## üéâ Complete!

You now have a full brain tumor segmentation pipeline with:

‚úÖ **Data loading & preprocessing** with MONAI transforms  
‚úÖ **3D U-Net model** with ~31M parameters  
‚úÖ **Training loop** with mixed precision & checkpointing  
‚úÖ **Evaluation metrics** (WT, TC, ET Dice scores)  
‚úÖ **Visualization** of predictions  
‚úÖ **Failure analysis** for debugging  
‚úÖ **Experiment tracking** for comparison  

### Next Steps:
1. **Hyperparameter tuning**: Try different learning rates, batch sizes, loss weights
2. **Advanced augmentation**: Add more data augmentation techniques
3. **Ensemble methods**: Combine multiple models for better performance
4. **Post-processing**: Add CRF or morphological operations
5. **MedSAM fine-tuning**: Experiment with foundation models

### Key Metrics to Report:
- **Whole Tumor (WT) Dice**: Overall tumor detection
- **Tumor Core (TC) Dice**: Core tumor regions
- **Enhancing Tumor (ET) Dice**: Active tumor areas

Good luck with your project! üöÄ

In [None]:
# # Uncomment to install and use MedSAM
# !pip install git+https://github.com/bowang-lab/MedSAM.git
# !wget https://github.com/bowang-lab/MedSAM/releases/download/v0.1/medsam_vit_b.pth

# from segment_anything import sam_model_registry

# class MedSAMFineTune(nn.Module):
#     """MedSAM with frozen encoder for brain tumor segmentation"""
    
#     def __init__(self, checkpoint_path, num_classes=4):
#         super().__init__()
#         self.medsam = sam_model_registry["vit_b"](checkpoint=checkpoint_path)
        
#         # Freeze encoder
#         for param in self.medsam.image_encoder.parameters():
#             param.requires_grad = False
        
#         # Custom segmentation head
#         self.seg_head = nn.Sequential(
#             nn.Conv2d(256, 128, 3, padding=1),
#             nn.BatchNorm2d(128),
#             nn.ReLU(),
#             nn.Conv2d(128, num_classes, 1)
#         )
    
#     def forward(self, x):
#         """Process 3D volume slice by slice"""
#         B, C, H, W, D = x.shape
#         outputs = []
#         for d in range(D):
#             # Convert to 3-channel for SAM
#             slice_x = x[:, :, :, :, d].mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
#             features = self.medsam.image_encoder(slice_x)
#             outputs.append(self.seg_head(features))
#         return torch.stack(outputs, dim=-1)

# # Create MedSAM model
# medsam_model = MedSAMFineTune('medsam_vit_b.pth', num_classes=4).to(device)
# print("‚úÖ MedSAM model created with frozen encoder")

## 16. (BONUS) MedSAM Fine-tuning Setup

**Note**: This section shows how to fine-tune MedSAM for brain tumor segmentation. Uncomment and run if you want to experiment with foundation models.

In [None]:
class ExperimentTracker:
    """Simple experiment tracking class"""
    
    def __init__(self):
        self.experiments = {}
    
    def log(self, name, config, results):
        """Log an experiment"""
        self.experiments[name] = {
            'config': config,
            'results': results
        }
        print(f"‚úÖ Logged experiment: {name}")
    
    def compare(self):
        """Compare all experiments"""
        print("\n" + "="*90)
        print("üìä EXPERIMENT COMPARISON")
        print("="*90)
        
        for name, exp in self.experiments.items():
            print(f"\n{name}:")
            print(f"  Config: {exp['config']}")
            print(f"  Results:")
            for k, v in exp['results'].items():
                if isinstance(v, float):
                    print(f"    {k}: {v:.4f}")
                else:
                    print(f"    {k}: {v}")
        print("="*90)
    
    def save(self, filepath):
        """Save experiments to JSON"""
        with open(filepath, 'w') as f:
            json.dump(self.experiments, f, indent=2)
        print(f"üíæ Saved experiments to {filepath}")

# Create tracker and log current experiment
tracker = ExperimentTracker()

tracker.log(
    name='3D_UNet_DiceCE_v1',
    config={
        'model': '3D U-Net',
        'loss': 'DiceCE (0.5/0.5)',
        'optimizer': 'AdamW',
        'lr': 1e-4,
        'batch_size': BATCH_SIZE,
        'epochs': NUM_EPOCHS,
        'augmentation': 'Yes'
    },
    results=test_results
)

