# Implémentation UNet

## Importation des bibliothèques

In [None]:
# === Base Python ===
import os
import glob
import random
from pathlib import Path

# === Typage ===
from typing import Optional, Union, Tuple

# === NumPy / Math / Visualisation ===
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
%matplotlib inline

# === PIL (images) ===
from PIL import Image

# === PyTorch / Torchvision ===
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision import transforms

# === Monai (medical imaging) ===
import monai
from monai.transforms import LoadImage

# === SimpleITK ===
import SimpleITK as sitk

# === Affichage modèle ===
from torchinfo import summary

# === Barre de progression ===
from tqdm.auto import tqdm


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

## Prétraitement NII file

In [None]:
import os
import ants
import SimpleITK as sitk

def n4_bias_correct(path_in):
    img = sitk.ReadImage(path_in, sitk.sitkFloat32)
    mask = sitk.OtsuThreshold(img, 0, 1)
    corrector = sitk.N4BiasFieldCorrectionImageFilter()
    corrected = corrector.Execute(img, mask)
    return corrected

def skull_strip_from_T1(t1_sitk):
    arr = sitk.GetArrayFromImage(t1_sitk)
    mask = (arr > arr.mean()).astype("uint8")
    mask_img = sitk.GetImageFromArray(mask)
    mask_img.CopyInformation(t1_sitk)
    return mask_img

def preprocess_patient(patient_id, input_root="./data", output_root="./data_traite"):
    print(f"\n🧠 Traitement patient : {patient_id}")
    input_dir = os.path.join(input_root, patient_id)
    output_dir = os.path.join(output_root, patient_id)
    os.makedirs(output_dir, exist_ok=True)

    flair_path = os.path.join(input_dir, "3DFLAIR.nii")
    t1_path = os.path.join(input_dir, "3DT1.nii")
    consensus_path = os.path.join(input_dir, "Consensus.nii")

    if not (os.path.exists(flair_path) and os.path.exists(t1_path) and os.path.exists(consensus_path)):
        print("❌ Fichiers requis manquants")
        return

    # N4 + Masque cerveau
    t1_corrected = n4_bias_correct(t1_path)
    brain_mask = skull_strip_from_T1(t1_corrected)

    # ANTsPy : T1 → FLAIR
    t1_ants = ants.from_numpy(sitk.GetArrayFromImage(t1_corrected))
    t1_ants.set_spacing(t1_corrected.GetSpacing())
    flair_ants = ants.image_read(flair_path)

    reg = ants.registration(fixed=flair_ants, moving=t1_ants, type_of_transform="Affine")

    # Warp du masque cerveau
    mask_ants = ants.from_numpy(sitk.GetArrayFromImage(brain_mask))
    mask_ants.set_spacing(t1_corrected.GetSpacing())
    warped_mask = ants.apply_transforms(
        fixed=flair_ants,
        moving=mask_ants,
        transformlist=reg['fwdtransforms'],
        interpolator='nearestNeighbor'
    )

    warped_mask = warped_mask.threshold_image(0.5, 1.1, 1, 0)
    flair_stripped = flair_ants * warped_mask

    # Sauvegardes
    ants.image_write(flair_stripped, os.path.join(output_dir, "3DFLAIR_traite.nii"))
    os.system(f'cp "{consensus_path}" "{os.path.join(output_dir, "Consensus_traite.nii")}"')
    print("✅ Sauvegarde OK")

def preprocess_all_patients(input_root="./data", output_root="./data_traite"):
    os.makedirs(output_root, exist_ok=True)
    patients = [p for p in os.listdir(input_root) if os.path.isdir(os.path.join(input_root, p))]

    for patient_id in patients:
        preprocess_patient(patient_id, input_root, output_root)

    print("\n🎉 Tous les patients ont été traités.")

# ▶️ Lancer le traitement global
preprocess_all_patients(input_root="./data", output_root="./data_traite")


## Slicing Data into PNG file

In [None]:
# import os
# import SimpleITK as sitk
# import numpy as np
# from PIL import Image

# def preprocess_and_save_slices(input_root, output_root, resized_size=(256, 256)):
#     """
#     Coupe tous les fichiers .nii en slices .png et sauvegarde sur disque.
    
#     input_root: dossier contenant les sous-dossiers patients avec .nii
#     output_root: dossier où sauver les slices png
#     resized_size: taille des images sauvegardées
#     """
#     os.makedirs(output_root, exist_ok=True)

