In [1]:
import os
DATASET_DIR = os.getenv("DATASET_DIR", "/workspace/data")
CONFIG_DIR = os.getenv("CONFIG_DIR", "/workspace/configs")


In [2]:
from p9dg.histo_dataset import HistoDataset
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

In [3]:
# ------------------------------
# 1️⃣ Config
# ------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs("artifacts", exist_ok=True)

In [4]:
# ------------------------------
# 2️⃣ Dataset & DataLoader
# ------------------------------
train_ds = HistoDataset(
    root_data="/workspace/data",
    split="train",
    output_size=256, # taille initiale 244 !
    pixel_range="imagenet",
    balance_per_class=True, # ✅ essentiel pour l'échantillonage
    thresholds_json_path="configs/seuils_par_classe.json",
    vahadane_enable=True,
    vahadane_device=DEVICE,
    samples_per_class_per_epoch=200 # 200 (raisonnable)
)

val_ds = HistoDataset(
    root_data="/workspace/data",
    split="val",
    output_size=256,
    pixel_range="imagenet",
    balance_per_class=True,
    thresholds_json_path="configs/seuils_par_classe.json",
    vahadane_enable=True,
    vahadane_device=DEVICE,
    samples_per_class_per_epoch=50
)

🎨 Référence Vahadane fixée : TUM-RQEVGAED.tif
🎨 Référence Vahadane auto: TUM-RQEVGAED.tif
✅ Seuils par classe chargés depuis : /workspace/configs/seuils_par_classe.json
🎨 Référence Vahadane fixée : TUM-TCGA-TWCEHKLC.tif
🎨 Référence Vahadane auto: TUM-TCGA-TWCEHKLC.tif
✅ Seuils par classe chargés depuis : /workspace/configs/seuils_par_classe.json


In [5]:
print(f"Train set size: {len(train_ds)} images")
print(f"Val set size: {len(val_ds)} images")

Train set size: 1800 images
Val set size: 450 images


In [6]:
# from collections import Counter
# import matplotlib.pyplot as plt

# # Comptage des classes dans le dataset d'entraînement
# cls_counts = Counter([y for _, y, _ in [train_ds[i] for i in range(len(train_ds))]])
# cls_names = [train_ds.idx_to_class[c] for c in cls_counts.keys()]
# cls_values = list(cls_counts.values())

# print(f"🧩 Nombre total d'images dans train_ds : {len(train_ds)}")
# for name, count in zip(cls_names, cls_values):
#     print(f"  - {name}: {count}")

# # Petit histogramme
# plt.figure(figsize=(6,3))
# plt.bar(cls_names, cls_values, color="cornflowerblue")
# plt.title("Distribution des classes (train)")
# plt.xticks(rotation=45, ha="right")
# plt.tight_layout()
# plt.show()


In [7]:
# Sélection du modèle : "mobilenetv2" ou "resnet18"
MODEL_NAME = "mobilenetv2"  # ← par défaut
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

num_classes = len(train_ds.class_to_idx)
print(f"🧠 Nombre de classes : {num_classes}")

if MODEL_NAME == "mobilenetv2":
    model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
    model.classifier[1] = nn.Linear(model.last_channel, num_classes)
elif MODEL_NAME == "resnet18":
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
else:
    raise ValueError("Modèle non reconnu")

model = model.to(DEVICE)
print(f"✅ Modèle {MODEL_NAME} initialisé sur {DEVICE}")


🧠 Nombre de classes : 9
✅ Modèle mobilenetv2 initialisé sur cuda


In [8]:
# Préparation de l'entrainement
# Hyperparamètres
BATCH_SIZE = 16
EPOCHS = 30
LR = 1e-3
PATIENCE = 6  # early stopping
NUM_WORKERS = 4

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True) # pin_memory transfer plus rapide vers GPU
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5,
    patience=2, threshold=1e-3, cooldown=1, verbose=True
)

print("🧩 Dataloaders prêts.")


🧩 Dataloaders prêts.




In [9]:
# Vidage RAM GPU
torch.cuda.empty_cache()

In [None]:
# Boucle d'entraînement
# Adaptée pour les petits jeux de données :
# ✅ validation lissée (EMA) pour réduire les sauts aléatoires,
# ✅ early stopping plus stable (warm-up + min_delta + patience étendue),
# ✅ scheduler avec seuil et cooldown pour éviter les réductions de LR prématurées,
# ✅ gradient clipping pour stabiliser les phases de descente.


from collections import deque
import numpy as np

# ---- Hyperparamètres early stopping & stabilité ----
WARMUP_EPOCHS = 3          # pas d'ES avant ce nombre d'époques
ES_PATIENCE   = 6          # nombre d'époques sans amélioration avant arrêt
ES_MIN_DELTA  = 1e-3       # amélioration minimale pour reset patience
best_val_ema  = float("inf")
patience_counter = 0
val_ema = None             # moyenne mobile de la val_loss
DELTA_POS_PATIENCE = 2
delta_pos_counter = 0

# ---- Lissage EMA ----
def ema(prev, new, alpha=0.3):
    return new if prev is None else (alpha * new + (1 - alpha) * prev)

# ---- Boucle d'entraînement ----
train_losses, val_losses = [], []
scaler = torch.amp.GradScaler('cuda')

