## üìÇ Cell 1: Setup Paths

In [None]:
from pathlib import Path
import sys

# Dataset paths
DATA_ROOT = Path('/kaggle/input/nsclc-multiorgan-segmentation')
CT_DIR = DATA_ROOT / 'normalized_ct'
MASKS_DIR = DATA_ROOT / 'normalized_masks'
CODE_DIR = DATA_ROOT / 'code'

# Add code to Python path
sys.path.append(str(CODE_DIR))

# Verify dataset
if DATA_ROOT.exists():
    print("‚úÖ Dataset found!")
    ct_files = list(CT_DIR.glob('*.nii'))
    mask_files = list(MASKS_DIR.glob('*.nii'))
    print(f"‚úÖ {len(ct_files)} CT files")
    print(f"‚úÖ {len(mask_files)} Mask files")
else:
    print("‚ùå Dataset not found!")
    print("Please add dataset: + Add Data ‚Üí Your Datasets ‚Üí nsclc-multiorgan-segmentation")

## üì¶ Cell 2: Install Dependencies & Imports

In [None]:
# Install SimpleITK if needed
!pip install SimpleITK -q

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import SimpleITK as sitk
import os

print(f"‚úÖ PyTorch: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")

## üìä Cell 3: Dataset Class

In [None]:
class MultiOrganDataset(Dataset):
    def __init__(self, ct_dir, masks_dir, transform=None):
        self.ct_dir = Path(ct_dir)
        self.masks_dir = Path(masks_dir)
        self.transform = transform
        
        # Get all CT files
        self.ct_files = sorted(self.ct_dir.glob('*.nii'))
        
        # Build patient list and slice mapping
        self.samples = []
        
        print(f"Loading {len(self.ct_files)} patients...")
        for ct_path in tqdm(self.ct_files):  # All 158 patients for full training
            patient_id = ct_path.stem.replace('_ct_normalized', '')
            mask_path = self.masks_dir / f"{patient_id}_mask_normalized.nii"
            
            if not mask_path.exists():
                continue
            
            # Load to get number of slices
            ct_img = sitk.ReadImage(str(ct_path))
            ct_array = sitk.GetArrayFromImage(ct_img)
            
            # Add each slice as a sample
            for slice_idx in range(ct_array.shape[0]):
                self.samples.append({
                    'ct_path': ct_path,
                    'mask_path': mask_path,
                    'slice_idx': slice_idx
                })
        
        print(f"‚úÖ Total slices: {len(self.samples)}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load CT and mask
        ct_img = sitk.ReadImage(str(sample['ct_path']))
        mask_img = sitk.ReadImage(str(sample['mask_path']))
        
        ct_array = sitk.GetArrayFromImage(ct_img)
        mask_array = sitk.GetArrayFromImage(mask_img)
        
        # Get slice
        ct_slice = ct_array[sample['slice_idx']]
        mask_slice = mask_array[sample['slice_idx']]
        
        # Convert to torch tensors
        ct_tensor = torch.from_numpy(ct_slice).unsqueeze(0).float()
        mask_tensor = torch.from_numpy(mask_slice).long()
        
        return {
            'image': ct_tensor,
            'mask': mask_tensor
        }

print("‚úÖ Dataset class defined")

## üß† Cell 4: U-Net Model

In [None]:
class UNetMultiOrgan(nn.Module):
    def __init__(self, in_channels=1, out_channels=8):
        super().__init__()
        
        # Encoder
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        
        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = self.conv_block(1024, 512)
        
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self.conv_block(128, 64)
        
        # Output
        self.out = nn.Conv2d(64, out_channels, 1)
        
        self.pool = nn.MaxPool2d(2)
    
    def conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        # Bottleneck
        b = self.bottleneck(self.pool(e4))
        
        # Decoder
        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        
        return self.out(d1)

print("‚úÖ U-Net model defined")

## ‚öôÔ∏è Cell 5: Configuration

In [None]:
CONFIG = {
    'batch_size': 8,
    'num_epochs': 50,  # Full training for best results
    'learning_rate': 1e-4,
    'num_workers': 2,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'patience': 10,  # Increased patience for full training
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## üìä Cell 6: Create Datasets & Loaders

In [None]:
print("Creating dataset...")
full_dataset = MultiOrganDataset(CT_DIR, MASKS_DIR)

# Split: 80% train, 20% val
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size]
)

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=CONFIG['num_workers'],
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers'],
    pin_memory=True
)

print(f"‚úÖ Train: {len(train_dataset)} slices")
print(f"‚úÖ Val: {len(val_dataset)} slices")

## üéØ Cell 7: Initialize Model & Loss

