In [None]:
import numpy as np, random, os
import torch
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"✅ Seed={SEED} | Device={DEVICE}")
torch.cuda.empty_cache() if DEVICE=="cuda" else None

In [None]:
from pathlib import Path

PROJECT_ROOT  = Path(os.getenv("PROJECT_ROOT", "/workspace")).resolve()
DATA_ROOT     = Path(os.getenv("DATA_ROOT", "/workspace/data")).resolve()
OUTPUTS_DIR   = PROJECT_ROOT / "outputs"
ARTIFACTS_DIR = PROJECT_ROOT / "artifacts"
CONFIG_DIR    = PROJECT_ROOT / "configs"
SAMPLES_DIR   = PROJECT_ROOT / "samples"
CKPT_DIR      = PROJECT_ROOT / "checkpoints"
METRICS_DIR   = PROJECT_ROOT / "metrics/gan_metrics"
for d in [OUTPUTS_DIR, ARTIFACTS_DIR, CONFIG_DIR, SAMPLES_DIR, CKPT_DIR, METRICS_DIR]:
    d.mkdir(parents=True, exist_ok=True)

IMAGE_SIZE = 256
BATCH_SIZE = 4
NUM_WORKERS = 0
PIN_MEMORY = (DEVICE=="cuda")
DROP_LAST = True
SHUFFLE = True

# Classes ciblées (utilise le filtrage du HistoDataset)
SELECTED_CLASSES = ["TUM", "STR", "NORM"]  # mets None pour toutes
print("CONFIG_DIR:", CONFIG_DIR)


In [None]:
from p9dg.histo_dataset import HistoDataset
from torch.utils.data import DataLoader, Subset

PIXEL_RANGE = "-1_1"
SAMPLES_PER_CLASS = 300
VAHADANE_ENABLE = True

ds_gan = HistoDataset(
    root_data=str(DATA_ROOT),
    split="train",
    output_size=IMAGE_SIZE,
    pixel_range=PIXEL_RANGE,
    balance_per_class=True,
    samples_per_class_per_epoch=SAMPLES_PER_CLASS,
    vahadane_enable=VAHADANE_ENABLE,
    vahadane_device=DEVICE,
    thresholds_json_path=str(CONFIG_DIR / "seuils_par_classe.json"),
    return_labels=True,
    classes=SELECTED_CLASSES,
)
print("✅ Dataset cGAN initialisé")
print("Classes retenues:", ds_gan.class_counts())
num_classes = len(ds_gan.class_counts())
print("num_classes =", num_classes)

loader_gan = DataLoader(
    dataset=ds_gan,
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    drop_last=DROP_LAST,
)
xb, yb, _ = next(iter(loader_gan))
print(f"Batch: x={tuple(xb.shape)}, y={tuple(yb.shape)} → min={xb.min():.3f}, max={xb.max():.3f}")

# held-out pour métriques
import numpy as np
HELDOUT_IDX_PATH = METRICS_DIR / "duet_eval_idx.npy"
N_EVAL_POOL = 2000
if not HELDOUT_IDX_PATH.exists():
    rng = np.random.default_rng(42)
    all_idx = np.arange(len(ds_gan))
    pick = min(N_EVAL_POOL, len(all_idx))
    eval_idx = rng.choice(all_idx, size=pick, replace=False)
    np.save(HELDOUT_IDX_PATH, eval_idx)
else:
    eval_idx = np.load(HELDOUT_IDX_PATH)

class StripXY(torch.utils.data.Dataset):
    def __init__(self, base): self.base=base
    def __len__(self): return len(self.base)
    def __getitem__(self, i): x, y, p = self.base[i]; return x

