# 3D Segmentation Pipeline with ResNet50 (3D Conv)
## Training on Preprocessed 3D Medical Image Data

This notebook implements 3D segmentation using:
- **Data:** Preprocessed 3D volumes (512 × 1024 × 32 × 2 channels)
- **Model:** Modified ResNet50 with 3D convolutions
- **Task:** 5-class segmentation (background + 4 structures)
- **Loss:** CrossEntropyLoss with class weighting
- **Optimization:** Adam with learning rate scheduling

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import nibabel as nib
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score, precision_score, recall_score, f1_score
import json
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 1. Configuration

In [None]:
# Paths
BASE_DIR = Path('C:/FeatureEx')
PREPROCESSED_DIR = BASE_DIR / 'preprocessed_3d_data'
IMAGES_DIR = PREPROCESSED_DIR / 'images'
LABELS_DIR = PREPROCESSED_DIR / 'labels'
MODELS_DIR = BASE_DIR / 'models_3d'
MODELS_DIR.mkdir(exist_ok=True)

# Model configuration
NUM_CLASSES = 5  # 0=background, 1-4=structures
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 2  # Adjust based on GPU memory
NUM_EPOCHS = 50
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1e-4

# Data split
TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15

print(f"Base directory: {BASE_DIR}")
print(f"Images directory: {IMAGES_DIR}")
print(f"Labels directory: {LABELS_DIR}")
print(f"Device: {DEVICE}")
print(f"Configuration:")
print(f"  Num classes: {NUM_CLASSES}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")

## 2. 3D Dataset

In [None]:
class Preprocessed3DDataset(Dataset):
    """Dataset for preprocessed 3D medical images and labels."""
    
    def __init__(self, images_dir, labels_dir):
        self.images_dir = Path(images_dir)
        self.labels_dir = Path(labels_dir)
        
        # Get file list
        self.image_files = sorted([f for f in self.images_dir.glob('*.nii.gz')])
        self.label_files = {f.stem: f for f in self.labels_dir.glob('*.nii.gz')}
        
        # Verify matching
        self.valid_pairs = []
        for img_file in self.image_files:
            if img_file.stem in self.label_files:
                self.valid_pairs.append(img_file)
        
        print(f"Dataset initialized:")
        print(f"  Total image files: {len(self.image_files)}")
        print(f"  Valid pairs: {len(self.valid_pairs)}")
    
    def __len__(self):
        return len(self.valid_pairs)
    
    def __getitem__(self, idx):
        img_path = self.valid_pairs[idx]
        label_path = self.label_files[img_path.stem]
        
        # Load image (shape: 512, 1024, 32, 2)
        img_nib = nib.load(img_path)
        img_data = img_nib.get_fdata()  # float32
        
        # Load label (shape: 512, 1024, 32, 2)
        label_nib = nib.load(label_path)
        label_data = label_nib.get_fdata()  # uint8
        
        # Normalize image
        img_min = img_data.min()
        img_max = img_data.max()
        if img_max > img_min:
            img_normalized = (img_data - img_min) / (img_max - img_min)
        else:
            img_normalized = img_data
        
        # Convert to torch tensors
        # Expected shape: (channels, depth, height, width) = (2, 32, 512, 1024)
        img_tensor = torch.from_numpy(np.transpose(img_normalized, (3, 2, 0, 1))).float()
        
        # Label: reduce to single channel and extract class values
        label_single = label_data[:, :, :, 0]  # Take first channel
        label_tensor = torch.from_numpy(label_single).long()
        
        return img_tensor, label_tensor, img_path.stem

# Create dataset
dataset = Preprocessed3DDataset(IMAGES_DIR, LABELS_DIR)
print(f"\nDataset size: {len(dataset)}")

## 3. Data Splitting and Loaders

In [None]:
# Split data
total_size = len(dataset)
train_size = int(total_size * TRAIN_RATIO)
val_size = int(total_size * VAL_RATIO)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# Create loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Data split:")
print(f"  Training: {train_size} samples")
print(f"  Validation: {val_size} samples")
print(f"  Test: {test_size} samples")
print(f"\nDataLoader created:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

## 4. 3D ResNet Model

In [None]:
class ResNet3DSegmentation(nn.Module):
    """3D ResNet-based segmentation model."""
    
    def __init__(self, in_channels=2, num_classes=5):
        super().__init__()
        
        # Encoder: 3D convolution layers
        self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        
        # ResNet blocks
        self.layer1 = self._make_layer(64, 64, 3, stride=1)
        self.layer2 = self._make_layer(64, 128, 4, stride=2)
        self.layer3 = self._make_layer(128, 256, 6, stride=2)
        self.layer4 = self._make_layer(256, 512, 3, stride=2)
        
        # Global adaptive pooling
        self.global_avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        
        # Decoder: Transpose convolutions for upsampling
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
        )
        
        # Final classification layer
        self.final_conv = nn.Conv3d(64, num_classes, kernel_size=1)
        
    def _make_layer(self, in_channels, out_channels, blocks, stride):
        layers = []
        layers.append(self._make_residual_block(in_channels, out_channels, stride))
        for _ in range(1, blocks):
            layers.append(self._make_residual_block(out_channels, out_channels, 1))
        return nn.Sequential(*layers)
    
    def _make_residual_block(self, in_channels, out_channels, stride):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_channels),
        )
    
    def forward(self, x):
        # Encoder
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Decoder
        x = self.decoder(x)
        
        # Final layer
        x = self.final_conv(x)
        
        return x