tracker.compare()

# Save experiments
tracker.save(os.path.join(SAVE_DIR, 'experiments.json'))

## 15. Experiment Tracking

In [None]:
def analyze_failures(model, test_loader, device):
    """Analyze failure cases"""
    model.eval()
    cases = []
    
    print("Analyzing all test cases...")
    
    with torch.no_grad():
        for idx, batch_data in enumerate(test_loader):
            inputs = batch_data["image"].to(device)
            labels = batch_data["label"].to(device)
            
            outputs = sliding_window_inference(
                inputs, 
                roi_size=(128, 128, 128),
                sw_batch_size=4, 
                predictor=model
            )
            preds = torch.argmax(outputs, dim=1, keepdim=True)
            
            metrics = compute_region_metrics(preds[0, 0], labels[0, 0])
            
            cases.append({
                'id': idx,
                'dice_wt': metrics['dice_wt'],
                'dice_tc': metrics['dice_tc'],
                'dice_et': metrics['dice_et'],
                'mean': np.mean(list(metrics.values()))
            })
    
    # Sort by mean Dice score
    cases_sorted = sorted(cases, key=lambda x: x['mean'])
    
    print("\n" + "="*70)
    print("üî¥ WORST 10 CASES (Lowest Dice Scores)")
    print("="*70)
    for i, case in enumerate(cases_sorted[:10]):
        print(f"{i+1:2d}. Case {case['id']:3d}: Mean={case['mean']:.4f} | "
              f"WT={case['dice_wt']:.4f}, TC={case['dice_tc']:.4f}, ET={case['dice_et']:.4f}")
    
    print("\n" + "="*70)
    print("üü¢ BEST 10 CASES (Highest Dice Scores)")
    print("="*70)
    for i, case in enumerate(cases_sorted[-10:][::-1]):
        print(f"{i+1:2d}. Case {case['id']:3d}: Mean={case['mean']:.4f} | "
              f"WT={case['dice_wt']:.4f}, TC={case['dice_tc']:.4f}, ET={case['dice_et']:.4f}")
    
    return cases_sorted

# Perform failure analysis
failure_cases = analyze_failures(model, test_loader, device)

## 14. Failure Analysis

In [None]:
def visualize_predictions(model, test_loader, device, n_samples=5):
    """Visualize model predictions"""
    model.eval()
    
    fig, axes = plt.subplots(n_samples, 4, figsize=(16, n_samples*4))
    if n_samples == 1:
        axes = axes.reshape(1, -1)
    
    with torch.no_grad():
        for idx, batch_data in enumerate(test_loader):
            if idx >= n_samples:
                break
            
            inputs = batch_data["image"].to(device)
            labels = batch_data["label"].to(device)
            
            # Get predictions
            outputs = sliding_window_inference(
                inputs, 
                roi_size=(128, 128, 128),
                sw_batch_size=4, 
                predictor=model
            )
            preds = torch.argmax(outputs, dim=1)
            
            # Get middle slice
            z = inputs.shape[-1] // 2
            
            # FLAIR modality
            axes[idx, 0].imshow(inputs[0, 0, :, :, z].cpu(), cmap='gray')
            axes[idx, 0].set_title('FLAIR', fontsize=12, fontweight='bold')
            axes[idx, 0].axis('off')
            
            # T1ce modality
            axes[idx, 1].imshow(inputs[0, 1, :, :, z].cpu(), cmap='gray')
            axes[idx, 1].set_title('T1ce', fontsize=12, fontweight='bold')
            axes[idx, 1].axis('off')
            
            # Ground truth
            axes[idx, 2].imshow(labels[0, 0, :, :, z].cpu(), cmap='jet', vmin=0, vmax=3)
            axes[idx, 2].set_title('Ground Truth', fontsize=12, fontweight='bold')
            axes[idx, 2].axis('off')
            
            # Prediction
            axes[idx, 3].imshow(preds[0, :, :, z].cpu(), cmap='jet', vmin=0, vmax=3)
            
            # Calculate Dice for this sample
            metrics = compute_region_metrics(preds[0], labels[0, 0])
            mean_dice = np.mean(list(metrics.values()))
            axes[idx, 3].set_title(f'Prediction (Dice: {mean_dice:.3f})', 
                                   fontsize=12, fontweight='bold')
            axes[idx, 3].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, 'predictions_visualization.png'), 
                dpi=150, bbox_inches='tight')
    plt.show()

