## 1Ô∏è‚É£ Setup Google Drive

In [None]:
# Monter Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Changer vers le dossier du projet
%cd /content/drive/MyDrive/RADIO_PROJET

## 2Ô∏è‚É£ V√©rifier GPU

In [None]:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 3Ô∏è‚É£ Installer d√©pendances

In [None]:
!pip install SimpleITK

## 4Ô∏è‚É£ Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import json

# Modules du projet
from dataset_multi_organ import MultiOrganDataset, compute_class_weights, get_class_distribution
from unet_multi_organ import UNetMultiOrgan, count_parameters
from train_multi_organ import CombinedLoss, train_model, plot_training_curves

print("‚úÖ Imports OK")

## 5Ô∏è‚É£ Configuration

In [None]:
# Hyperparam√®tres
BATCH_SIZE = 16  # Augmenter si assez de GPU RAM
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
NUM_CLASSES = 8
NUM_WORKERS = 2
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Chemins
BASE_DIR = Path('/content/drive/MyDrive/RADIO_PROJET')
DATA_DIR = BASE_DIR / 'DATA' / 'processed'
CT_DIR = DATA_DIR / 'normalized'
MASK_DIR = DATA_DIR / 'masks_multi_organ'
SPLITS_DIR = DATA_DIR / 'splits_rtstruct'
CHECKPOINT_DIR = BASE_DIR / 'checkpoints_multi_organ'
CHECKPOINT_DIR.mkdir(exist_ok=True)

print(f"Device: {DEVICE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Checkpoint dir: {CHECKPOINT_DIR}")

## 6Ô∏è‚É£ Charger donn√©es

In [None]:
# Charger les splits
print("Chargement des splits...")
with open(SPLITS_DIR / 'train.txt', 'r') as f:
    train_ids = [line.strip() for line in f.readlines()]
with open(SPLITS_DIR / 'val.txt', 'r') as f:
    val_ids = [line.strip() for line in f.readlines()]

print(f"‚úÖ Train: {len(train_ids)} patients")
print(f"‚úÖ Val: {len(val_ids)} patients")

# Cr√©er datasets
print("\nCr√©ation des datasets...")
train_dataset = MultiOrganDataset(train_ids, CT_DIR, MASK_DIR)
val_dataset = MultiOrganDataset(val_ids, CT_DIR, MASK_DIR)

print(f"‚úÖ Train dataset: {len(train_dataset)} slices")
print(f"‚úÖ Val dataset: {len(val_dataset)} slices")

## 7Ô∏è‚É£ Visualiser √©chantillons

In [None]:
# Visualiser quelques samples
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

LABEL_COLORS = {
    0: [0, 0, 0], 1: [255, 0, 0], 2: [255, 165, 0],
    3: [0, 255, 255], 4: [0, 191, 255], 5: [255, 0, 255],
    6: [255, 255, 0], 7: [0, 255, 0]
}

for i in range(4):
    idx = i * 1000
    ct, mask = train_dataset[idx]
    
    ct_np = ct.squeeze().numpy()
    mask_np = mask.numpy()
    
    # Cr√©er masque RGB
    mask_rgb = np.zeros((*mask_np.shape, 3), dtype=np.uint8)
    for label, color in LABEL_COLORS.items():
        mask_rgb[mask_np == label] = color
    
    # CT
    axes[0, i].imshow(ct_np, cmap='gray')
    axes[0, i].set_title(f'Sample {i+1} - CT', fontsize=12)
    axes[0, i].axis('off')
    
    # Masque
    axes[1, i].imshow(mask_rgb)
    unique_labels = np.unique(mask_np)
    axes[1, i].set_title(f'{len(unique_labels)-1} organes', fontsize=12)
    axes[1, i].axis('off')

plt.tight_layout()
plt.savefig('samples_preview.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úÖ Visualisation sauvegard√©e: samples_preview.png")

## 8Ô∏è‚É£ Calculer class weights

In [None]:
# NOTE: Cette √©tape peut prendre 10-15 minutes
# Si tu as d√©j√† calcul√© les weights, charge-les depuis un fichier JSON

import os

weights_file = CHECKPOINT_DIR / 'class_weights.json'

if weights_file.exists():
    print("Chargement des class weights depuis fichier...")
    with open(weights_file, 'r') as f:
        weights_list = json.load(f)
    class_weights = torch.tensor(weights_list, dtype=torch.float32).to(DEVICE)
else:
    print("Calcul des class weights (peut prendre 10-15 min)...")
    class_counts = get_class_distribution(train_dataset)
    class_weights = compute_class_weights(class_counts, method='sqrt_inverse')
    
    # Sauvegarder pour r√©utilisation
    with open(weights_file, 'w') as f:
        json.dump(class_weights.tolist(), f)
    
    class_weights = class_weights.to(DEVICE)

print("\nClass weights:")
label_names = ['Background', 'GTV', 'PTV', 'Poumon_D', 'Poumon_G', 'Coeur', 'Oesophage', 'Moelle']
for i, (name, weight) in enumerate(zip(label_names, class_weights)):
    print(f"  {i}: {name:12s} = {weight:.4f}")

## 9Ô∏è‚É£ Cr√©er DataLoaders

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"‚úÖ Train batches: {len(train_loader)}")
print(f"‚úÖ Val batches: {len(val_loader)}")