#     for subfolder in os.listdir(input_root):
#         folder_path = os.path.join(input_root, subfolder)
#         if not os.path.isdir(folder_path):
#             continue

#         flair_path = os.path.join(folder_path, "3DFLAIR.nii")
#         mask_path = os.path.join(folder_path, "Consensus.nii")

#         if not (os.path.exists(flair_path) and os.path.exists(mask_path)):
#             continue

#         flair_img = sitk.ReadImage(flair_path)
#         mask_img = sitk.ReadImage(mask_path)

#         flair_array = sitk.GetArrayFromImage(flair_img)  # [D, H, W]
#         mask_array = sitk.GetArrayFromImage(mask_img)    # [D, H, W]

#         patient_out_folder = os.path.join(output_root, subfolder)
#         images_folder = os.path.join(patient_out_folder, "images")
#         masks_folder = os.path.join(patient_out_folder, "masks")

#         os.makedirs(images_folder, exist_ok=True)
#         os.makedirs(masks_folder, exist_ok=True)

#         for idx in range(flair_array.shape[0]):
#             img_slice = flair_array[idx]
#             mask_slice = mask_array[idx]

#             img_slice = (img_slice - np.min(img_slice)) / (np.max(img_slice) - np.min(img_slice) + 1e-8)

#             img_pil = Image.fromarray((img_slice * 255).astype(np.uint8)).resize(resized_size)
#             mask_pil = Image.fromarray((mask_slice > 0).astype(np.uint8) * 255).resize(resized_size)

#             img_pil.save(os.path.join(images_folder, f"slice_{idx:04d}.png"))
#             mask_pil.save(os.path.join(masks_folder, f"slice_{idx:04d}.png"))

#     print(f"✅ Slicing terminé et sauvegardé sous {output_root}")

# # 🔥 Comment l'utiliser :
# input_root = "./data"  # Ton dossier de base contenant les IRMs
# output_root = "./pngData"  # Dossier où sauver les PNG

# preprocess_and_save_slices(input_root, output_root)


## Load PNG Data

In [None]:
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import os
import re

class PngSlicesDataset(Dataset):
    def __init__(self, root_dir, resized_width=256, resized_height=256, slice_range=None, transform=None):
        self.image_paths = []
        self.mask_paths = []
        self.resized_width = resized_width
        self.resized_height = resized_height
        self.slice_range = slice_range
        self.transform = transform if transform is not None else transforms.ToTensor()

        patients = os.listdir(root_dir)
        for patient in patients:
            images_folder = os.path.join(root_dir, patient, "images")
            masks_folder = os.path.join(root_dir, patient, "masks")

            if not os.path.isdir(images_folder) or not os.path.isdir(masks_folder):
                continue

            image_files = sorted(os.listdir(images_folder))

            for img_file in image_files:
                # Utiliser une expression régulière pour détecter slice_XXXX.png
                match = re.match(r"slice_(\d+)\.png", img_file)
                if not match:
                    continue  # Si ce n'est pas un slice_XXXX.png, on ignore

                slice_num = int(match.group(1))

                if self.slice_range:
                    if not (self.slice_range[0] <= slice_num <= self.slice_range[1]):
                        continue

                img_path = os.path.join(images_folder, img_file)
                mask_path = os.path.join(masks_folder, img_file)

                if os.path.exists(img_path) and os.path.exists(mask_path):
                    self.image_paths.append(img_path)
                    self.mask_paths.append(mask_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('L')
        mask = Image.open(self.mask_paths[idx]).convert('L')

        resize_transform = transforms.Resize((self.resized_height, self.resized_width))

        img = resize_transform(img)
        mask = resize_transform(mask)

        img = self.transform(img)
        mask = self.transform(mask)

        img = (img - 0.5) / 0.5  # Normaliser [-1,1]
        mask = (mask > 0.5).float()  # Binariser

        return img, mask, os.path.basename(self.image_paths[idx])


## Define train, test, validation dataset

### Here we get the initial image shapes

This value is important, as it will allow us to resize our entire dataset. We must also bear in mind that the size **must be a multiple of two**. This is because, through downsampling (by a factor of two), **we want natural numbers**. 

In [None]:
from PIL import Image
import os

# Dossier d'un patient
patient_folder = "./pngData/01016SACH"
images_folder = os.path.join(patient_folder, "images")
masks_folder = os.path.join(patient_folder, "masks")

# Liste des fichiers slices
image_files = sorted([f for f in os.listdir(images_folder) if f.endswith('.png')])
mask_files = sorted([f for f in os.listdir(masks_folder) if f.endswith('.png')])

# Charger un exemple pour connaître la taille
img = Image.open(os.path.join(images_folder, image_files[0]))
mask = Image.open(os.path.join(masks_folder, mask_files[0]))

# Obtenir largeur et hauteur
w, h = img.size
print((w, h, len(image_files)), "W, H, num_slices (à partir des PNG)")
print((w, h, len(mask_files)), "W, H, num_slices (masks PNG)")




In [None]:
def getNearestMultipleOfTwo(x):
    multipleOfTwo = [2**i for i in range(10)]  # [1, 2, 4, ..., 512]
    mini = float('inf')
    nearest_value = None
    
    for elem in multipleOfTwo:
        difference = abs(elem - x)
        if difference < mini:
            mini = difference
            nearest_value = elem
    
    return nearest_value

#print(getNearestMultipleOfTwo(70)) # -> 64
#print(getNearestMultipleOfTwo(150)) # -> 128
            

Here we define `width` and `height` of our __dataset__

In [None]:
width = getNearestMultipleOfTwo(w)
height = getNearestMultipleOfTwo(h)
slice_range = [145, 291]

ROOT_DIR = './pngData'

train_dataset = PngSlicesDataset(
    root_dir=ROOT_DIR,
    resized_width=width,
    resized_height=height,
    slice_range=slice_range
)

generator = torch.Generator().manual_seed(25)
train_dataset, test_dataset = random_split(train_dataset, [0.8, 0.2], generator=generator)
print("Size dataset :", len(train_dataset) + len(test_dataset))
print("Size train_dataset :", len(train_dataset))
print("Size test_dataset :", len(test_dataset))

test_dataset, val_dataset = random_split(test_dataset, [0.5, 0.5], generator=generator)
print("Size train_dataset :", len(train_dataset))
print("Size test_dataset :", len(test_dataset))
print("Size val_dataset :", len(val_dataset))

device = "cuda" if torch.cuda.is_available() else "cpu"
num_workers = torch.cuda.device_count() * 4 - 1 if device == "cuda" else 1

print("device:", device)
print("num_workers:", num_workers)

LEARNING_RATE = 3e-4
BATCH_SIZE = 8


### Data loader + his test


In [None]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset,batch_size=BATCH_SIZE,shuffle= False)
test_dataloader = DataLoader(dataset=test_dataset,batch_size=1, shuffle=False)