# Visualize predictions
visualize_predictions(model, test_loader, device, n_samples=5)

## 13. Visualize Predictions

In [None]:
def evaluate_test_set(model, test_loader, device, model_path):
    """Evaluate model on test set"""
    
    # Load best model
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    print(f"Loaded model from epoch {checkpoint['epoch']} with Dice: {checkpoint['best_dice']:.4f}")
    
    all_metrics = {'dice_wt': [], 'dice_tc': [], 'dice_et': []}
    
    with torch.no_grad():
        for idx, batch_data in enumerate(test_loader):
            inputs = batch_data["image"].to(device)
            labels = batch_data["label"].to(device)
            
            # Sliding window inference
            outputs = sliding_window_inference(
                inputs, 
                roi_size=(128, 128, 128),
                sw_batch_size=4, 
                predictor=model
            )
            
            preds = torch.argmax(outputs, dim=1, keepdim=True)
            metrics = compute_region_metrics(preds[0, 0], labels[0, 0])
            
            for k, v in metrics.items():
                all_metrics[k].append(v)
            
            if (idx + 1) % 10 == 0:
                print(f"  Evaluated {idx + 1}/{len(test_loader)} samples")
    
    # Calculate statistics
    results = {}
    for k, v in all_metrics.items():
        results[f'{k}_mean'] = np.mean(v)
        results[f'{k}_std'] = np.std(v)
        results[f'{k}_median'] = np.median(v)
    
    # Print results
    print("\n" + "="*70)
    print("üìä TEST SET RESULTS")
    print("="*70)
    print(f"WT (Whole Tumor):    {results['dice_wt_mean']:.4f} ¬± {results['dice_wt_std']:.4f}")
    print(f"TC (Tumor Core):     {results['dice_tc_mean']:.4f} ¬± {results['dice_tc_std']:.4f}")
    print(f"ET (Enhancing):      {results['dice_et_mean']:.4f} ¬± {results['dice_et_std']:.4f}")
    print(f"Mean Dice:           {np.mean([results['dice_wt_mean'], results['dice_tc_mean'], results['dice_et_mean']]):.4f}")
    print("="*70)
    
    return results, all_metrics

# Evaluate on test set
test_results, test_metrics_all = evaluate_test_set(model, test_loader, device, best_model_path)

## 12. Test Set Evaluation

In [None]:
def plot_training_curves(history):
    """Plot training curves"""
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss curves
    axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
    axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=11)
    axes[0].grid(True, alpha=0.3)
    
    # Dice score curves
    axes[1].plot(history['dice_wt'], label='WT (Whole Tumor)', linewidth=2)
    axes[1].plot(history['dice_tc'], label='TC (Tumor Core)', linewidth=2)
    axes[1].plot(history['dice_et'], label='ET (Enhancing)', linewidth=2)
    axes[1].plot(history['mean_dice'], label='Mean Dice', linewidth=2, linestyle='--', color='black')
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Dice Score', fontsize=12)
    axes[1].set_title('Validation Dice Scores', fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=11)
    axes[1].grid(True, alpha=0.3)
    axes[1].set_ylim([0, 1])
    
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, 'training_curves.png'), dpi=150, bbox_inches='tight')
    plt.show()

plot_training_curves(history)

## 11. Visualize Training Curves

In [None]:
# Training configuration
NUM_EPOCHS = 100
best_dice = 0
best_epoch = 0

# Paths for saving models
best_model_path = os.path.join(SAVE_DIR, "best_3d_unet.pth")
last_model_path = os.path.join(SAVE_DIR, "last_3d_unet.pth")

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'dice_wt': [],
    'dice_tc': [],
    'dice_et': [],
    'mean_dice': []
}

