## üì¶ Setup & Imports

In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import json

# Add dataset code to path
code_dir = '/kaggle/input/nsclc-multiorgan-segmentation/code'
sys.path.append(code_dir)

# Import custom modules
from dataset_multi_organ import MultiOrganDataset
from unet_multi_organ import UNetMultiOrgan

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## ‚öôÔ∏è Configuration

In [None]:
# Paths
DATA_ROOT = '/kaggle/input/nsclc-multiorgan-segmentation'
OUTPUT_DIR = Path('/kaggle/working')
OUTPUT_DIR.mkdir(exist_ok=True)

# Training hyperparameters
CONFIG = {
    'batch_size': 8,
    'num_epochs': 50,
    'learning_rate': 1e-4,
    'num_workers': 2,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'patience': 10,  # Early stopping
    'in_channels': 1,
    'out_channels': 8,
    'bilinear': False
}

# Organ names
ORGAN_NAMES = {
    0: 'Background',
    1: 'GTV',
    2: 'PTV',
    3: 'Right_Lung',
    4: 'Left_Lung',
    5: 'Heart',
    6: 'Esophagus',
    7: 'Spinal_Cord'
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## üìä Load Datasets

In [None]:
print("Loading datasets...")

train_dataset = MultiOrganDataset(
    data_root=DATA_ROOT,
    split='train',
    slice_wise=True
)

val_dataset = MultiOrganDataset(
    data_root=DATA_ROOT,
    split='val',
    slice_wise=True
)

print(f"\n‚úÖ Train dataset: {len(train_dataset)} slices")
print(f"‚úÖ Val dataset: {len(val_dataset)} slices")

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True
)

print(f"\n‚úÖ Train batches: {len(train_loader)}")
print(f"‚úÖ Val batches: {len(val_loader)}")

## üîç Visualize Sample

In [None]:
# Visualize a sample
sample_batch = next(iter(train_loader))
sample_image = sample_batch['image'][0, 0].cpu().numpy()
sample_mask = sample_batch['mask'][0].cpu().numpy()

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(sample_image, cmap='gray')
axes[0].set_title('CT Image')
axes[0].axis('off')

axes[1].imshow(sample_mask, cmap='tab10', vmin=0, vmax=7)
axes[1].set_title('Ground Truth Mask')
axes[1].axis('off')

axes[2].imshow(sample_image, cmap='gray')
axes[2].imshow(sample_mask, cmap='tab10', vmin=0, vmax=7, alpha=0.5)
axes[2].set_title('Overlay')
axes[2].axis('off')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'sample_visualization.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nImage shape: {sample_batch['image'].shape}")
print(f"Mask shape: {sample_batch['mask'].shape}")
print(f"Unique labels: {torch.unique(sample_batch['mask']).tolist()}")

## ‚öñÔ∏è Calculate Class Weights

In [None]:
def calculate_class_weights(dataset, num_classes=8, max_samples=1000):
    """Calculate class weights using sqrt inverse frequency."""
    print(f"Calculating class weights (sampling {max_samples} slices)...")
    
    class_counts = np.zeros(num_classes, dtype=np.int64)
    
    indices = np.random.choice(len(dataset), size=min(max_samples, len(dataset)), replace=False)
    
    for idx in tqdm(indices):
        mask = dataset[idx]['mask'].numpy()
        unique, counts = np.unique(mask, return_counts=True)
        for label, count in zip(unique, counts):
            if label < num_classes:
                class_counts[label] += count
    
    # Sqrt inverse frequency
    total_pixels = class_counts.sum()
    class_weights = np.sqrt(total_pixels / (class_counts + 1e-5))
    class_weights = class_weights / class_weights.sum() * num_classes
    
    print("\nClass weights:")
    for i in range(num_classes):
        freq = class_counts[i] / total_pixels * 100
        print(f"  {ORGAN_NAMES[i]:15s}: weight={class_weights[i]:.4f}, freq={freq:.2f}%")
    
    return torch.tensor(class_weights, dtype=torch.float32)

# Calculate or load weights
weights_path = OUTPUT_DIR / 'class_weights.json'
if weights_path.exists():
    print("Loading cached class weights...")
    with open(weights_path) as f:
        class_weights = torch.tensor(json.load(f), dtype=torch.float32)
else:
    class_weights = calculate_class_weights(train_dataset)
    with open(weights_path, 'w') as f:
        json.dump(class_weights.tolist(), f)

class_weights = class_weights.to(CONFIG['device'])

## üèóÔ∏è Initialize Model

In [None]:
model = UNetMultiOrgan(
    in_channels=CONFIG['in_channels'],
    out_channels=CONFIG['out_channels'],
    bilinear=CONFIG['bilinear']
).to(CONFIG['device'])

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"\n‚úÖ Model created: {num_params:,} parameters ({num_params*4/1e6:.2f} MB)")

# Test forward pass
with torch.no_grad():
    test_input = torch.randn(2, 1, 256, 256).to(CONFIG['device'])
    test_output = model(test_input)
    print(f"Test forward pass: {test_input.shape} ‚Üí {test_output.shape}")

## üìâ Define Loss Functions

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, num_classes=8, smooth=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred_softmax = torch.softmax(pred, dim=1)
        target_one_hot = torch.nn.functional.one_hot(target, self.num_classes).permute(0, 3, 1, 2).float()
        
        intersection = (pred_softmax * target_one_hot).sum(dim=(2, 3))
        union = pred_softmax.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()

class CombinedLoss(nn.Module):
    def __init__(self, class_weights, num_classes=8):
        super().__init__()
        self.ce_loss = nn.CrossEntropyLoss(weight=class_weights)
        self.dice_loss = DiceLoss(num_classes=num_classes)
    
    def forward(self, pred, target):
        ce = self.ce_loss(pred, target)
        dice = self.dice_loss(pred, target)
        return 0.5 * ce + 0.5 * dice