In [None]:
# Vérification du DataLoader
for img, mask, path in  test_dataloader:
    print(f'Image batch shape: {img.shape}')
    print(f'Mask batch shape: {mask.shape}')
    print(f'Image path: {path[0]}')  # Afficher un chemin d'image pour vérifier
    break  # Juste pour vérifier une première itération

## Visualize data samples

In [None]:
"""import torch
import numpy as np
import matplotlib.pyplot as plt

# Obtenir un index aléatoire pour chaque dataset
train_idx = np.random.randint(len(train_dataset))
val_idx = np.random.randint(len(val_dataset))
test_idx = np.random.randint(len(test_dataset))

print("train_idx:", train_idx)
print("val_idx:", val_idx)
print("test_idx:", test_idx)

def plot_slice(dataset, index, dataset_name):
    batch_data = dataset[index]
    image, label = batch_data[0].to(device), batch_data[1].to(device)

    # Conversion en numpy
    image = image.squeeze().detach().cpu().numpy()
    label = label.squeeze().detach().cpu().numpy()

    image = (image + 1) / 2.0  # Normaliser [0,1]

    fig, axes = plt.subplots(1, 2, figsize=(8, 4))

    axes[0].imshow(image, cmap='gray')
    axes[0].set_title(f"{dataset_name} - Image")
    axes[0].axis('off')

    axes[1].imshow(image, cmap='gray')
    axes[1].imshow(label, cmap='Reds', alpha=0.4)
    axes[1].set_title(f"{dataset_name} - Overlay")
    axes[1].axis('off')

    plt.tight_layout()
    plt.show()

# 🔥 Afficher séparément
plot_slice(train_dataset, train_idx, dataset_name="Training")
plot_slice(val_dataset, val_idx, dataset_name="Validation")
plot_slice(test_dataset, test_idx, dataset_name="Test")"""


## UNet Network Architecture

#### UNet params

In [None]:
# Model Architecture Parameters
input_channels = 1
num_classes  = 1      #  e.g. 1 for binary segmentation (background vs object)
input_shape = (input_channels, width, height)  # This is the shape of the input image to the network
output_shape = (num_classes, width, height)  # This is the shape of the output mask
init_channels = 32              # This is the number of channels in the first layer of the network