print(f"üöÄ Starting training for {NUM_EPOCHS} epochs...")
print("="*70)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 70)
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, loss_function, scaler, device)
    
    # Validate
    val_loss, metrics = validate(model, val_loader, loss_function, device)
    
    # Update learning rate
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    
    # Calculate mean Dice score
    mean_dice = np.mean(list(metrics.values()))
    
    # Update history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['dice_wt'].append(metrics['dice_wt'])
    history['dice_tc'].append(metrics['dice_tc'])
    history['dice_et'].append(metrics['dice_et'])
    history['mean_dice'].append(mean_dice)
    
    # Print epoch summary
    print(f"\nüìä Epoch {epoch+1} Summary:")
    print(f"   Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}")
    print(f"   Dice WT: {metrics['dice_wt']:.4f} | TC: {metrics['dice_tc']:.4f} | ET: {metrics['dice_et']:.4f}")
    print(f"   Mean Dice: {mean_dice:.4f}")
    
    # Save best model
    if mean_dice > best_dice:
        best_dice = mean_dice
        best_epoch = epoch + 1
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_dice': best_dice,
            'metrics': metrics
        }, best_model_path)
        print(f"   ‚úÖ New best model saved! (Dice: {best_dice:.4f})")
    
    # Save last model
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'history': history
    }, last_model_path)

print("\n" + "="*70)
print(f"‚úÖ Training completed!")
print(f"   Best Dice: {best_dice:.4f} at epoch {best_epoch}")
print(f"   Models saved to: {SAVE_DIR}")

## 10. Training Loop

In [None]:
def train_epoch(model, train_loader, optimizer, loss_function, scaler, device):
    """Train for one epoch"""
    model.train()
    epoch_loss = 0
    
    for batch_idx, batch_data in enumerate(train_loader):
        inputs = batch_data["image"].to(device)
        labels = batch_data["label"].to(device)
        
        optimizer.zero_grad()
        
        # Mixed precision training
        with autocast():
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        epoch_loss += loss.item()
        
        if (batch_idx + 1) % 10 == 0:
            print(f"  Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")
    
    return epoch_loss / len(train_loader)

def validate(model, val_loader, loss_function, device):
    """Validate model"""
    model.eval()
    val_loss = 0
    all_metrics = {'dice_wt': [], 'dice_tc': [], 'dice_et': []}
    
    with torch.no_grad():
        for batch_data in val_loader:
            inputs = batch_data["image"].to(device)
            labels = batch_data["label"].to(device)
            
            # Use sliding window inference for better predictions
            outputs = sliding_window_inference(
                inputs, 
                roi_size=(128, 128, 128),
                sw_batch_size=4, 
                predictor=model
            )
            
            val_loss += loss_function(outputs, labels).item()
            
            # Get predictions
            preds = torch.argmax(outputs, dim=1, keepdim=True)
            
            # Compute metrics
            metrics = compute_region_metrics(preds[0, 0], labels[0, 0])
            for k, v in metrics.items():
                all_metrics[k].append(v)
    
    # Average metrics
    avg_metrics = {k: np.mean(v) for k, v in all_metrics.items()}
    avg_loss = val_loss / len(val_loader)
    
    return avg_loss, avg_metrics

print("‚úÖ Training and validation functions defined")

## 9. Training & Validation Functions

In [None]:
def compute_region_metrics(pred, label):
    """
    Compute Dice scores for BraTS tumor regions:
    - WT (Whole Tumor): All tumor classes (1, 2, 3)
    - TC (Tumor Core): Labels 1 and 3
    - ET (Enhancing Tumor): Label 3 only
    
    Args:
        pred: Predicted segmentation (H, W, D)
        label: Ground truth segmentation (H, W, D)
    
    Returns:
        Dictionary with Dice scores for each region
    """
    pred_np = pred.cpu().numpy()
    label_np = label.cpu().numpy()
    
    # Whole Tumor (all non-zero labels)
    wt_pred = (pred_np > 0).astype(np.float32)
    wt_label = (label_np > 0).astype(np.float32)
    
    # Tumor Core (labels 1 and 3)
    tc_pred = ((pred_np == 1) | (pred_np == 3)).astype(np.float32)
    tc_label = ((label_np == 1) | (label_np == 3)).astype(np.float32)
    
    # Enhancing Tumor (label 3 only)
    et_pred = (pred_np == 3).astype(np.float32)
    et_label = (label_np == 3).astype(np.float32)
    
    def dice_score(pred, label):
        """Calculate Dice coefficient"""
        smooth = 1e-5
        intersection = np.sum(pred * label)
        union = np.sum(pred) + np.sum(label)
        return (2.0 * intersection + smooth) / (union + smooth)
    
    return {
        'dice_wt': dice_score(wt_pred, wt_label),
        'dice_tc': dice_score(tc_pred, tc_label),
        'dice_et': dice_score(et_pred, et_label)
    }

