## üì¶ Installation des d√©pendances

In [None]:
!pip install -q SimpleITK nibabel tqdm scikit-learn seaborn

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## üìÅ Upload des donn√©es

**Option 1: Upload depuis PC**
- Zippe `data/processed/normalized/` sur ton PC
- Upload le zip ici
- D√©compresse

**Option 2: Google Drive**
- Monte Google Drive
- Place les donn√©es dans Drive

In [None]:
# ==================================================
# üì¶ CHARGER DONN√âES DEPUIS GOOGLE DRIVE
# ==================================================

from google.colab import drive
import zipfile
import os

# Monter Google Drive
drive.mount('/content/drive')

# Chemin du ZIP dans Drive (utilise colab_data.zip - le fichier complet)
ZIP_PATH = '/content/drive/MyDrive/colab_data.zip'

# V√©rifier si le ZIP existe
if not os.path.exists(ZIP_PATH):
    print(f"‚ùå Fichier non trouv√©: {ZIP_PATH}")
    print("\nüí° Instructions:")
    print("   1. Upload colab_data.zip dans Google Drive")
    print("   2. Place-le dans 'Mon Drive' (MyDrive)")
    print("   3. Relance cette cellule")
else:
    # Extraire dans Colab
    print("üì¶ Extraction du ZIP en cours...")
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall('/content/data')
    
    print("‚úÖ Extraction termin√©e!")
    
    # D√©finir les chemins
    DATA_ROOT = '/content/data/normalized'
    SPLITS_DIR = '/content/data/splits'
    
    # V√©rifier la structure
    print("\n" + "="*50)
    print("üìä V√âRIFICATION DES DONN√âES")
    print("="*50 + "\n")
    
    # Compter fichiers normalis√©s
    ct_files = [f for f in os.listdir(DATA_ROOT) if f.endswith('_ct_normalized.nii.gz')]
    mask_files = [f for f in os.listdir(DATA_ROOT) if f.endswith('_mask_normalized.nii.gz')]
    
    print(f"üìÅ Donn√©es normalis√©es:")
    print(f"   CT scans: {len(ct_files)}")
    print(f"   Masks: {len(mask_files)}")
    print()
    
    # V√©rifier splits
    for split in ['train', 'val', 'test']:
        split_file = os.path.join(SPLITS_DIR, f'{split}.txt')
        if os.path.exists(split_file):
            with open(split_file, 'r') as f:
                patient_ids = [line.strip() for line in f if line.strip()]
            print(f"   {split.capitalize()}: {len(patient_ids)} patients")
        else:
            print(f"   ‚ùå {split}.txt NOT FOUND!")
    
    print("\n‚úÖ Donn√©es pr√™tes pour l'entra√Ænement!")


## üèóÔ∏è PyTorch Dataset

In [None]:
# ==================================================
# üìä DATASET CLASS (adapt√© pour structure plate avec splits)
# ==================================================

from pathlib import Path
import SimpleITK as sitk
import torch
from torch.utils.data import Dataset, DataLoader
from typing import Tuple, List

class NSCLCDataset(Dataset):
    """Dataset pour les images CT et masques de segmentation."""
    
    def __init__(
        self,
        data_root: str,
        split_file: str,
        mode: str = 'slice',
        transform=None,
        cache: bool = False
    ):
        self.data_root = Path(data_root)
        self.mode = mode
        self.transform = transform
        self.cache = cache
        self.cached_data = {} if cache else None
        
        # Lire les patient IDs depuis le fichier split
        with open(split_file, 'r') as f:
            self.patient_ids = [line.strip() for line in f if line.strip()]
        
        # Construire les paires CT/mask
        self.ct_files = [self.data_root / f"{pid}_ct_normalized.nii.gz" for pid in self.patient_ids]
        self.mask_files = [self.data_root / f"{pid}_mask_normalized.nii.gz" for pid in self.patient_ids]
        
        # V√©rifier que tous les fichiers existent
        for ct_file, mask_file in zip(self.ct_files, self.mask_files):
            if not ct_file.exists():
                raise FileNotFoundError(f"CT file not found: {ct_file}")
            if not mask_file.exists():
                raise FileNotFoundError(f"Mask file not found: {mask_file}")
        
        # Construire index slice-wise
        if self.mode == 'slice':
            self.slice_index = []
            for idx, ct_file in enumerate(self.ct_files):
                img = sitk.ReadImage(str(ct_file))
                n_slices = img.GetSize()[2]
                for slice_idx in range(n_slices):
                    self.slice_index.append((idx, slice_idx))
    
    def __len__(self) -> int:
        if self.mode == 'slice':
            return len(self.slice_index)
        return len(self.ct_files)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.mode == 'slice':
            volume_idx, slice_idx = self.slice_index[idx]
        else:
            volume_idx = idx
            slice_idx = None
        
        # Cache
        cache_key = (volume_idx, slice_idx)
        if self.cache and cache_key in self.cached_data:
            return self.cached_data[cache_key]
        
        # Charger
        image = sitk.GetArrayFromImage(sitk.ReadImage(str(self.ct_files[volume_idx])))
        mask = sitk.GetArrayFromImage(sitk.ReadImage(str(self.mask_files[volume_idx])))
        
        if self.mode == 'slice':
            image = image[slice_idx]
            mask = mask[slice_idx]
        
        # Convertir en tenseurs
        image = torch.from_numpy(image).float().unsqueeze(0)
        mask = torch.from_numpy(mask).float().unsqueeze(0)
        
        # Transformation
        if self.transform:
            image, mask = self.transform(image, mask)
        
        if self.cache:
            self.cached_data[cache_key] = (image, mask)
        
        return image, mask