criterion = CombinedLoss(class_weights=class_weights, num_classes=CONFIG['out_channels'])
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])

print("‚úÖ Loss function: CombinedLoss (0.5 * CrossEntropy + 0.5 * Dice)")
print(f"‚úÖ Optimizer: Adam (lr={CONFIG['learning_rate']})")

## üéì Training Loop

In [None]:
def compute_dice_score(pred, target, num_classes=8):
    """Compute Dice score per class."""
    pred_labels = torch.argmax(pred, dim=1)
    dice_scores = []
    
    for c in range(1, num_classes):  # Skip background
        pred_c = (pred_labels == c).float()
        target_c = (target == c).float()
        
        intersection = (pred_c * target_c).sum()
        union = pred_c.sum() + target_c.sum()
        
        if union > 0:
            dice = (2.0 * intersection) / union
            dice_scores.append(dice.item())
        else:
            dice_scores.append(0.0)
    
    return np.mean(dice_scores)

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    total_dice = 0.0
    
    pbar = tqdm(loader, desc='Training')
    for batch in pbar:
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        loss.backward()
        optimizer.step()
        
        dice = compute_dice_score(outputs, masks)
        
        total_loss += loss.item()
        total_dice += dice
        
        pbar.set_postfix({'loss': loss.item(), 'dice': dice})
    
    return total_loss / len(loader), total_dice / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_dice = 0.0
    
    with torch.no_grad():
        pbar = tqdm(loader, desc='Validation')
        for batch in pbar:
            images = batch['image'].to(device)
            masks = batch['mask'].to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            dice = compute_dice_score(outputs, masks)
            
            total_loss += loss.item()
            total_dice += dice
            
            pbar.set_postfix({'loss': loss.item(), 'dice': dice})
    
    return total_loss / len(loader), total_dice / len(loader)

## üöÄ Main Training

In [None]:
print("\n" + "="*80)
print("üöÄ STARTING TRAINING")
print("="*80 + "\n")

history = {
    'train_loss': [],
    'train_dice': [],
    'val_loss': [],
    'val_dice': []
}

best_val_loss = float('inf')
patience_counter = 0

for epoch in range(CONFIG['num_epochs']):
    print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
    print("-" * 80)
    
    # Train
    train_loss, train_dice = train_epoch(
        model, train_loader, criterion, optimizer, CONFIG['device']
    )
    
    # Validate
    val_loss, val_dice = validate(
        model, val_loader, criterion, CONFIG['device']
    )
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_dice'].append(train_dice)
    history['val_loss'].append(val_loss)
    history['val_dice'].append(val_dice)
    
    print(f"\nTrain Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_dice': val_dice
        }, OUTPUT_DIR / 'best_model.pth')
        print("‚úÖ Saved best model")
    else:
        patience_counter += 1
        print(f"Patience: {patience_counter}/{CONFIG['patience']}")
    
    # Early stopping
    if patience_counter >= CONFIG['patience']:
        print(f"\n‚ö†Ô∏è Early stopping at epoch {epoch+1}")
        break
    
    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, OUTPUT_DIR / f'checkpoint_epoch_{epoch+1}.pth')

print("\n" + "="*80)
print("‚úÖ TRAINING COMPLETED")
print("="*80)

## üìä Plot Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Dice
axes[1].plot(history['train_dice'], label='Train Dice', linewidth=2)
axes[1].plot(history['val_dice'], label='Val Dice', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Dice Score', fontsize=12)
axes[1].set_title('Training & Validation Dice', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nüìä Best Val Loss: {min(history['val_loss']):.4f}")
print(f"üìä Best Val Dice: {max(history['val_dice']):.4f}")

## üîÆ Visualize Predictions

In [None]:
# Load best model
checkpoint = torch.load(OUTPUT_DIR / 'best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Get samples
val_batch = next(iter(val_loader))
images = val_batch['image'].to(CONFIG['device'])
masks = val_batch['mask'].cpu().numpy()

with torch.no_grad():
    outputs = model(images)
    preds = torch.argmax(outputs, dim=1).cpu().numpy()

images_np = images.cpu().numpy()

# Plot 4 samples
fig, axes = plt.subplots(4, 3, figsize=(12, 16))

for i in range(min(4, len(images))):
    # CT
    axes[i, 0].imshow(images_np[i, 0], cmap='gray')
    axes[i, 0].set_title('CT Image', fontsize=10)
    axes[i, 0].axis('off')
    
    # Ground truth
    axes[i, 1].imshow(masks[i], cmap='tab10', vmin=0, vmax=7)
    axes[i, 1].set_title('Ground Truth', fontsize=10)
    axes[i, 1].axis('off')
    
    # Prediction
    axes[i, 2].imshow(preds[i], cmap='tab10', vmin=0, vmax=7)
    axes[i, 2].set_title('Prediction', fontsize=10)
    axes[i, 2].axis('off')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'predictions.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Predictions visualized")

## üíæ Save Final Model

In [None]:
# Save history
with open(OUTPUT_DIR / 'training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print("\n" + "="*80)
print("üìÅ SAVED FILES:")
print("="*80)
print(f"  ‚Ä¢ best_model.pth - Best model checkpoint")
print(f"  ‚Ä¢ training_curves.png - Loss & Dice plots")
print(f"  ‚Ä¢ predictions.png - Sample predictions")
print(f"  ‚Ä¢ training_history.json - Full training history")
print(f"  ‚Ä¢ class_weights.json - Class weights for reuse")
print("\n‚úÖ Training pipeline completed successfully!")
print("="*80)