print("‚úÖ Evaluation metrics defined")
print("   - WT (Whole Tumor): All tumor classes")
print("   - TC (Tumor Core): Necrotic + Enhancing")
print("   - ET (Enhancing Tumor): Enhancing only")

## 8. Evaluation Metrics

In [None]:
# Loss function: Dice + Cross Entropy
loss_function = DiceCELoss(
    include_background=False,  # Don't include background class
    to_onehot_y=True,           # Convert labels to one-hot
    softmax=True,               # Apply softmax to predictions
    lambda_dice=0.5,            # Weight for Dice loss
    lambda_ce=0.5               # Weight for Cross Entropy loss
)

# Optimizer: AdamW with weight decay
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=1e-4, 
    weight_decay=1e-5
)

# Learning rate scheduler: Cosine annealing
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=100
)

# Mixed precision training scaler
scaler = GradScaler()

print("‚úÖ Training components initialized")
print(f"   - Loss: Dice + Cross Entropy (50/50)")
print(f"   - Optimizer: AdamW (lr=1e-4, wd=1e-5)")
print(f"   - Scheduler: CosineAnnealingLR")
print(f"   - Mixed precision: Enabled")

## 7. Loss Function, Optimizer & Scheduler

In [None]:
def create_3d_unet(in_channels=4, out_channels=4):
    """
    Create 3D U-Net model for brain tumor segmentation
    
    Args:
        in_channels: Number of input modalities (FLAIR, T1ce, T1, T2)
        out_channels: Number of output classes (background, necrotic, edema, enhancing)
    
    Returns:
        MONAI UNet model
    """
    model = UNet(
        spatial_dims=3,
        in_channels=in_channels,
        out_channels=out_channels,
        channels=(32, 64, 128, 256, 512),
        strides=(2, 2, 2, 2),
        num_res_units=2,
        dropout=0.1
    )
    return model

# Create model
model = create_3d_unet(in_channels=4, out_channels=4).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úÖ 3D U-Net model created")
print(f"   - Total parameters: {total_params:,}")
print(f"   - Trainable parameters: {trainable_params:,}")
print(f"   - Model size: ~{total_params * 4 / 1024**2:.1f} MB")

## 6. Model Architecture - 3D U-Net

In [None]:
def create_dataloaders(train_files, val_files, test_files, 
                       train_transforms, val_transforms, test_transforms,
                       batch_size=2, cache_rate=1.0):
    """
    Create MONAI DataLoaders with caching for faster training
    
    Args:
        train_files, val_files, test_files: List of data dictionaries
        train_transforms, val_transforms, test_transforms: MONAI transform objects
        batch_size: Batch size for training
        cache_rate: Fraction of data to cache in memory (1.0 = all data)
    
    Returns:
        train_loader, val_loader, test_loader
    """
    
    # Create cached datasets for faster loading
    train_ds = CacheDataset(
        data=train_files, 
        transform=train_transforms, 
        cache_rate=cache_rate, 
        num_workers=4
    )
    
    val_ds = CacheDataset(
        data=val_files, 
        transform=val_transforms, 
        cache_rate=cache_rate, 
        num_workers=4
    )
    
    test_ds = CacheDataset(
        data=test_files, 
        transform=test_transforms, 
        cache_rate=cache_rate, 
        num_workers=4
    )
    
    # Create data loaders
    train_loader = MonaiDataLoader(
        train_ds, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=0,
        pin_memory=torch.cuda.is_available()
    )
    
    val_loader = MonaiDataLoader(
        val_ds, 
        batch_size=1, 
        shuffle=False, 
        num_workers=0,
        pin_memory=torch.cuda.is_available()
    )
    
    test_loader = MonaiDataLoader(
        test_ds, 
        batch_size=1, 
        shuffle=False, 
        num_workers=0,
        pin_memory=torch.cuda.is_available()
    )
    
    return train_loader, val_loader, test_loader

# Create data loaders
BATCH_SIZE = 2
train_loader, val_loader, test_loader = create_dataloaders(
    train_files, val_files, test_files,
    train_transforms, val_transforms, test_transforms, 
    batch_size=BATCH_SIZE
)

print(f"‚úÖ DataLoaders created successfully")
print(f"   - Train batches: {len(train_loader)} (batch size: {BATCH_SIZE})")
print(f"   - Val batches: {len(val_loader)} (batch size: 1)")
print(f"   - Test batches: {len(test_loader)} (batch size: 1)")