def create_dataloaders(
    data_root: str,
    splits_dir: str,
    batch_size: int = 8,
    num_workers: int = 2,
    mode: str = 'slice'
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Cr√©e les DataLoaders train/val/test depuis les fichiers split."""
    
    train_dataset = NSCLCDataset(
        data_root=data_root,
        split_file=f"{splits_dir}/train.txt",
        mode=mode
    )
    
    val_dataset = NSCLCDataset(
        data_root=data_root,
        split_file=f"{splits_dir}/val.txt",
        mode=mode
    )
    
    test_dataset = NSCLCDataset(
        data_root=data_root,
        split_file=f"{splits_dir}/test.txt",
        mode=mode
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader, test_loader

# Test rapide
print("‚úÖ Dataset d√©fini!")
print(f"   Mode: slice-wise (chaque slice = 1 √©chantillon)")
print(f"   Structure: data/normalized/*.nii.gz + data/splits/*.txt")


## üèõÔ∏è U-Net Model

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        # Encoder
        self.inc = DoubleConv(in_channels, features[0])
        self.down1 = Down(features[0], features[1])
        self.down2 = Down(features[1], features[2])
        self.down3 = Down(features[2], features[3])
        
        # Bottleneck
        self.down4 = Down(features[3], features[3] * 2)
        
        # Decoder
        self.up1 = Up(features[3] * 2, features[3])
        self.up2 = Up(features[3], features[2])
        self.up3 = Up(features[2], features[1])
        self.up4 = Up(features[1], features[0])
        
        # Output
        self.outc = nn.Conv2d(features[0], out_channels, kernel_size=1)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        return torch.sigmoid(self.outc(x))

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = pred.contiguous().view(-1)
        target = target.contiguous().view(-1)
        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        return 1 - dice

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super().__init__()
        self.alpha = alpha
        self.bce = nn.BCELoss()
        self.dice = DiceLoss()
    
    def forward(self, pred, target):
        return self.alpha * self.bce(pred, target) + (1 - self.alpha) * self.dice(pred, target)

# Test du mod√®le
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=1, out_channels=1).to(device)
print(f"‚úì Mod√®le cr√©√©: {sum(p.numel() for p in model.parameters()):,} param√®tres")
print(f"‚úì Device: {device}")

## üéØ Training Pipeline

In [None]:
from tqdm.notebook import tqdm
import json
from pathlib import Path

def calculate_metrics(pred, target, threshold=0.5):
    """Calcule Dice, IoU, etc."""
    pred_binary = (pred > threshold).float()
    target_binary = target.float()
    
    intersection = (pred_binary * target_binary).sum()
    union = pred_binary.sum() + target_binary.sum()
    
    dice = (2. * intersection + 1e-6) / (union + 1e-6)
    iou = (intersection + 1e-6) / (union - intersection + 1e-6)
    
    return {
        'dice': dice.item(),
        'iou': iou.item()
    }

class Trainer:
    def __init__(
        self,
        model,
        train_loader,
        val_loader,
        criterion,
        optimizer,
        device,
        checkpoint_dir='checkpoints',
        log_dir='logs'
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        
        self.checkpoint_dir = Path(checkpoint_dir)
        self.log_dir = Path(log_dir)
        self.checkpoint_dir.mkdir(exist_ok=True)
        self.log_dir.mkdir(exist_ok=True)
        
        self.history = {'train_loss': [], 'val_loss': [], 'val_dice': [], 'val_iou': []}
        self.best_val_dice = 0.0
    
    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0
        
        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch} [TRAIN]")
        for batch_idx, (images, masks) in enumerate(pbar):
            images = images.to(self.device)
            masks = masks.to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.criterion(outputs, masks)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})
        
        return total_loss / len(self.train_loader)
    
    def validate(self):
        self.model.eval()
        total_loss = 0
        all_dice = []
        all_iou = []
        
        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc="[VAL]")
            for images, masks in pbar:
                images = images.to(self.device)
                masks = masks.to(self.device)
                
                outputs = self.model(images)
                loss = self.criterion(outputs, masks)
                
                total_loss += loss.item()
                
                metrics = calculate_metrics(outputs, masks)
                all_dice.append(metrics['dice'])
                all_iou.append(metrics['iou'])
        
        return {
            'loss': total_loss / len(self.val_loader),
            'dice': np.mean(all_dice),
            'iou': np.mean(all_iou)
        }
    
    def save_checkpoint(self, epoch, val_metrics, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_metrics': val_metrics,
            'history': self.history
        }
        
        torch.save(checkpoint, self.checkpoint_dir / 'latest_checkpoint.pth')
        
        if is_best:
            torch.save(checkpoint, self.checkpoint_dir / 'best_model.pth')
            print(f"‚úì Best model saved (Dice: {val_metrics['dice']:.4f})")
    
    def train(self, num_epochs, scheduler=None, early_stopping_patience=5):
        print(f"\n{'='*50}")
        print(f"Starting training for {num_epochs} epochs")
        print(f"Device: {self.device}")
        print(f"{'='*50}\n")
        
        patience_counter = 0
        
        for epoch in range(1, num_epochs + 1):
            # Train
            train_loss = self.train_epoch(epoch)
            
            # Validate
            val_metrics = self.validate()
            
            # Update history
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_metrics['loss'])
            self.history['val_dice'].append(val_metrics['dice'])
            self.history['val_iou'].append(val_metrics['iou'])
            
            # Print summary
            print(f"\nEpoch {epoch}/{num_epochs}:")
            print(f"  Train Loss: {train_loss:.4f}")
            print(f"  Val Loss: {val_metrics['loss']:.4f}")
            print(f"  Val Dice: {val_metrics['dice']:.4f}")
            print(f"  Val IoU: {val_metrics['iou']:.4f}")
            
            # Save checkpoint
            is_best = val_metrics['dice'] > self.best_val_dice
            if is_best:
                self.best_val_dice = val_metrics['dice']
                patience_counter = 0
            else:
                patience_counter += 1
            
            self.save_checkpoint(epoch, val_metrics, is_best)
            
            # LR scheduler
            if scheduler:
                scheduler.step(val_metrics['dice'])
            
            # Early stopping
            if patience_counter >= early_stopping_patience:
                print(f"\n‚ö†Ô∏è Early stopping triggered (patience={early_stopping_patience})")
                break
        
        # Save history
        with open(self.log_dir / 'training_history.json', 'w') as f:
            json.dump(self.history, f, indent=2)
        
        print(f"\n{'='*50}")
        print(f"Training completed!")
        print(f"Best Val Dice: {self.best_val_dice:.4f}")
        print(f"{'='*50}\n")

## üöÄ Lancement de l'entra√Ænement

In [None]:
# Configuration
BATCH_SIZE = 16 if torch.cuda.is_available() else 4  # Plus grand batch avec GPU
LEARNING_RATE = 0.001
NUM_EPOCHS = 50  # Plus d'epochs possible avec GPU
PATIENCE = 10

print(f"Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Device: {device}\n")

# Cr√©er les DataLoaders (avec les chemins corrects)
train_loader, val_loader, test_loader = create_dataloaders(
    data_root=DATA_ROOT,
    splits_dir=SPLITS_DIR,
    batch_size=BATCH_SIZE,
    num_workers=2,
    mode='slice'
)

print(f"\nüì¶ Datasets charg√©s:")
print(f"   Train: {len(train_loader.dataset)} slices")
print(f"   Val: {len(val_loader.dataset)} slices")
print(f"   Test: {len(test_loader.dataset)} slices\n")

# Initialiser le mod√®le
model = UNet(in_channels=1, out_channels=1).to(device)
criterion = CombinedLoss(alpha=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=5
)

# Cr√©er le trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    checkpoint_dir='checkpoints',
    log_dir='logs'
)

# LANCER L'ENTRA√éNEMENT üöÄ
trainer.train(
    num_epochs=NUM_EPOCHS,
    scheduler=scheduler,
    early_stopping_patience=PATIENCE
)


## üìä Visualisation des r√©sultats

In [None]:
import matplotlib.pyplot as plt
import json

# Charger l'historique
with open('logs/training_history.json', 'r') as f:
    history = json.load(f)

# Cr√©er les graphiques
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Dice Score
axes[0, 1].plot(history['val_dice'], label='Val Dice', color='green', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Dice Score')
axes[0, 1].set_title('Validation Dice Score')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# IoU Score
axes[1, 0].plot(history['val_iou'], label='Val IoU', color='orange', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('IoU Score')
axes[1, 0].set_title('Validation IoU Score')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Summary
axes[1, 1].axis('off')
summary_text = f"""
TRAINING SUMMARY
================

Total Epochs: {len(history['train_loss'])}

Final Train Loss: {history['train_loss'][-1]:.4f}
Final Val Loss: {history['val_loss'][-1]:.4f}

Best Val Dice: {max(history['val_dice']):.4f}
Best Val IoU: {max(history['val_iou']):.4f}

Model saved: checkpoints/best_model.pth
"""
axes[1, 1].text(0.1, 0.5, summary_text, fontsize=12, family='monospace',
                verticalalignment='center')

plt.tight_layout()
plt.savefig('training_results.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úì Graphiques sauvegard√©s: training_results.png")

## üîç Test sur quelques pr√©dictions

In [None]:
# Charger le meilleur mod√®le
checkpoint = torch.load('checkpoints/best_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Prendre quelques exemples du test set
test_iter = iter(test_loader)
images, masks = next(test_iter)
images = images.to(device)
masks = masks.to(device)

# Pr√©dictions
with torch.no_grad():
    predictions = model(images)

# Visualiser
fig, axes = plt.subplots(3, 5, figsize=(20, 12))
for i in range(min(5, len(images))):
    # Image CT
    axes[0, i].imshow(images[i, 0].cpu(), cmap='gray')
    axes[0, i].set_title('CT Image')
    axes[0, i].axis('off')
    
    # Ground Truth
    axes[1, i].imshow(masks[i, 0].cpu(), cmap='gray')
    axes[1, i].set_title('Ground Truth')
    axes[1, i].axis('off')
    
    # Pr√©diction
    axes[2, i].imshow(predictions[i, 0].cpu(), cmap='gray')
    dice = calculate_metrics(predictions[i:i+1], masks[i:i+1])['dice']
    axes[2, i].set_title(f'Prediction (Dice: {dice:.3f})')
    axes[2, i].axis('off')

plt.tight_layout()
plt.savefig('predictions_sample.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úì Pr√©dictions sauvegard√©es: predictions_sample.png")

## üíæ T√©l√©charger les r√©sultats

In [None]:
# Zipper tous les r√©sultats
!zip -r results.zip checkpoints/ logs/ *.png

# T√©l√©charger
from google.colab import files
files.download('results.zip')

print("‚úì T√©l√©chargement lanc√©!")
print("\nContenu du zip:")
print("  - checkpoints/best_model.pth (meilleur mod√®le)")
print("  - checkpoints/latest_checkpoint.pth (dernier checkpoint)")
print("  - logs/training_history.json (historique)")
print("  - training_results.png (graphiques)")
print("  - predictions_sample.png (exemples)")

## ‚úÖ R√©sum√© Final

**Ce qui a √©t√© fait:**
1. ‚úì Installation des d√©pendances
2. ‚úì Upload des donn√©es normalis√©es
3. ‚úì Cr√©ation du Dataset PyTorch
4. ‚úì Cr√©ation du mod√®le U-Net
5. ‚úì Entra√Ænement complet avec GPU
6. ‚úì Visualisation des r√©sultats
7. ‚úì T√©l√©chargement des checkpoints

**Prochaines √©tapes sur ton PC:**
1. D√©zippe `results.zip`
2. Place `checkpoints/` dans ton projet
3. Lance `python evaluate.py` pour l'√©valuation compl√®te
4. Lance `python visualize_predictions.py` pour plus de visualisations

**Temps total:** ~10-15 minutes avec GPU T4 üöÄ