for epoch in range(EPOCHS):
    if epoch % 10 == 0:
        torch.cuda.empty_cache()
        print("🧹 Cache GPU vidé")
    train_ds.set_epoch(epoch)
    print(f"\n=== Epoch {epoch+1}/{EPOCHS} ===")

    # ---- TRAIN ----
    model.train()
    running_loss = 0.0

    for imgs, labels, _ in tqdm(train_loader, desc="Train", leave=False):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()

        with torch.amp.autocast('cuda'):
            outputs = model(imgs)
            loss = criterion(outputs, labels)

        # Backprop + clipping + update
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * imgs.size(0)

    train_loss = running_loss / len(train_ds)
    train_losses.append(train_loss)

    # ---- VALIDATION ----
    model.eval()
    val_running_loss = 0.0
    y_true, y_pred = [], []

    with torch.no_grad():
        for imgs, labels, _ in tqdm(val_loader, desc="Val", leave=False):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            with torch.amp.autocast('cuda'):
                outputs = model(imgs)
                loss = criterion(outputs, labels)

            val_running_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(dim=1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    val_loss = val_running_loss / len(val_loader.dataset)
    val_losses.append(val_loss)

    # ---- Lissage & scheduler ----
    val_ema = ema(val_ema, val_loss, alpha=0.3)
    scheduler.step(val_ema)  # scheduler basé sur la perte lissée

    # ---- Affichage résumé ----
    delta_raw = val_loss - train_loss
    delta_ema = val_ema - train_loss
    current_lr = optimizer.param_groups[0]['lr']
    print(f"LR: {current_lr:.2e}")
    print(f"📊 Epoch {epoch+1:02d}/{EPOCHS} | "
          f"Train: {train_loss:.4f} | Val: {val_loss:.4f} | "
          f"Val(EMA): {val_ema:.4f} | Δraw={delta_raw:+.4f} | Δema={delta_ema:+.4f}")

    # --- Surapprentissage : contrôle du delta positif ---
    if delta_ema > 0:
        delta_pos_counter += 1
        if delta_pos_counter >= DELTA_POS_PATIENCE:
            print(f"⚠️ Δema > 0 sur {DELTA_POS_PATIENCE} époques consécutives → arrêt pour surapprentissage.")
            break
    else:
        delta_pos_counter = 0


    # ---- Early Stopping (avec warmup & min_delta) ----
    if epoch + 1 <= WARMUP_EPOCHS:
        improved = False  # on attend avant de juger
    else:
        improved = (best_val_ema - val_ema) > ES_MIN_DELTA

    if improved:
        best_val_ema = val_ema
        patience_counter = 0
        torch.save(model.state_dict(), f"artifacts/{MODEL_NAME}_best.pt")
    else:
        patience_counter += 1
        if patience_counter >= ES_PATIENCE:
            print("⏸️ Early stopping déclenché (EMA).")
            break


🧹 Cache GPU vidé

=== Epoch 1/30 ===


Train:  24%|██████████████████████▉                                                                         | 27/113 [02:34<07:56,  5.54s/it]

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(7,4))
plt.plot(train_losses, label="Train", marker="o")
plt.plot(val_losses, label="Validation", marker="s")
plt.title(f"Courbe de perte ({MODEL_NAME})")
plt.xlabel("Épochs")
plt.ylabel("Loss")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()


In [None]:
import torch
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# 1️⃣ Chargement du meilleur modèle
best_path = f"artifacts/{MODEL_NAME}_best.pt"
state_dict = torch.load(best_path, map_location=DEVICE, weights_only=True)
model.load_state_dict(state_dict)
model.eval()
print(f"✅ Meilleur modèle chargé depuis : {best_path}")


In [None]:
# 2️⃣ Prédiction sur le jeu de validation
y_true, y_pred = [], []

with torch.no_grad():
    for imgs, labels, _ in tqdm(val_loader, desc="Évaluation finale"):
        imgs = imgs.to(DEVICE)
        with torch.amp.autocast('cuda'):
            outputs = model(imgs)
        preds = outputs.argmax(dim=1)
        y_true.extend(labels.numpy())
        y_pred.extend(preds.cpu().numpy())


In [None]:
# 3️⃣ Rapport de classification (macro-F1, recall, précision)
classes = list(train_ds.class_to_idx.keys())

print("\n=== Rapport de classification ===")
print(classification_report(
    y_true, y_pred,
    target_names=classes,
    digits=3,
    zero_division=0
))


In [None]:
# 4️⃣ Matrice de confusion normalisée

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Matrice de confusion (valeurs absolues)
cm = confusion_matrix(y_true, y_pred)  # 👈 pas de normalize ici

plt.figure(figsize=(8,6))
sns.heatmap(
    cm, annot=True, fmt="d", cmap="Blues",
    xticklabels=classes, yticklabels=classes
)
plt.xlabel("Prédit")
plt.ylabel("Réel")
plt.title(f"Matrice de confusion (effectifs absolus) — {MODEL_NAME}")
plt.tight_layout()
plt.show()



In [None]:
# 5️⃣ Sauvegarde automatique des résultats
plt.savefig(f"artifacts/confusion_matrix_{MODEL_NAME}.png")

# Sauvegarde du rapport texte
from pathlib import Path
from sklearn.metrics import classification_report

report = classification_report(
    y_true, y_pred,
    target_names=classes,
    digits=3,
    zero_division=0
)
Path("artifacts").mkdir(exist_ok=True)
(Path("artifacts") / f"classification_report_{MODEL_NAME}.txt").write_text(report)
print("📁 Résultats enregistrés dans artifacts/")