## 5. Create DataLoaders

In [None]:
def get_preprocessing_transforms(mode='train'):
    """
    Get preprocessing transforms for training, validation, or testing
    
    Args:
        mode: 'train', 'val', or 'test'
    
    Returns:
        MONAI Compose object with transforms
    """
    
    # Common transforms for all modes
    common_transforms = [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), 
                mode=("bilinear", "nearest")),
        CropForegroundd(keys=["image", "label"], source_key="image", margin=10),
        ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=(128, 128, 128)),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
    
    # Add augmentation transforms for training only
    if mode == 'train':
        train_transforms = common_transforms + [
            RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=(0, 1)),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
            RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        ]
        return Compose(train_transforms)
    
    return Compose(common_transforms)

# Create transforms for each dataset split
train_transforms = get_preprocessing_transforms(mode='train')
val_transforms = get_preprocessing_transforms(mode='val')
test_transforms = get_preprocessing_transforms(mode='test')

print("‚úÖ Preprocessing transforms created")
print(f"   - Training: {len(train_transforms.transforms)} transforms (with augmentation)")
print(f"   - Validation: {len(val_transforms.transforms)} transforms")
print(f"   - Test: {len(test_transforms.transforms)} transforms")

## 4. Preprocessing Transforms

In [None]:
class MSDDatasetPreparation:
    """Dataset preparation class for MSD Task01_BrainTumour"""
    
    def __init__(self, data_root, train_ratio=0.7, val_ratio=0.15):
        self.data_root = Path(data_root)
        self.train_ratio = train_ratio
        self.val_ratio = val_ratio
        
    def load_dataset_json(self):
        """Load dataset.json metadata"""
        json_path = self.data_root / "dataset.json"
        with open(json_path, 'r') as f:
            dataset_info = json.load(f)
        return dataset_info
    
    def prepare_data_dicts(self):
        """Prepare data dictionaries with image and label paths"""
        dataset_info = self.load_dataset_json()
        data_dicts = []
        
        for item in dataset_info['training']:
            # Convert relative paths to absolute paths
            image_path = str(self.data_root / item['image'].lstrip('./'))
            label_path = str(self.data_root / item['label'].lstrip('./'))
            
            data_dicts.append({
                'image': image_path,
                'label': label_path
            })
        
        return data_dicts, dataset_info
    
    def split_dataset(self, data_dicts, seed=42):
        """Split dataset into train, validation, and test sets"""
        np.random.seed(seed)
        n_total = len(data_dicts)
        indices = np.random.permutation(n_total)
        
        n_train = int(n_total * self.train_ratio)
        n_val = int(n_total * self.val_ratio)
        
        train_files = [data_dicts[i] for i in indices[:n_train]]
        val_files = [data_dicts[i] for i in indices[n_train:n_train+n_val]]
        test_files = [data_dicts[i] for i in indices[n_train+n_val:]]
        
        print(f"Dataset split:")
        print(f"  Train: {len(train_files)} samples ({self.train_ratio*100:.0f}%)")
        print(f"  Val: {len(val_files)} samples ({self.val_ratio*100:.0f}%)")
        print(f"  Test: {len(test_files)} samples ({(1-self.train_ratio-self.val_ratio)*100:.0f}%)")
        
        return train_files, val_files, test_files

# Initialize dataset preparation
data_prep = MSDDatasetPreparation(data_root=DATA_ROOT, train_ratio=0.7, val_ratio=0.15)
data_dicts, dataset_info = data_prep.prepare_data_dicts()
train_files, val_files, test_files = data_prep.split_dataset(data_dicts)

# Display dataset information
print(f"\nDataset: {dataset_info['name']}")
print(f"Modalities: {dataset_info['modality']}")
print(f"Labels: {dataset_info['labels']}")
print(f"Total training samples: {dataset_info['numTraining']}")

## 3. Dataset Preparation

In [None]:
# Set dataset paths
DATA_ROOT = "/content/drive/MyDrive/BrainTumor/Task01_BrainTumour"
SAVE_DIR = "/content/drive/MyDrive/BrainTumor/models"

# Create save directory if it doesn't exist
os.makedirs(SAVE_DIR, exist_ok=True)

print(f"üìÅ Checking dataset paths...")
print(f"Data root: {DATA_ROOT}")
print(f"Save directory: {SAVE_DIR}")
print()