In [None]:
# Model
model = UNetMultiOrgan(in_channels=1, out_channels=8).to(CONFIG['device'])
num_params = sum(p.numel() for p in model.parameters())
print(f"‚úÖ Model: {num_params:,} parameters")

# Loss functions
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = torch.softmax(pred, dim=1)
        target_one_hot = torch.nn.functional.one_hot(target, 8).permute(0, 3, 1, 2).float()
        
        intersection = (pred * target_one_hot).sum(dim=(2, 3))
        union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()

class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()
        self.dice = DiceLoss()
    
    def forward(self, pred, target):
        return 0.5 * self.ce(pred, target) + 0.5 * self.dice(pred, target)

criterion = CombinedLoss()
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])

print("‚úÖ Loss: CombinedLoss (CE + Dice)")
print("‚úÖ Optimizer: Adam")

## üîÑ Cell 8: Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    
    pbar = tqdm(loader, desc='Training')
    for batch in pbar:
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    
    with torch.no_grad():
        pbar = tqdm(loader, desc='Validation')
        for batch in pbar:
            images = batch['image'].to(device)
            masks = batch['mask'].to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(loader)

print("‚úÖ Training functions ready")

## üöÄ Cell 9: Main Training Loop

In [None]:
print("="*80)
print("üöÄ STARTING TRAINING")
print("="*80)

history = {'train_loss': [], 'val_loss': []}
best_val_loss = float('inf')
patience_counter = 0

for epoch in range(CONFIG['num_epochs']):
    print(f"\nüìä Epoch {epoch+1}/{CONFIG['num_epochs']}")
    
    train_loss = train_epoch(model, train_loader, criterion, optimizer, CONFIG['device'])
    val_loss = validate(model, val_loader, criterion, CONFIG['device'])
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), '/kaggle/working/best_model.pth')
        print("‚úÖ Saved best model")
    else:
        patience_counter += 1
        if patience_counter >= CONFIG['patience']:
            print(f"‚ö†Ô∏è Early stopping at epoch {epoch+1}")
            break

print("="*80)
print("‚úÖ TRAINING COMPLETED")
print(f"Best Val Loss: {best_val_loss:.4f}")
print("="*80)

## üìä Cell 10: Plot Results

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Train Loss', marker='o')
plt.plot(history['val_loss'], label='Val Loss', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training & Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig('/kaggle/working/training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"üìä Best Val Loss: {min(history['val_loss']):.4f}")
print(f"üìä Final Train Loss: {history['train_loss'][-1]:.4f}")

## üé® Cell 11: Visualize Predictions

In [None]:
# Load best model
model.load_state_dict(torch.load('/kaggle/working/best_model.pth'))
model.eval()

# Get 4 samples
val_samples = [val_dataset[i] for i in range(0, len(val_dataset), len(val_dataset)//4)][:4]

fig, axes = plt.subplots(4, 3, figsize=(12, 16))

with torch.no_grad():
    for idx, sample in enumerate(val_samples):
        image = sample['image'].unsqueeze(0).to(CONFIG['device'])
        mask_true = sample['mask'].numpy()
        
        # Prediction
        output = model(image)
        mask_pred = torch.argmax(output, dim=1).cpu().numpy()[0]
        
        # Plot
        axes[idx, 0].imshow(image.cpu().squeeze(), cmap='gray')
        axes[idx, 0].set_title('CT Scan')
        axes[idx, 0].axis('off')
        
        axes[idx, 1].imshow(mask_true, cmap='tab10', vmin=0, vmax=7)
        axes[idx, 1].set_title('Ground Truth')
        axes[idx, 1].axis('off')
        
        axes[idx, 2].imshow(mask_pred, cmap='tab10', vmin=0, vmax=7)
        axes[idx, 2].set_title('Prediction')
        axes[idx, 2].axis('off')

plt.tight_layout()
plt.savefig('/kaggle/working/predictions.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Visualizations saved")

## üíæ Cell 12: Summary & Download Files

In [None]:
print("="*60)
print("üìä TRAINING SUMMARY")
print("="*60)
print(f"Dataset: {len(train_dataset)} train, {len(val_dataset)} val slices")
print(f"Model: U-Net ({num_params:,} parameters)")
print(f"Best Val Loss: {best_val_loss:.4f}")
print(f"Epochs: {len(history['train_loss'])}")
print("\nüì• Output files:")
print("  ‚úÖ /kaggle/working/best_model.pth")
print("  ‚úÖ /kaggle/working/training_curves.png")
print("  ‚úÖ /kaggle/working/predictions.png")
print("\nüí° Download files from Output section (sidebar right)")
print("="*60)