#### The UNet model 

In [None]:
from monai.networks.nets import UNet

# Define a Unet with  monai, lighter than Class Unet 
# Creation of the model U-Net with MONAI
model = UNet(
    spatial_dims = 2,                   # 2D U-Net
    in_channels = 1,                    # e.g. 3 for RGB, 1 for grayscale input images
    out_channels = 1,                   # 1 for binary segmentation
    channels = (16, 32, 64, 128, 256),  # Nombres de canaux aux différents niveaux d'encodage
    strides = (2, 2, 2, 2),             # Strides de downsampling (modèle aura 4 niveaux) 
    num_res_units=1                     # Nombre d'unités résiduelles par bloc
).to(device)


#### Loss functions

In [None]:
learning_rate = 0.001
n_epochs = 50 # ou 100 car les lésions sont petites

from monai.losses import DiceLoss

# Option 1 : Utiliser Dice Loss seule
criterion = DiceLoss(sigmoid=True)

# Option 2 (recommandée) : Combiner Dice + BCE
class ComboLoss(nn.Module):
    def __init__(self, dice_weight=0.7):
        super().__init__()
        self.dice = DiceLoss(sigmoid=True)
        self.bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([20.0]).to(device))  # 20 peut être ajusté
        self.dice_weight = dice_weight

    def forward(self, inputs, targets):
        dice_loss = self.dice(inputs, targets)
        bce_loss = self.bce(inputs, targets)
        return self.dice_weight * dice_loss + (1 - self.dice_weight) * bce_loss

criterion = ComboLoss(dice_weight=0.7)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

## Train the model

In [None]:
%%time
# Déplacement du modèle sur le device
model = model.to(device)

# Listes pour stocker les pertes
train_losses = []
valid_losses = []

best_metric = -1
best_valid_loss = float('inf')
best_model = None
best_epoch = 0

model.train()  # mettre en mode entraînement

for epoch in tqdm(range(n_epochs)):

    train_loss = 0.0
    model.train()

    ###################
    # Phase entraînement
    ###################
    for batch_data in tqdm(train_dataloader, position=0, leave=True):
        images, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()

        output = model(images)

        loss = criterion(output, labels)  # ici output brut pour la loss (pas de sigmoid ici)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss = train_loss / len(train_dataloader.dataset)
    train_losses.append(train_loss)

    ###################
    # Phase validation
    ###################
    model.eval()
    valid_loss = 0.0

    with torch.no_grad():
        for batch_data in tqdm(val_dataloader, position=0, leave=True):
            val_inputs, val_labels = batch_data[0].to(device), batch_data[1].to(device)

            val_outputs = model(val_inputs)
            val_loss = criterion(val_outputs, val_labels)  # pareil : output brut

            valid_loss += val_loss.item() * val_inputs.size(0)

        valid_loss = valid_loss / len(val_dataloader.dataset)
        valid_losses.append(valid_loss)

        print(f'Epoch: {epoch+1} \tTraining Loss: {train_loss:.4f} \tValidation Loss: {valid_loss:.4f}')

        # Sauvegarde du meilleur modèle
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            best_model = model.state_dict()
            best_epoch = epoch + 1

print(f"✅ Best model selected at epoch {best_epoch} with validation loss: {best_valid_loss:.4f}")


#### We save the model

In [None]:
model.load_state_dict(best_model)
torch.save(best_model, 'best_model_3.pth')  # Save the best model
print(f"Best model selected at epoch {best_epoch} with validation loss: {best_valid_loss:.4f}")

## Display Train Curves 

In [None]:
# Plot loss curves
plt.figure(figsize=(6, 6))

