In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from scipy.ndimage import label
from scipy.spatial.distance import directed_hausdorff
from model import BrainTumorSegmentationModel
from torch.utils.data import random_split
from dataset import BraTSDataset
import csv
from torch.utils.tensorboard import SummaryWriter  # Option TensorBoard

# Initialisation CSV
with open('training_metrics.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow([
        'epoch',
        'train_loss',
        'val_dice_et', 'val_dice_tc', 'val_dice_wt',
        'val_hd_et', 'val_hd_tc', 'val_hd_wt'
    ])



#le modèle
model = BrainTumorSegmentationModel(in_channels=1, num_classes=3)

# Optimiseur
optimizer = Adam(model.parameters(), lr=1e-5, weight_decay=1e-4)

class BCEDiceLoss(nn.Module):
    def __init__(self):
        super(BCEDiceLoss, self).__init__()
        self.bce = nn.BCELoss()

    def forward(self, preds, targets, smooth=1e-5):
        bce_loss = self.bce(preds, targets)

        dice_loss = 0.0
        for i in range(preds.shape[1]):
            pred = preds[:, i]
            target = targets[:, i]
            intersection = (pred * target).sum()
            dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
            dice_loss += 1 - dice

        dice_loss /= preds.shape[1]
        return bce_loss + dice_loss

# Fonction de perte
criterion = BCEDiceLoss()

train_set=BraTSDataset("BraTS2021_Augmented")
val_set = BraTSDataset("BraTS2021_Val")

# DataLoader
train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
val_loader = DataLoader(val_set, batch_size=1, shuffle=False)

# Dice Score
def dice_score(preds, targets, smooth=1e-5):
    scores = []
    for i in range(preds.shape[1]):
        pred = preds[:, i]
        target = targets[:, i]
        intersection = (pred * target).sum()
        dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
        scores.append(dice.item())
    return scores

# Fonction Hausdorff95
def hausdorff95(pred, target):
    pred = pred.astype(bool)
    target = target.astype(bool)

    if pred.sum() == 0 or target.sum() == 0:
        return np.nan

    pred_pts = np.argwhere(pred)
    target_pts = np.argwhere(target)

    d1 = directed_hausdorff(pred_pts, target_pts)[0]
    d2 = directed_hausdorff(target_pts, pred_pts)[0]

    return np.percentile([d1, d2], 95)

# Supprimer petits ET < 200 voxels
def remove_small_et(pred, min_size=200):
    et = pred[0]  # ET est le premier canal
    labeled, num = label(et)
    for i in range(1, num + 1):
        if (labeled == i).sum() < min_size:
            et[labeled == i] = 0
    pred[0] = et
    return pred

# Boucle d'entraînement + validation
n_epochs = 100

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: (1 - epoch / n_epochs) ** 0.9)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0.0

    for flair, t1, t1ce, t2, mask in tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}"):
        flair, t1, t1ce, t2, mask = flair.to(device), t1.to(device), t1ce.to(device), t2.to(device), mask.to(device)

        with torch.no_grad():
            et = (mask == 4).float()
            tc = ((mask == 1) | (mask == 4)).float()
            wt = ((mask == 1) | (mask == 2) | (mask == 4)).float()
            target = torch.cat([et, tc, wt], dim=1)

        output = model(flair, t1, t1ce, t2)
        loss = criterion(output, target)
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f" Epoch {epoch+1} terminé - Loss moyenne : {epoch_loss/len(train_loader):.4f}")

    # Validation
    model.eval()
    dice_totals = [0.0, 0.0, 0.0]
    hausdorff_totals = [0.0, 0.0, 0.0]
    valid_counts = [0, 0, 0]

    with torch.no_grad():
        for flair, t1, t1ce, t2, mask in val_loader:
            flair, t1, t1ce, t2, mask = flair.to(device), t1.to(device), t1ce.to(device), t2.to(device), mask.to(device)

            et = (mask == 4).float()
            tc = ((mask == 1) | (mask == 4)).float()
            wt = ((mask == 1) | (mask == 2) | (mask == 4)).float()
            target = torch.cat([et, tc, wt], dim=1)

            output = model(flair, t1, t1ce, t2)
            preds = (output > 0.5).float().cpu().numpy()
            target_np = target.cpu().numpy()

            # Supprimer petits ET
            preds[0] = remove_small_et(preds[0])

            # Dice
            dices = dice_score(torch.tensor(preds), torch.tensor(target_np))
            dice_totals = [dice_totals[i] + dices[i] for i in range(3)]

            # Hausdorff95
            for i in range(3):
                hd = hausdorff95(preds[0, i], target_np[0, i])
                if not np.isnan(hd):
                    hausdorff_totals[i] += hd
                    valid_counts[i] += 1

    n = len(val_loader)
    dice_avgs = [d / n for d in dice_totals]
    hausdorff_avgs = [hausdorff_totals[i] / valid_counts[i] if valid_counts[i] > 0 else float('nan') for i in range(3)]

    print(f" Dice Validation - ET: {dice_avgs[0]:.4f} | TC: {dice_avgs[1]:.4f} | WT: {dice_avgs[2]:.4f}")
    print(f"c      - ET: {hausdorff_avgs[0]:.2f} | TC: {hausdorff_avgs[1]:.2f} | WT: {hausdorff_avgs[2]:.2f}")
    scheduler.step()

    # Sauvegarde des métriques dans CSV
    with open('training_metrics.csv', 'a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            epoch + 1,
            epoch_loss / len(train_loader),
            dice_avgs[0], dice_avgs[1], dice_avgs[2],
            hausdorff_avgs[0], hausdorff_avgs[1], hausdorff_avgs[2]
        ])



    # sauvegarde du modèle après la dernière époque
    torch.save(model.state_dict(), "brain_tumor_model.pth")
    print("Modèle entraîné et sauvegardé avec succès.")