# Verify dataset exists
if os.path.exists(DATA_ROOT):
    print(f"‚úÖ Dataset root exists")
    
    # Check for required folders and files
    imagesTr_path = os.path.join(DATA_ROOT, "imagesTr")
    labelsTr_path = os.path.join(DATA_ROOT, "labelsTr")
    dataset_json_path = os.path.join(DATA_ROOT, "dataset.json")
    
    print(f"   - imagesTr folder: {'‚úÖ Found' if os.path.isdir(imagesTr_path) else '‚ùå Missing'}")
    print(f"   - labelsTr folder: {'‚úÖ Found' if os.path.isdir(labelsTr_path) else '‚ùå Missing'}")
    print(f"   - dataset.json: {'‚úÖ Found' if os.path.isfile(dataset_json_path) else '‚ùå Missing'}")
    
    if os.path.isdir(imagesTr_path):
        num_images = len([f for f in os.listdir(imagesTr_path) if f.endswith('.nii.gz') or f.endswith('.nii')])
        print(f"   - Number of images: {num_images}")
    
    if os.path.isdir(labelsTr_path):
        num_labels = len([f for f in os.listdir(labelsTr_path) if f.endswith('.nii.gz') or f.endswith('.nii')])
        print(f"   - Number of labels: {num_labels}")
else:
    print(f"‚ùå Dataset root NOT found!")
    print()
    print("üîç Searching for possible locations...")
    
    # Search for possible dataset locations
    search_paths = [
        "/content/drive/MyDrive/BrainTumor",
        "/content/drive/MyDrive",
        "/content/drive/My Drive/BrainTumor",
        "/content/drive/My Drive"
    ]
    
    for search_path in search_paths:
        if os.path.exists(search_path):
            print(f"\nüìÇ Found: {search_path}")
            items = os.listdir(search_path)
            print(f"   Contents: {items[:10]}")  # Show first 10 items
            
            # Look for Task01 folder
            for item in items:
                if 'task01' in item.lower() or 'brain' in item.lower():
                    print(f"   ‚≠ê Possible dataset: {os.path.join(search_path, item)}")
    
    print()
    print("üí° Please update DATA_ROOT to the correct path where your dataset is located.")
    print("   Your dataset should contain: imagesTr/, labelsTr/, and dataset.json")

In [None]:
from google.colab import drive
import os

# Check if already mounted
if not os.path.exists('/content/drive/MyDrive'):
    drive.mount('/content/drive')
    print("‚úÖ Google Drive mounted successfully")
else:
    print("‚úÖ Google Drive already mounted")

## 2. Mount Google Drive & Setup Paths

In [None]:
# Import all required libraries
import os
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
import monai
from monai.data import CacheDataset, DataLoader as MonaiDataLoader
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd,
    Orientationd, CropForegroundd, ResizeWithPadOrCropd,
    NormalizeIntensityd, RandRotate90d, RandFlipd,
    RandScaleIntensityd, RandShiftIntensityd, RandAffined
)
from monai.losses import DiceCELoss, DiceLoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.inferers import sliding_window_inference
from pathlib import Path
import matplotlib.pyplot as plt
import json
import nibabel as nib

# Set random seeds for reproducibility
monai.utils.set_determinism(seed=42)
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"MONAI version: {monai.__version__}")
print(f"PyTorch version: {torch.__version__}")

In [None]:
# Install required packages
!pip install monai[all]
!pip install nibabel
!pip install SimpleITK
!pip install matplotlib seaborn
!pip install tensorboard

## 1. Setup & Installation

# Brain Tumor Segmentation - Complete Pipeline
## MSD Task01_BrainTumour | 3D U-Net + MedSAM | MONAI + PyTorch

This notebook implements a complete brain tumor segmentation pipeline with:
- **Dataset**: MSD Task01_BrainTumour (BraTS)
- **Models**: 3D U-Net (MONAI) + MedSAM fine-tuning
- **Metrics**: Dice Score for Whole Tumor (WT), Tumor Core (TC), Enhancing Tumor (ET)
- **Analysis**: Training curves, failure analysis, experiment tracking

### Workflow:
1. Setup & Installation
2. Dataset Preparation
3. Preprocessing Transforms
4. DataLoaders
5. Model Architecture
6. Loss Functions
7. Training Loop
8. Evaluation
9. MedSAM Fine-tuning
10. Failure Analysis & Visualization