# **3D Brain Image Analysis with Deep Learning**
This example demonstrates how to implement deep learning models for three fundamental tasks in 3D brain image analysis: **classification** (sex classification), **regression** (age prediction), and **segmentation** (lesion segmentation). All models share a common 3D feature extraction backbone but use different prediction heads tailored to each task type.

In [None]:
import torch
import torch.nn as nn

### **COMMON COMPONENTS FOR ALL TASKS**

In [None]:
class Conv3DBlock(nn.Module):
    """Basic 3D convolution block with BatchNorm and ReLU"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

class FeatureExtractor3D(nn.Module):
    """Common 3D feature extractor backbone for all tasks"""
    def __init__(self, input_channels=1):
        super().__init__()
        # Progressive feature extraction with spatial downsampling
        self.conv1 = Conv3DBlock(input_channels, 64, stride=2)    # /2 spatial reduction
        self.conv2 = Conv3DBlock(64, 128, stride=2)               # /4 spatial reduction  
        self.conv3 = Conv3DBlock(128, 256, stride=2)              # /8 spatial reduction
        self.conv4 = Conv3DBlock(256, 512, stride=2)              # /16 spatial reduction
        
        # Final feature maps: (B, 512, D/16, H/16, W/16)
        
    def forward(self, x):
        # x: (B, 1, D, H, W) - e.g., (B, 1, 128, 128, 128)
        
        # Hierarchical feature extraction
        f1 = self.conv1(x)      # (B, 64, D/2, H/2, W/2)
        f2 = self.conv2(f1)     # (B, 128, D/4, H/4, W/4) 
        f3 = self.conv3(f2)     # (B, 256, D/8, H/8, W/8)
        f4 = self.conv4(f3)     # (B, 512, D/16, H/16, W/16)
        
        # Return both final features and intermediate features
        # Final: for classification/regression
        # Intermediate: for segmentation skip connections
        return {
            'final_features': f4,
            'intermediate_features': [f1, f2, f3, f4]
        }

### **TASK 1: CLASSIFICATION** - Sex Classification (Binary)
**Pattern: Feature Extraction → Global Pooling → FC → Sigmoid**

In [None]:
class BrainClassifier3D(nn.Module):
    """3D brain image classification (e.g., Female/Male)"""
    def __init__(self, input_channels=1):
        super().__init__()
        
        # COMMON: Feature extraction backbone
        self.feature_extractor = FeatureExtractor3D(input_channels)
        
        # CLASSIFICATION SPECIFIC: Global pooling + FC layers
        self.global_pool = nn.AdaptiveAvgPool3d(1)  # (B, 512, D, H, W) → (B, 512, 1, 1, 1)
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, 1)  # Single output for binary classification
        )
        
        # CLASSIFICATION SPECIFIC: Output activation
        self.sigmoid = nn.Sigmoid()  # Binary classification: 0=Female, 1=Male
        
    def forward(self, x):
        # x: (B, 1, D, H, W) - e.g., (B, 1, 128, 128, 128)
        
        # Feature extraction with final features (same as regression)
        features = self.feature_extractor(x)
        final_features = features['final_features']  # (B, 512, D/16, H/16, W/16)
        
        # CLASSIFICATION SPECIFIC: Spatial information removal (same as regression)
        pooled = self.global_pool(final_features)    # (B, 512, 1, 1, 1)
        flattened = pooled.view(pooled.size(0), -1)  # (B, 512)
        
        # CLASSIFICATION SPECIFIC: Binary classification
        class_logits = self.classifier(flattened)          # (B, 1)
        class_probability = self.sigmoid(class_logits)         # (B, 1) - P(Male)
        
        return class_probability  # Output: P(Male), P(Female) = 1 - P(Male)

### **TASK 2: REGRESSION** - Age Prediction
**Pattern: Feature Extraction → Global Pooling → FC → (No activation)**

In [None]:
class BrainRegressor3D(nn.Module):
    """3D brain image regression (e.g., age prediction)"""
    def __init__(self, input_channels=1):
        super().__init__()
        
        # COMMON: Feature extraction backbone
        self.feature_extractor = FeatureExtractor3D(input_channels)
        
        # REGRESSION SPECIFIC: Global pooling + FC layers
        self.global_pool = nn.AdaptiveAvgPool3d(1)  # (B, 512, D, H, W) → (B, 512, 1, 1, 1)
        self.regressor = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, 1)  # Single output for age prediction
        )
        
        # REGRESSION SPECIFIC: No output activation (unbounded prediction)
        # Alternative: nn.ReLU() for non-negative values (age >= 0)
        
    def forward(self, x):
        # x: (B, 1, D, H, W) - e.g., (B, 1, 128, 128, 128)
        
        # Feature extraction with final features (same as classification)
        features = self.feature_extractor(x)
        final_features = features['final_features']  # (B, 512, D/16, H/16, W/16)
        
        # REGRESSION SPECIFIC: Spatial information removal (same as classification)
        pooled = self.global_pool(final_features)    # (B, 512, 1, 1, 1)
        flattened = pooled.view(pooled.size(0), -1)  # (B, 512)
        
        # REGRESSION SPECIFIC: Continuous value prediction
        predicted_age = self.regressor(flattened)    # (B, 1)
        
        return predicted_age  # Output: continuous value (age)

### **TASK 3: SEGMENTATION** - Lesion Segmentation (Binary)
**Pattern: Feature Extraction (Encoder) → Spatial Prediction Head (Decoder)**

In [None]:
class Decoder3DBlock(nn.Module):
    """3D decoder block for segmentation with skip connections"""
    def __init__(self, in_channels, skip_channels, out_channels):
        super().__init__()
        # Upsampling + convolution
        self.upsample = nn.ConvTranspose3d(in_channels, in_channels, 2, 2)  # 2x upsampling
        # Skip connection fusion
        self.conv = Conv3DBlock(in_channels + skip_channels, out_channels)
        
    def forward(self, x, skip):
        # x: decoder features, skip: encoder features  
        upsampled = self.upsample(x)                    # Spatial upsampling
        concatenated = torch.cat([upsampled, skip], dim=1)  # Skip connection
        return self.conv(concatenated)

class BrainSegmenter3D(nn.Module):
    """3D brain image segmentation (e.g., Normal/Lesion)"""
    def __init__(self, input_channels=1):  # Normal (0) + Lesion (1)
        super().__init__()
        
        # COMMON: Feature extraction backbone
        self.feature_extractor = FeatureExtractor3D(input_channels)
        
        # SEGMENTATION SPECIFIC: Decoder for spatial reconstruction
        self.decoder4 = Decoder3DBlock(512, 256, 256)  # Match conv3 output
        self.decoder3 = Decoder3DBlock(256, 128, 128)  # Match conv2 output  
        self.decoder2 = Decoder3DBlock(128, 64, 64)    # Match conv1 output
        self.decoder1 = nn.ConvTranspose3d(64, 32, 2, 2)  # Final upsampling
        
        # SEGMENTATION SPECIFIC: Voxel-wise classification head
        self.final_conv = nn.Conv3d(32, 1, kernel_size=1)  # 1x1x1 conv: reduces 32 channels to 1 channel for binary segmentation
        
        # SEGMENTATION SPECIFIC: Output activation for binary segmentation
        self.sigmoid = nn.Sigmoid()  # Binary segmentation: 0=Normal, 1=Lesion
        
    def forward(self, x):
        # x: (B, 1, D, H, W) - e.g., (B, 1, 128, 128, 128)
        
        # ENCODER: Feature extraction with intermediate features
        features = self.feature_extractor(x)
        f1, f2, f3, f4 = features['intermediate_features'] # For skip connections
        
        # SEGMENTATION SPECIFIC: DECODER with skip connections
        # NO global pooling - spatial information preserved!
        d4 = self.decoder4(f4, f3)      # (B, 256, D/8, H/8, W/8)
        d3 = self.decoder3(d4, f2)      # (B, 128, D/4, H/4, W/4)  
        d2 = self.decoder2(d3, f1)      # (B, 64, D/2, H/2, W/2)
        d1 = self.decoder1(d2)          # (B, 32, D, H, W) - Original spatial size
        
        # SEGMENTATION SPECIFIC: Voxel-wise binary classification
        segmentation_logits = self.final_conv(d1)        # (B, 1, D, H, W)
        segmentation_map = self.sigmoid(segmentation_logits)  # (B, 1, D, H, W) - P(Lesion)
        
        return segmentation_map  # Output: spatial segmentation map (lesion probability per voxel)

### **USAGE EXAMPLES AND KEY DIFFERENCES SUMMARY**

In [None]:
def demonstrate_3d_brain_models():
    """Demonstrate the usage of all three models"""
    
    # Sample 3D brain MRI input
    batch_size = 4
    input_volume = torch.randn(batch_size, 1, 128, 128, 128)  # (B, C, D, H, W)
    
    print("=== 3D Brain Image Analysis Tasks ===\n")
    
    # TASK 1: CLASSIFICATION
    classifier = BrainClassifier3D()  # Female/Male
    class_probs = classifier(input_volume)
    print(f"CLASSIFICATION OUTPUT: {class_probs.shape}")
    print(f"  - Input:  (B, C, D, H, W) = {input_volume.shape}")
    print(f"  - Output: (B, 1) = {class_probs.shape}")
    print(f"  - Interpretation: P(Male) per brain volume, P(Female) = 1 - P(Male)")
    print(f"  - Spatial Info: REMOVED via global pooling\n")
    
    # TASK 2: REGRESSION  
    regressor = BrainRegressor3D()  # Age prediction
    predicted_age = regressor(input_volume)
    print(f"REGRESSION OUTPUT: {predicted_age.shape}")
    print(f"  - Input:  (B, C, D, H, W) = {input_volume.shape}")
    print(f"  - Output: (B, output_dim) = {predicted_age.shape}")
    print(f"  - Interpretation: Predicted brain age per volume")
    print(f"  - Spatial Info: REMOVED via global pooling\n")
    
    # TASK 3: SEGMENTATION
    segmenter = BrainSegmenter3D()  # Normal/Lesion
    segmentation = segmenter(input_volume)
    print(f"SEGMENTATION OUTPUT: {segmentation.shape}")
    print(f"  - Input:  (B, C, D, H, W) = {input_volume.shape}")
    print(f"  - Output: (B, 1, D, H, W) = {segmentation.shape}")  
    print(f"  - Interpretation: Lesion probability per voxel")
    print(f"  - Spatial Info: PRESERVED throughout decoder\n")
    
    # KEY ARCHITECTURAL DIFFERENCES
    print("=== KEY ARCHITECTURAL DIFFERENCES ===")
    print("COMMON:")
    print("  ✓ Same feature extractor backbone")
    print("  ✓ Same hierarchical feature extraction")
    print("  ✓ Same 3D convolution operations\n")
    
    print("CLASSIFICATION & REGRESSION:")
    print("  ✓ Global pooling: (B,C,D,H,W) → (B,C)")
    print("  ✓ Flatten + FC layers")
    print("  ✓ Single prediction per volume")
    print("  ✗ Spatial information discarded")
    print("  ✓ Binary classification uses Sigmoid activation")
    print("  ✓ Regression uses no output activation\n")
    
    print("SEGMENTATION:")
    print("  ✗ NO global pooling")
    print("  ✓ Decoder with skip connections") 
    print("  ✓ Spatial upsampling")
    print("  ✓ Prediction per voxel")
    print("  ✓ Spatial information preserved")
    print("  ✓ Binary segmentation uses Sigmoid activation")
    
    print("\n=== LOSS FUNCTIONS ===")
    print("CLASSIFICATION: BCELoss or BCEWithLogitsLoss (binary)")
    print("REGRESSION: MSELoss or L1Loss (continuous)")  
    print("SEGMENTATION: BCELoss or DiceLoss (binary voxel-wise)")

demonstrate_3d_brain_models()