# Plotting Training and Validation Loss
plt.subplot(1, 1, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(valid_losses, label='Validation Loss')
plt.title('Loss Curve')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

## Load and test the "best" model

In [None]:
# To load and use the best model

input_shape = (1, width, height)  # This is the shape of the input image to the network
num_classes = 1  # This is the number of output classes
output_shape = (num_classes, width, height)  # This is the shape of the output mask
init_channels = 32  # This is the number of channels in the first layer of the network

#model = UNet(input_shape=input_shape, output_shape=output_shape, init_channels=init_channels).to(device)

model_weights_path = "best_model_2.pth"
model.load_state_dict(torch.load(model_weights_path))

In [None]:
num_workers = os.cpu_count() - 1
print("num_worker s=",num_workers)

# Create test_dataloader2, batch_size=1, just for the display of the following cell
test_dataloader2 = DataLoader(
    test_dataset, 
    batch_size=1, 
    pin_memory=torch.cuda.is_available(), 
    shuffle=False
)

# Create an iterator to iterate over the test dataloader
test_dataloader_iter = iter(test_dataloader2)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from monai.networks.nets import UNet
from sklearn.metrics import jaccard_score
from monai.metrics import DiceMetric, MeanIoU

# --- METRICS MONAI ---
dice_metric = DiceMetric(include_background=False, reduction="mean")
iou_metric = MeanIoU(include_background=False)

# --- METRICS numpy ---
def dice_score(pred, label):
    pred = pred.astype(np.bool_)
    label = label.astype(np.bool_)
    intersection = np.logical_and(pred, label).sum()
    return 2. * intersection / (pred.sum() + label.sum() + 1e-8)

def iou_score(pred, label):
    pred = pred.astype(np.bool_)
    label = label.astype(np.bool_)
    intersection = np.logical_and(pred, label).sum()
    union = np.logical_or(pred, label).sum()
    return intersection / (union + 1e-8)

# --- FONCTION D'ÉVALUATION ---
def evaluate_model_on_batch(model, dataloader, device):
    model = model.to(device)
    model.eval()

    dice_metric.reset()
    iou_metric.reset()

    batch_data = next(iter(dataloader))  # prend un batch

    with torch.no_grad():
        images, labels, names = batch_data  # attention ici : dataset retourne aussi les noms !
        images = images.to(device)
        labels = labels.to(device)

        # prendre seulement la première image du batch
        image = images[0].unsqueeze(0)  # [1, 1, H, W]
        label = labels[0].unsqueeze(0)  # [1, 1, H, W]

        print("Image shape:", image.shape)

        preds = model(image)

        # Appliquer sigmoid + threshold
        preds = torch.sigmoid(preds)
        preds = (preds > 0.5).float()

        # Calcul MONAI metrics
        dice_metric(preds, label)
        iou_metric(preds, label)
        dice_metric_result = dice_metric.aggregate().item()
        iou_metric_result = iou_metric.aggregate().item()

        # --- Conversion numpy pour affichage ---
        image_np = image.squeeze().detach().cpu().numpy()
        label_np = label.squeeze().detach().cpu().numpy()
        preds_np = preds.squeeze().detach().cpu().numpy()

        # Normaliser l'image [-1,1] → [0,1]
        image_np = (image_np + 1) / 2.0

        # Calcul numpy metrics
        dice = dice_score(preds_np, label_np)
        iou = iou_score(preds_np, label_np)

    # --- AFFICHAGE ---
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))

    axes[0, 0].imshow(image_np, cmap="gray")
    axes[0, 0].set_title("Image d'origine")
    axes[0, 0].axis("off")

    axes[0, 1].imshow(label_np, cmap="jet")
    axes[0, 1].set_title("Mask réel")
    axes[0, 1].axis("off")

    axes[0, 2].imshow(preds_np, cmap="jet")
    axes[0, 2].set_title(f"Prédiction\nDice: {dice:.3f} | IoU: {iou:.3f}")
    axes[0, 2].axis("off")

    axes[1, 0].imshow(image_np, cmap="gray")
    axes[1, 0].set_title("Image d'origine")
    axes[1, 0].axis("off")

    axes[1, 1].imshow(image_np, cmap="gray")
    axes[1, 1].imshow(label_np, cmap="Reds", alpha=0.4)
    axes[1, 1].set_title("Overlay Ground Truth")
    axes[1, 1].axis("off")

    axes[1, 2].imshow(image_np, cmap="gray")
    axes[1, 2].imshow(preds_np, cmap="Blues", alpha=0.4)
    axes[1, 2].set_title("Overlay Prédiction")
    axes[1, 2].axis("off")

    plt.tight_layout()
    plt.show()

    print(f"✅ Dice (MONAI) : {dice_metric_result:.3f}")
    print(f"✅ IoU  (MONAI) : {iou_metric_result:.3f}")
    print(f"✅ Dice (NumPy) : {dice:.3f}")
    print(f"✅ IoU  (NumPy) : {iou:.3f}")

# --- UTILISATION ---
# Exemple :
# model = UNet(...)  # Ton modèle déjà chargé
# test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=1)

evaluate_model_on_batch(model, test_dataloader, device)