## üîü Cr√©er mod√®le

In [None]:
model = UNetMultiOrgan(n_channels=1, n_classes=NUM_CLASSES, bilinear=False)
model = model.to(DEVICE)

num_params = count_parameters(model)
model_size_mb = num_params * 4 / (1024**2)

print(f"‚úÖ Mod√®le cr√©√©")
print(f"  Param√®tres: {num_params:,}")
print(f"  Taille: {model_size_mb:.2f} MB")

## 1Ô∏è‚É£1Ô∏è‚É£ Loss et Optimizer

In [None]:
# Loss combin√©e: CrossEntropy + Dice
criterion = CombinedLoss(
    num_classes=NUM_CLASSES,
    class_weights=class_weights,
    ce_weight=0.5,
    dice_weight=0.5
)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("‚úÖ Loss: CombinedLoss (CE + Dice)")
print("‚úÖ Optimizer: Adam")

## 1Ô∏è‚É£2Ô∏è‚É£ üöÄ ENTRA√éNEMENT

In [None]:
print("="*70)
print("D√âBUT DE L'ENTRA√éNEMENT")
print("="*70)

history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=DEVICE,
    num_epochs=NUM_EPOCHS,
    save_dir=str(CHECKPOINT_DIR),
    early_stopping_patience=10
)

print("\n" + "="*70)
print("‚úÖ ENTRA√éNEMENT TERMIN√â!")
print("="*70)

## 1Ô∏è‚É£3Ô∏è‚É£ Visualiser courbes d'entra√Ænement

In [None]:
plot_training_curves(history, save_path=str(CHECKPOINT_DIR / "training_curves.png"))
plt.show()

## 1Ô∏è‚É£4Ô∏è‚É£ Charger meilleur mod√®le et visualiser pr√©dictions

In [None]:
# Charger le meilleur mod√®le
checkpoint = torch.load(CHECKPOINT_DIR / 'best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"‚úÖ Meilleur mod√®le charg√© (Epoch {checkpoint['epoch']})")
print(f"   Val Dice: {checkpoint['val_dice']:.4f}")

# Visualiser pr√©dictions
fig, axes = plt.subplots(3, 4, figsize=(20, 15))

with torch.no_grad():
    for i in range(4):
        idx = i * 1000
        ct, mask_true = val_dataset[idx]
        
        # Pr√©diction
        ct_batch = ct.unsqueeze(0).to(DEVICE)
        logits = model(ct_batch)
        mask_pred = torch.argmax(logits, dim=1).squeeze().cpu().numpy()
        
        ct_np = ct.squeeze().numpy()
        mask_true_np = mask_true.numpy()
        
        # Cr√©er masques RGB
        mask_true_rgb = np.zeros((*mask_true_np.shape, 3), dtype=np.uint8)
        mask_pred_rgb = np.zeros((*mask_pred.shape, 3), dtype=np.uint8)
        for label, color in LABEL_COLORS.items():
            mask_true_rgb[mask_true_np == label] = color
            mask_pred_rgb[mask_pred == label] = color
        
        # CT
        axes[0, i].imshow(ct_np, cmap='gray')
        axes[0, i].set_title(f'Sample {i+1} - CT', fontsize=12)
        axes[0, i].axis('off')
        
        # Ground Truth
        axes[1, i].imshow(mask_true_rgb)
        axes[1, i].set_title('Ground Truth', fontsize=12)
        axes[1, i].axis('off')
        
        # Pr√©diction
        axes[2, i].imshow(mask_pred_rgb)
        axes[2, i].set_title('Pr√©diction', fontsize=12)
        axes[2, i].axis('off')

plt.tight_layout()
plt.savefig(CHECKPOINT_DIR / 'predictions_samples.png', dpi=200, bbox_inches='tight')
plt.show()
print("‚úÖ Visualisation sauvegard√©e")

## 1Ô∏è‚É£5Ô∏è‚É£ üìä R√©sum√© final

In [None]:
print("="*70)
print("R√âSUM√â FINAL")
print("="*70)
print(f"\nüìÅ Dataset:")
print(f"   Train: {len(train_ids)} patients, {len(train_dataset):,} slices")
print(f"   Val: {len(val_ids)} patients, {len(val_dataset):,} slices")

print(f"\nüéØ Mod√®le:")
print(f"   Architecture: U-Net multi-classes")
print(f"   Param√®tres: {num_params:,}")
print(f"   Classes: {NUM_CLASSES}")

print(f"\nüèÜ Meilleur mod√®le:")
print(f"   Epoch: {checkpoint['epoch']}")
print(f"   Val Dice: {checkpoint['val_dice']:.4f}")

print(f"\nüìä Dice par organe:")
for c, score in checkpoint['val_dice_per_class'].items():
    print(f"   {label_names[c]:12s}: {score:.4f}")

print(f"\nüíæ Fichiers sauvegard√©s:")
print(f"   {CHECKPOINT_DIR / 'best_model.pth'}")
print(f"   {CHECKPOINT_DIR / 'training_history.json'}")
print(f"   {CHECKPOINT_DIR / 'training_curves.png'}")
print(f"   {CHECKPOINT_DIR / 'predictions_samples.png'}")

print("\n" + "="*70)
print("‚úÖ PROJET TERMIN√â!")
print("="*70)