# ✅ P9 — Test rapide du DataGenerator (HistoDataset)

Ce notebook vérifie :
1. Chargement et **shape** des batches
2. **Comptes par classe** et politique de split sans fuite
3. **Filtre qualité** (scores + seuils JSON par classe)
4. **Normalisation Vahadane** via `torch_staintools` (comparaison ON/OFF)
5. **Visualisation** (grille d’images) et **débit** du DataLoader
6. **Non-réutilisation** des échantillons en **val/test**

> **Pré-requis** :
- Le package `p9dg/` est à la racine de votre `WORKDIR` (ex: `/workspace`).
- Les datasets sont montés sous `/data` avec les dossiers :
  - `/data/NCT-CRC-HE-100K/` (train)
  - `/data/CRC-VAL-HE-7K/` (val/test)
- Le fichier `configs/seuils_par_classe.json` est présent (facultatif, des défauts sont prévus).


In [1]:
import os, time, random, numpy as np
import torch
from torch.utils.data import DataLoader
from PIL import Image
from p9dg.histo_dataset import HistoDataset, BalancedRoundRobinSampler

ROOT_DATA = "/data"  # adapte si besoin
THRESHOLDS_JSON = "configs/seuils_par_classe.json"  # ou chemin absolu

BATCH_SIZE = 64
IMG_SIZE = 256
NUM_WORKERS = 4  # 0 pour debug
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

SyntaxError: invalid syntax (histo_dataset.py, line 329)

## 1) Instanciation des datasets (train/val/test)

In [None]:
ds_train = HistoDataset(root_data=ROOT_DATA, split="train", output_size=IMG_SIZE,
                        thresholds_json_path=THRESHOLDS_JSON,
                        vahadane_enable=True, vahadane_device=("cuda" if torch.cuda.is_available() else "cpu"))
ds_val   = HistoDataset(root_data=ROOT_DATA, split="val",   output_size=IMG_SIZE,
                        thresholds_json_path=THRESHOLDS_JSON,
                        vahadane_enable=True, vahadane_device=("cuda" if torch.cuda.is_available() else "cpu"))
ds_test  = HistoDataset(root_data=ROOT_DATA, split="test",  output_size=IMG_SIZE,
                        thresholds_json_path=THRESHOLDS_JSON,
                        vahadane_enable=True, vahadane_device=("cuda" if torch.cuda.is_available() else "cpu"))

print("Classes (train):", ds_train.class_counts())
print("Classes (val):  ", ds_val.class_counts())
print("Classes (test): ", ds_test.class_counts())
len(ds_train), len(ds_val), len(ds_test)

## 2) Visualisation rapide (grille)

In [None]:
from IPython.display import display
grid = ds_train.vis(16)
display(grid)

## 3) Sanity check d’un batch + plages de pixels

In [None]:
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=False,
                      num_workers=NUM_WORKERS, pin_memory=True,
                      persistent_workers=(NUM_WORKERS>0))
xb, yb, pb = next(iter(dl_train))
print("xb:", xb.shape, xb.dtype, float(xb.min()), float(xb.max()))
print("yb:", yb.shape, yb.dtype)
print("ex path:", pb[0])

## 4) Débit DataLoader (it/s)

In [None]:
it = iter(dl_train)
t0 = time.time(); n=0
for _ in range(10):
    try:
        xb, yb, pb = next(it); n+=1
    except StopIteration:
        break
dt = time.time() - t0
print(f"{n} batches en {dt:.2f}s → {n/dt:.2f} it/s")

## 5) Non-réutilisation en val/test (unicité des paths)

In [None]:
def unique_ratio(ds):
    seen = set()
    for i in range(len(ds)):
        _,_,p = ds[i]
        seen.add(p)
    return len(seen)/len(ds)

print("val unique ratio:", unique_ratio(ds_val))
print("test unique ratio:", unique_ratio(ds_test))

## 6) Inspection filtre qualité (scores + seuils JSON)

In [None]:
import random
i = random.randrange(len(ds_train))
x, y, p = ds_train[i]
cls = ds_train.idx_to_class[int(y)]
img = Image.open(p).convert("RGB")
metrics = ds_train.qf.score(img)
thr = ds_train.class_thresholds.get(cls, {})
print("class:", cls)
print("path:", p)
print("metrics:")
for k in sorted(metrics.keys()):
    print(f"  {k:16s} = {metrics[k]:.4f}")
print("\nthresholds (JSON override):")
for k,v in thr.items():
    print(f"  {k:16s} = {v}")

## 7) Vahadane ON/OFF — comparaison visuelle (optionnel)

In [None]:
from IPython.display import display
ds_no_stain = HistoDataset(root_data=ROOT_DATA, split="train", output_size=IMG_SIZE,
                           thresholds_json_path=THRESHOLDS_JSON,
                           vahadane_enable=False)
grid_off = ds_no_stain.vis(16)
grid_on  = ds_train.vis(16)
print("Normalisation OFF:")
display(grid_off)
print("Normalisation ON:")
display(grid_on)

## 8) Compat rapid-test avec MobileNetV2 (1 batch forward)

In [None]:
import torchvision as tv, torch.nn as nn
model = tv.models.mobilenet_v2(weights="IMAGENET1K_V1")
model.classifier[1] = nn.Linear(model.last_channel, len(ds_train.class_to_idx))
model.to(DEVICE).eval()
xb, yb, _ = next(iter(dl_train))
with torch.no_grad():
    out = model(xb.to(DEVICE))
print("Forward OK:", out.shape)