# Create model
model = ResNet3DSegmentation(in_channels=2, num_classes=NUM_CLASSES).to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model created:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Device: {DEVICE}")

## 5. Loss Function and Optimizer

In [None]:
# Calculate class weights for imbalanced segmentation
# Background is usually much larger, so we give it lower weight
class_weights = torch.tensor([0.1, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32).to(DEVICE)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5,
    min_lr=1e-6
)

print(f"Loss function: CrossEntropyLoss with class weights")
print(f"Class weights: {class_weights.cpu().numpy()}")
print(f"Optimizer: Adam")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Learning rate scheduler: ReduceLROnPlateau")

## 6. Training Loop

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch_idx, (images, labels, sample_names) in enumerate(loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        
        # Compute loss
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        if (batch_idx + 1) % 10 == 0:
            print(f"  Batch {batch_idx + 1}/{len(loader)}, Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / num_batches
    return avg_loss

def validate(model, loader, criterion, device):
    """Validation loop."""
    model.eval()
    total_loss = 0
    num_batches = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels, _ in loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            num_batches += 1
            
            # Get predictions
            preds = outputs.argmax(dim=1)
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    avg_loss = total_loss / num_batches
    return avg_loss

print("Training functions defined.")

## 7. Train Model

In [None]:
# Training history
history = {'train_loss': [], 'val_loss': [], 'learning_rates': []}
best_val_loss = float('inf')
best_epoch = 0

print(f"Starting training for {NUM_EPOCHS} epochs...\n")

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
    
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    print(f"  Train Loss: {train_loss:.4f}")
    
    # Validate
    val_loss = validate(model, val_loader, criterion, DEVICE)
    print(f"  Val Loss: {val_loss:.4f}")
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # History
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['learning_rates'].append(current_lr)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch
        torch.save(model.state_dict(), MODELS_DIR / 'best_resnet3d_segmentation.pth')
        print(f"  ** Best model saved (Epoch {epoch + 1}) **")
    
    print(f"  Learning rate: {current_lr:.2e}\n")

print(f"Training complete!")
print(f"Best epoch: {best_epoch + 1} with val_loss: {best_val_loss:.4f}")

## 8. Evaluate on Test Set

In [None]:
# Load best model
model.load_state_dict(torch.load(MODELS_DIR / 'best_resnet3d_segmentation.pth'))
model.eval()

test_loss = validate(model, test_loader, criterion, DEVICE)

print(f"Test Loss: {test_loss:.4f}")

## 9. Plot Training History

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Loss curve
axes[0].plot(history['train_loss'], label='Training Loss', marker='o', markersize=3)
axes[0].plot(history['val_loss'], label='Validation Loss', marker='s', markersize=3)
axes[0].axvline(x=best_epoch, color='r', linestyle='--', alpha=0.5, label=f'Best Epoch ({best_epoch + 1})')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Learning rate
axes[1].plot(history['learning_rates'], label='Learning Rate', color='orange')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Learning Rate')
axes[1].set_title('Learning Rate Schedule')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_yscale('log')

plt.tight_layout()
plt.savefig(MODELS_DIR / 'training_history_3d.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Training history plot saved to {MODELS_DIR / 'training_history_3d.png'}")

## 10. Save Training Metrics

In [None]:
# Save metrics
metrics = {
    'best_epoch': best_epoch + 1,
    'best_val_loss': float(best_val_loss),
    'final_test_loss': float(test_loss),
    'num_epochs_trained': NUM_EPOCHS,
    'num_classes': NUM_CLASSES,
    'input_shape': [2, 32, 512, 1024],
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'total_parameters': int(total_params),
    'training_samples': train_size,
    'validation_samples': val_size,
    'test_samples': test_size
}

with open(MODELS_DIR / 'metrics_3d.json', 'w') as f:
    json.dump(metrics, f, indent=2)

print(f"Metrics saved:")
for key, val in metrics.items():
    print(f"  {key}: {val}")