ds_eval = Subset(ds_gan, eval_idx.tolist())
loader_real_eval = DataLoader(StripXY(ds_eval), batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
print("✅ loader_real_eval prêt:", len(ds_eval))


In [None]:
# PathoDuet extractor (gelé) basé sur timm (DeiT-B/16 224), chargement ckpt HE
import re, torch
import torch.nn as nn
import torch.nn.functional as F
try:
    import timm
except Exception as e:
    raise RuntimeError("Installe 'timm' (pip install timm)")

CKPT_PATH = PROJECT_ROOT / "models/checkpoint_HE.pth"
assert CKPT_PATH.exists(), f"Checkpoint introuvable: {CKPT_PATH}"

def load_state_dict_safely(path):
    sd_raw = None
    try:
        obj = torch.load(path, map_location="cpu", weights_only=True)
        if isinstance(obj, dict) and all(hasattr(v, "shape") for v in obj.values()):
            sd_raw = obj
        elif isinstance(obj, dict):
            for k in ("state_dict","model","module"):
                if k in obj and isinstance(obj[k], dict) and all(hasattr(v, "shape") for v in obj[k].values()):
                    sd_raw = obj[k]; break
    except TypeError:
        obj = torch.load(path, map_location="cpu")
        if isinstance(obj, dict) and all(hasattr(v, "shape") for v in obj.values()):
            sd_raw = obj
        elif isinstance(obj, dict):
            for k in ("state_dict","model","module"):
                if k in obj and isinstance(obj[k], dict) and all(hasattr(v, "shape") for v in obj[k].values()):
                    sd_raw = obj[k]; break
    if sd_raw is None:
        raise RuntimeError("state_dict non trouvé dans le ckpt")
    clean_sd = {re.sub(r"^(module\.|model\.)","",k): v for k, v in sd_raw.items()}
    # retirer distillation + réajuster pos_embed 198→197
    clean_sd.pop("dist_token", None)
    for k in list(clean_sd.keys()):
        if k.startswith("head_dist."): clean_sd.pop(k)
    if "pos_embed" in clean_sd and clean_sd["pos_embed"].shape[1] == 198:
        pos = clean_sd["pos_embed"]
        clean_sd["pos_embed"] = torch.cat([pos[:, :1, :], pos[:, 2:, :]], dim=1)
    for k in list(clean_sd.keys()):
        if k.startswith("head."): clean_sd.pop(k)
    return clean_sd

state_dict = load_state_dict_safely(CKPT_PATH)
backbone = timm.create_model("deit_base_patch16_224", pretrained=False, num_classes=0, img_size=224, global_pool="avg")
missing, unexpected = backbone.load_state_dict(state_dict, strict=False)
for p in backbone.parameters(): p.requires_grad_(False)
backbone.eval()

class PathoDuetExtractor(nn.Module):
    def __init__(self, bb, in_size=224):
        super().__init__(); self.bb=bb; self.in_size=in_size
    @torch.no_grad()
    def forward(self, x_m11):
        x01 = (x_m11.clamp(-1,1) + 1)*0.5
        if x01.shape[-2:] != (self.in_size, self.in_size):
            x01 = F.interpolate(x01, size=(self.in_size,self.in_size), mode="bilinear", align_corners=False)
        return self.bb(x01)

pathoduet = PathoDuetExtractor(backbone).to(DEVICE)
with torch.no_grad():
    f = pathoduet(xb.to(DEVICE)[:2])
print("PathoDuet features:", tuple(f.shape))


In [None]:
# cGAN: Générateur (StyleGAN-lite) conditionnel (embedding classe → mapping)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as nn_utils

class MappingNetwork(nn.Module):
    def __init__(self, z_dim=512, w_dim=512, n_layers=8):
        super().__init__()
        layers = []
        dim = z_dim
        for _ in range(n_layers):
            layers += [nn.Linear(dim, w_dim), nn.LeakyReLU(0.2, inplace=True)]
            dim = w_dim
        self.mapping = nn.Sequential(*layers)
    def forward(self, z):
        z = z / (z.norm(dim=1, keepdim=True) + 1e-8)
        return self.mapping(z)

class ModulatedConv2d(nn.Module):
    def __init__(self, in_ch, out_ch, kernel, style_dim, demod=True, up=False):
        super().__init__(); self.up=up; self.demod=demod; self.eps=1e-8; self.pad=kernel//2
        self.weight = nn.Parameter(torch.randn(1, out_ch, in_ch, kernel, kernel))
        self.style = nn.Linear(style_dim, in_ch)
    def forward(self, x, s):
        if self.up:
            x = F.interpolate(x, scale_factor=2, mode="nearest")
        b, c, h, w = x.shape
        w1 = self.style(s).view(b, 1, c, 1, 1)
        w2 = self.weight * (w1 + 1)
        if self.demod:
            d = torch.rsqrt((w2 ** 2).sum([2,3,4]) + self.eps)
            w2 = w2 * d.view(b, -1, 1, 1, 1)
        x = x.view(1, -1, h, w)
        w2 = w2.view(b * w2.size(1), w2.size(2), w2.size(3), w2.size(4))
        out = F.conv2d(x, w2, padding=self.pad, groups=b)
        return out.view(b, -1, out.shape[-2], out.shape[-1])

class Generator(nn.Module):
    def __init__(self, z_dim=512, w_dim=512, img_res=256, fmap_base=256, num_classes=9):
        super().__init__()
        self.z_dim = z_dim
        self.embed = nn.Embedding(num_classes, z_dim)
        nn.init.normal_(self.embed.weight, std=0.02)
        self.mapping = MappingNetwork(z_dim, w_dim)
        self.const = nn.Parameter(torch.randn(1, fmap_base, 4, 4))
        self.blocks = nn.ModuleList()
        in_ch = fmap_base; res = 4
        while res < img_res:
            out_ch = max(fmap_base // max(res // 8, 1), 64)
            self.blocks.append(nn.ModuleList([
                ModulatedConv2d(in_ch, out_ch, 3, w_dim, up=True), nn.LeakyReLU(0.2, inplace=True),
                ModulatedConv2d(out_ch, out_ch, 3, w_dim, up=False), nn.LeakyReLU(0.2, inplace=True),
            ]))
            in_ch = out_ch; res *= 2
        self.to_rgb = nn.Conv2d(in_ch, 3, 1)
    def forward(self, z, y):
        zc = z + self.embed(y)
        w = self.mapping(zc)
        x = self.const.repeat(z.size(0), 1, 1, 1)
        for (m1,a1,m2,a2) in self.blocks:
            x = m1(x, w); x = a1(x); x = m2(x, w); x = a2(x)
        return torch.tanh(self.to_rgb(x))

class Discriminator(nn.Module):
    def __init__(self, ch=64):
        super().__init__()
        def C(i,o,k=3,s=2,p=1): return nn_utils.spectral_norm(nn.Conv2d(i,o,k,s,p))
        self.body = nn.Sequential(
            C(3, ch,3,2,1), nn.LeakyReLU(0.2, inplace=True),
            C(ch, ch*2,3,2,1), nn.LeakyReLU(0.2, inplace=True),
            C(ch*2, ch*4,3,2,1), nn.LeakyReLU(0.2, inplace=True),
            C(ch*4, ch*8,3,2,1), nn.LeakyReLU(0.2, inplace=True),
            C(ch*8, ch*8,3,2,1), nn.LeakyReLU(0.2, inplace=True),
            C(ch*8, ch*8,3,2,1), nn.LeakyReLU(0.2, inplace=True),
        )
        self.conv_last = nn_utils.spectral_norm(nn.Conv2d(ch*8, ch*8, 4, 1, 0))
        self.to_logit  = nn_utils.spectral_norm(nn.Conv2d(ch*8, 1, 1, 1, 0))
        self.out_ch = ch*8
    def forward(self, x):
        h = self.body(x)
        h = self.conv_last(h)
        logit_base = self.to_logit(h).view(x.size(0), 1)
        feat = h.view(x.size(0), self.out_ch)
        return logit_base, feat

class CombinedCondDiscriminator(nn.Module):
    def __init__(self, base, num_classes, duet_extractor, alpha=0.0, t_duet=2.0):
        super().__init__()
        self.base = base
        self.embed = nn.Embedding(num_classes, base.out_ch)
        nn.init.normal_(self.embed.weight, std=0.02)
        self.duet = duet_extractor
        self.duet_head = nn.Linear(768, 1)
        self.alpha = float(alpha); self.t_duet=float(t_duet)
    def forward(self, x_m11, y):
        logit_base, feat = self.base(x_m11)
        wy = self.embed(y)
        logit_proj = (feat * wy).sum(dim=1, keepdim=True)
        with torch.no_grad():
            f_duet = self.duet(x_m11)
        logit_duet = self.duet_head(f_duet)
        if self.t_duet>1.0:
            logit_duet = logit_duet / self.t_duet
        return logit_base + logit_proj + self.alpha * logit_duet

Z_DIM = 512
G = Generator(z_dim=Z_DIM, w_dim=Z_DIM, img_res=IMAGE_SIZE, fmap_base=256, num_classes=num_classes).to(DEVICE)
D_base = Discriminator(ch=64).to(DEVICE)
D = CombinedCondDiscriminator(D_base, num_classes=num_classes, duet_extractor=pathoduet, alpha=0.0, t_duet=2.0).to(DEVICE)

with torch.no_grad():
    z = torch.randn(4, Z_DIM, device=DEVICE)
    y = torch.randint(0, num_classes, (4,), device=DEVICE)
    fake = G(z, y)
    out = D(fake, y)
print("Smoke: fake=", tuple(fake.shape), "D(out)=", tuple(out.shape))


In [None]:
# Entraînement NS-GAN + R1 + EMA + ADA minimal, FID PathoDuet conditionnel
import time
import torchvision.utils as vutils
from csv import DictWriter
from contextlib import nullcontext

def requires_grad(m, flag=True):
    for p in m.parameters(): p.requires_grad_(flag)

@torch.no_grad()
def ema_update(ema_m, m, decay=0.9995):
    for p_ema, p in zip(ema_m.parameters(), m.parameters()):
        p_ema.data.mul_(decay).add_(p.data, alpha=1.0-decay)

def d_logistic_loss(r, f):
    return torch.nn.functional.softplus(-r).mean() + torch.nn.functional.softplus(f).mean()
def g_nonsat_loss(f):
    return torch.nn.functional.softplus(-f).mean()
def r1_penalty(x, pred):
    grad = torch.autograd.grad(outputs=pred.sum(), inputs=x, create_graph=True, retain_graph=True, only_inputs=True)[0]
    return grad.pow(2).flatten(1).sum(1).mean()

def ada_augment(x, p, translate=0.04):
    if p<=0: return x
    if torch.rand(1, device=x.device).item() < p:
        x = torch.flip(x, dims=[3])
    if torch.rand(1, device=x.device).item() < p:
        h, w = x.shape[-2:]
        max_pix = max(1, int(h*translate))
        dx = torch.randint(-max_pix, max_pix+1, (1,), device=x.device).item()
        dy = torch.randint(-max_pix, max_pix+1, (1,), device=x.device).item()
        x = torch.roll(x, shifts=(dx,dy), dims=(2,3))
    return x

ADA_STATE = {"p": 0.0, "acc_ema": None}
ADA_TARGET=0.6; ADA_DECAY=0.99; ADA_SPEED=0.25; ADA_MAX_P=0.08
R1_GAMMA=10.0; R1_EVERY=8
EPOCHS=2; LR_G=4e-4; LR_D=1e-4; BETAS=(0.0,0.99)
SAMPLE_EVERY=400; SAVE_EVERY=800

opt_G = torch.optim.Adam(G.parameters(), lr=LR_G, betas=BETAS)
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D, betas=BETAS)

try:
    from torch.amp import GradScaler as _GradScaler
    from torch.cuda.amp import autocast as autocast_cm
except Exception:
    from torch.cuda.amp import GradScaler as _GradScaler
    def autocast_cm(enabled=True): return nullcontext()
USE_AMP = False
scaler_G=_GradScaler(enabled=USE_AMP); scaler_D=_GradScaler(enabled=USE_AMP)

class RealStatsEMA:
    def __init__(self, m=0.995): self.m=m; self.mean=None; self.std=None
    @torch.no_grad()
    def update(self, x):
        x01=(x.clamp(-1,1)+1)*0.5
        mean=x01.mean(dim=(0,2,3)); std=x01.std(dim=(0,2,3)).clamp_min(1e-6)
        if self.mean is None: self.mean, self.std = mean, std
        else:
            self.mean = self.m*self.mean + (1-self.m)*mean
            self.std  = self.m*self.std  + (1-self.m)*std
    def penalty(self, f):
        if self.mean is None: return f.new_zeros(())
        f01=(f.clamp(-1,1)+1)*0.5
        fmean=f01.mean(dim=(0,2,3)); fstd=f01.std(dim=(0,2,3)).clamp_min(1e-6)
        return (fmean-self.mean.detach()).pow(2).mean() + (fstd-self.std.detach()).pow(2).mean()

real_stats = RealStatsEMA(0.995); LAMBDA_STATS = 3e-3

@torch.no_grad()
def duet_frechet_cond(real_batch_iter, G_ema, z_dim, num_classes, n_real=128, n_fake=128, device=None):
    dev = device or next(G_ema.parameters()).device
    def _resize224(x):
        x01 = (x.clamp(-1,1)+1)*0.5
        return torch.nn.functional.interpolate(x01, size=(224,224), mode="bilinear", align_corners=False)
    # real
    feats_r, n_acc = [], 0
    for x in real_batch_iter:
        x = x.to(dev)
        x224 = _resize224(x)
        f = pathoduet(x224)
        feats_r.append(f)
        n_acc += x.size(0)
        if n_acc >= n_real: break
    feats_r = torch.cat(feats_r, dim=0)[:n_real]
    # fake
    z = torch.randn(n_fake, z_dim, device=dev)
    y = torch.randint(0, num_classes, (n_fake,), device=dev)
    fake = G_ema(z, y)
    f224 = _resize224(fake)
    feats_f = pathoduet(f224)
    # moments
    def _mom(fe):
        mu = fe.mean(dim=0); xc=fe-mu; cov = (xc.t()@xc)/(fe.size(0)-1+1e-8)
        return mu, cov
    mu_r, cov_r = _mom(feats_r); mu_f, cov_f = _mom(feats_f)
    def _fre(mu1,c1,mu2,c2):
        m1=mu1.double(); m2=mu2.double(); C1=c1.double(); C2=c2.double()
        diff=(m1-m2)
        eva1,eve1=torch.linalg.eigh(C1+1e-6*torch.eye(C1.shape[0], device=dev));
        eva2,eve2=torch.linalg.eigh(C2+1e-6*torch.eye(C2.shape[0], device=dev));
        C1h=(eve1@torch.diag_embed(eva1.clamp_min(0).sqrt())@eve1.t())
        C2h=(eve2@torch.diag_embed(eva2.clamp_min(0).sqrt())@eve2.t())
        evaP,eveP=torch.linalg.eigh((C1h@C2h)@ (C1h@C2h))
        sqrtP=(eveP@torch.diag_embed(evaP.clamp_min(0).sqrt())@eveP.t())
        return (diff@diff).item() + torch.trace(C1+C2-2*sqrtP).item()
    return float(_fre(mu_r,cov_r,mu_f,cov_f))

def log_metrics_csv(path, row):
    path.parent.mkdir(parents=True, exist_ok=True)
    is_new = (not path.exists()) or path.stat().st_size==0
    with path.open("a", newline="") as f:
        w = DictWriter(f, fieldnames=["epoch","step","d_loss","g_loss","real_mu","real_min","real_max","fake_mu","fake_min","fake_max","ada_p","acc","acc_ema","alpha"])
        if is_new: w.writeheader()
        w.writerow(row)

G_ema = Generator(z_dim=Z_DIM, w_dim=Z_DIM, img_res=IMAGE_SIZE, fmap_base=G.const.shape[1], num_classes=num_classes).to(DEVICE)
G_ema.load_state_dict(G.state_dict()); requires_grad(G_ema, False)

global_step=0
print("[init] D.alpha=0.0 (branche PathoDuet)")
D.alpha = 0.0
METRICS_CSV = OUTPUTS_DIR / "gan/metrics_gan.csv"
if METRICS_CSV.exists(): METRICS_CSV.unlink()

for epoch in range(EPOCHS):
    t0=time.time()
    for real, y_real, _ in loader_gan:
        global_step += 1
        real = real.to(DEVICE, non_blocking=True)
        y_real = y_real.to(DEVICE, non_blocking=True)
        # D
        for p in D.parameters(): p.requires_grad_(True)
        for p in G.parameters(): p.requires_grad_(False)
        y_fake = torch.randint(0, num_classes, (real.size(0),), device=DEVICE)
        z = torch.randn(real.size(0), Z_DIM, device=DEVICE)
        opt_D.zero_grad(set_to_none=True)
        with autocast_cm(enabled=USE_AMP):
            fake = G(z, y_fake).detach()
            real_aug = ada_augment(real, ADA_STATE["p"], translate=0.02)
            fake_aug = ada_augment(fake, ADA_STATE["p"], translate=0.02)
            real_pred = D(real_aug, y_real)
            fake_pred = D(fake_aug, y_fake)
            d_loss = d_logistic_loss(real_pred, fake_pred)
        scaler_D.scale(d_loss).backward()
        do_r1 = (global_step % R1_EVERY)==0
        if do_r1:
            real_r1 = real.detach().requires_grad_(True)
            with autocast_cm(enabled=USE_AMP):
                real_aug_r1 = ada_augment(real_r1, ADA_STATE["p"], translate=0.04)
                real_pred_r1 = D(real_aug_r1, y_real)
            r1 = 0.5*R1_GAMMA*r1_penalty(real_aug_r1, real_pred_r1)
            scaler_D.scale(r1).backward()
        torch.nn.utils.clip_grad_norm_(D.parameters(), 1.0)
        scaler_D.step(opt_D); scaler_D.update()
        with torch.no_grad(): real_stats.update(real)
        # ADA simple (EMA accuracy)
        with torch.no_grad():
            acc = torch.sigmoid(real_pred).gt(0.5).float().mean().item()
            ADA_STATE["acc_ema"] = acc if ADA_STATE["acc_ema"] is None else ADA_DECAY*ADA_STATE["acc_ema"] + (1-ADA_DECAY)*acc
            err = ADA_STATE["acc_ema"] - ADA_TARGET
            ADA_STATE["p"] = float(min(ADA_MAX_P, max(0.0, ADA_STATE["p"] + 0.001*err*ADA_SPEED)))

        # G
        for p in D.parameters(): p.requires_grad_(False)
        for p in G.parameters(): p.requires_grad_(True)
        y_fake = torch.randint(0, num_classes, (real.size(0),), device=DEVICE)
        z = torch.randn(real.size(0), Z_DIM, device=DEVICE)
        opt_G.zero_grad(set_to_none=True)
        with autocast_cm(enabled=USE_AMP):
            fake = G(z, y_fake)
            fake_aug = ada_augment(fake, ADA_STATE["p"], translate=0.03)
            fake_pred = D(fake_aug, y_fake)
            g_loss = g_nonsat_loss(fake_pred) + LAMBDA_STATS*real_stats.penalty(fake)
        scaler_G.scale(g_loss).backward()
        torch.nn.utils.clip_grad_norm_(G.parameters(), 1.0)
        scaler_G.step(opt_G); scaler_G.update()
        ema_update(G_ema, G, decay=0.9995)

        if global_step % 50 == 0:
            with torch.no_grad():
                r_mean, r_min, r_max = real.mean().item(), real.min().item(), real.max().item()
                f_clamp = fake.clamp(-1,1)
                f_mean, f_min, f_max = f_clamp.mean().item(), f_clamp.min().item(), f_clamp.max().item()
                alpha = getattr(D, 'alpha', 0.0)
            log_metrics_csv(OUTPUTS_DIR/"gan/metrics_gan.csv", {
                "epoch": epoch+1, "step": global_step,
                "d_loss": float(d_loss.item()), "g_loss": float(g_loss.item()),
                "real_mu": r_mean, "real_min": r_min, "real_max": r_max,
                "fake_mu": f_mean, "fake_min": f_min, "fake_max": f_max,
                "ada_p": ADA_STATE["p"], "acc": acc, "acc_ema": ADA_STATE["acc_ema"], "alpha": alpha
            })

        if global_step % SAMPLE_EVERY == 0:
            with torch.no_grad():
                z_vis = torch.randn(16, Z_DIM, device=DEVICE)
                y_vis = torch.randint(0, num_classes, (16,), device=DEVICE)
                imgs = G_ema(z_vis, y_vis).clamp(-1,1)
                vutils.save_image((imgs+1)*0.5, str(SAMPLES_DIR / f"sample_step{global_step:06d}.png"), nrow=4)
                try:
                    fd = duet_frechet_cond(loader_real_eval, G_ema, Z_DIM, num_classes, n_real=96, n_fake=96, device=DEVICE)
                    print(f"[Duet-FID] step {global_step:06d} -> {fd:.3f}")
                except Exception as e:
                    print("[Duet-FID] skip:", e)

        if global_step % SAVE_EVERY == 0:
            torch.save({
                "G": G.state_dict(), "D": D.state_dict(), "G_ema": G_ema.state_dict(),
                "opt_G": opt_G.state_dict(), "opt_D": opt_D.state_dict(),
                "step": global_step, "epoch": epoch, "num_classes": num_classes,
            }, CKPT_DIR / f"cgan_step{global_step:06d}.pt")

    print(f"✅ Epoch {epoch+1}/{EPOCHS} OK ({(time.time()-t0)/60:.1f} min)")
