# üß© Cellule 1 ‚Äì Environnement et imports

In [1]:
# ===============================================================
# üß© 08_PixCell_LoRA_Conditionnel.ipynb
# Cellule 1 : Imports & configuration de l'environnement
# ===============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from diffusers import DiffusionPipeline

# ---------------------------------------------------------------
# üîß Configuration g√©n√©rale
# ---------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚úÖ Device actif : {device}")

# R√©pertoires du projet
PROJECT_ROOT = Path("/workspace")  # adapte selon ton environnement
DATA_DIR     = PROJECT_ROOT / "data" / "NCT-CRC-HE-100K"
CONFIG_DIR   = PROJECT_ROOT / "configs"
OUTPUT_DIR   = PROJECT_ROOT / "outputs" / "pixcell_lora"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"üìÇ Dossier de sortie : {OUTPUT_DIR}")

# ---------------------------------------------------------------
# V√©rification CUDA
# ---------------------------------------------------------------
if device.type == "cuda":
    print(f"GPU d√©tect√© : {torch.cuda.get_device_name(0)}")
    print(f"CUDA version : {torch.version.cuda}")
    print(f"PyTorch version : {torch.__version__}")
else:
    print("‚ö†Ô∏è Pas de GPU d√©tect√© ‚Äî attention aux performances.")


‚úÖ Device actif : cuda
üìÇ Dossier de sortie : /workspace/outputs/pixcell_lora
GPU d√©tect√© : NVIDIA GeForce RTX 4060 Laptop GPU
CUDA version : 12.4
PyTorch version : 2.4.0


# üß© Cellule 2 ‚Äì Chargement de PixCell + UNI-2h

In [2]:
# ===============================================================
# üß© Cellule 2 : Chargement PixCell (avec VAE explicite) + UNI-2h (timm)
# ===============================================================
import os, shutil, torch, timm
import torch.nn.functional as F
from pathlib import Path
from huggingface_hub import login
from diffusers import DiffusionPipeline, AutoencoderKL
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

# ---------------------------------------------------------------
# üîê Authentification Hugging Face
# ---------------------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
assert HF_TOKEN is not None, "‚ö†Ô∏è Variable d‚Äôenvironnement HF_TOKEN absente."
login(token=HF_TOKEN, add_to_git_credential=True)
print("üîë Authentification Hugging Face r√©ussie.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CACHE_DIR = Path(os.getenv("HF_HOME", "/workspace/.cache/huggingface"))

# ---------------------------------------------------------------
# 1Ô∏è‚É£ Chargement du pipeline PixCell-256 (avec VAE explicite)
# ---------------------------------------------------------------
def load_pixcell_with_explicit_vae(device):
    print("‚è≥ Chargement du VAE (SD 3.5 Large)‚Ä¶")
    vae = AutoencoderKL.from_pretrained(
        "stabilityai/stable-diffusion-3.5-large",
        subfolder="vae",
        torch_dtype=torch.float16,
        token=HF_TOKEN,
    ).to(device)
    print("‚úÖ VAE charg√©.")

    print("‚è≥ Chargement de PixCell-256‚Ä¶")
    pipe = DiffusionPipeline.from_pretrained(
        "StonyBrook-CVLab/PixCell-256",
        custom_pipeline="StonyBrook-CVLab/PixCell-pipeline",
        trust_remote_code=True,
        torch_dtype=torch.float16,
        vae=vae,
        token=HF_TOKEN,
    ).to(device)
    print("‚úÖ PixCell-256 charg√©.")
    return pipe

try:
    pipe256 = load_pixcell_with_explicit_vae(device)
except OSError as e:
    print("‚ö†Ô∏è Cache local corrompu :", e)
    model_cache_root = CACHE_DIR / "hub" / "models--StonyBrook-CVLab--PixCell-256" / "snapshots"
    if model_cache_root.exists():
        for snap in model_cache_root.iterdir():
            vae_dir = snap / "vae"
            if vae_dir.exists():
                print(f"üßπ Suppression du sous-dossier VAE corrompu : {vae_dir}")
                shutil.rmtree(vae_dir, ignore_errors=True)
    pipe256 = load_pixcell_with_explicit_vae(device)

# ---------------------------------------------------------------
# 2Ô∏è‚É£ Chargement du mod√®le UNI-2h (timm)
# ---------------------------------------------------------------
print("\n‚è≥ Chargement du mod√®le UNI-2h (via timm)‚Ä¶")

timm_kwargs = {
    "img_size": 224,
    "patch_size": 14,
    "depth": 24,
    "num_heads": 24,
    "embed_dim": 1536,
    "mlp_ratio": 2.66667 * 2,   # sp√©cifique √† UNI-2h
    "init_values": 1e-5,
    "num_classes": 0,
    "no_embed_class": True,
    "reg_tokens": 8,
    "dynamic_img_size": True,
    "mlp_layer": timm.layers.SwiGLUPacked,
    "act_layer": torch.nn.SiLU,
}

uni = timm.create_model(
    "hf-hub:MahmoodLab/UNI2-h",
    pretrained=True,
    **timm_kwargs
).eval().to(device)

cfg = resolve_data_config(uni.pretrained_cfg, model=uni)
tfm = create_transform(**cfg)

print("‚úÖ UNI-2h pr√™t ‚Üí", next(uni.parameters()).dtype, "| device:", device)

# ---------------------------------------------------------------
# 3Ô∏è‚É£ Fonction utilitaire : embeddings UNI-2h
# ---------------------------------------------------------------
@torch.inference_mode()
def uni_embed_from_tensor(x_3chw: torch.Tensor) -> torch.Tensor:
    """
    Extrait les embeddings UNI-2h d'une image (B,3,H,W)
    ‚Üí Retourne (B,1536) en float32
    """
    if x_3chw.shape[-1] != 224:
        x_3chw = F.interpolate(x_3chw, size=(224, 224),
                               mode="bilinear", align_corners=False)
    e = uni(x_3chw).float()
    if hasattr(e, "shape") and e.ndim == 3:
        e = e[:, 0]  # token [CLS]
    return e

# ---------------------------------------------------------------
# 4Ô∏è‚É£ V√©rification dtype/device PixCell
# ---------------------------------------------------------------
ref_weight = pipe256.transformer.pos_embed.proj.weight
PIPE_DTYPE = ref_weight.dtype
PIPE_DEV   = ref_weight.device
print(f"\nPixCell dtype : {PIPE_DTYPE}, device : {PIPE_DEV}")


Token has not been saved to git credential helper.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


[1m[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub.
Run the following command in your terminal in case you want to set the 'store' credential helper as default.

git config --global credential.helper store

Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.[0m
üîë Authentification Hugging Face r√©ussie.
‚è≥ Chargement du VAE (SD 3.5 Large)‚Ä¶
‚úÖ VAE charg√©.
‚è≥ Chargement de PixCell-256‚Ä¶


Keyword arguments {'trust_remote_code': True} are not expected by PixCellPipeline and will be ignored.


Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]

The config attributes {'flow_shift': 1.0, 'use_flow_sigmas': False} were passed to DPMSolverMultistepScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.
The config attributes {'double_self_attention': False, 'num_vector_embeds': None, 'only_cross_attention': False, 'use_linear_projection': False} were passed to PixCellTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.


‚úÖ PixCell-256 charg√©.

‚è≥ Chargement du mod√®le UNI-2h (via timm)‚Ä¶
‚úÖ UNI-2h pr√™t ‚Üí torch.float32 | device: cuda

PixCell dtype : torch.float16, device : cuda:0


# üß© Cellule 3 ‚Äì Dataset, mappings et prototypes UNI-2h

In [5]:
# ===============================================================
# ‚öôÔ∏è Configuration du chemin projet pour importer p9dg
# ===============================================================
import sys
from pathlib import Path

PROJECT_ROOT = Path("/workspace")  # üîß adapte si besoin (par ex. ".." en local)
PACKAGE_DIR = PROJECT_ROOT / "p9dg"

if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

assert PACKAGE_DIR.exists(), f"‚ùå Dossier {PACKAGE_DIR} introuvable."
print(f"‚úÖ Chemin ajout√© au PYTHONPATH : {PROJECT_ROOT}")



‚úÖ Chemin ajout√© au PYTHONPATH : /workspace


In [9]:
# ‚úÖ D√©tection robuste du split + init dataset sans doublon de chemin
from pathlib import Path
from p9dg.histo_dataset import HistoDataset

DATA_ROOT = Path("/workspace/data")            # <-- parent, pas le sous-dossier
assert DATA_ROOT.exists(), f"{DATA_ROOT} introuvable"

if (DATA_ROOT / "NCT-CRC-HE-100K").is_dir():
    SPLIT = "NCT-CRC-HE-100K"
elif (DATA_ROOT / "CRC-VAL-HE-7K").is_dir():
    SPLIT = "CRC-VAL-HE-7K"
else:
    raise FileNotFoundError("Ni 'NCT-CRC-HE-100K' ni 'CRC-VAL-HE-7K' trouv√©s sous /workspace/data")

print(f"üîé Split d√©tect√© : {SPLIT}")
print(f"üìÅ Base effective : {DATA_ROOT / SPLIT}")

train_ds = HistoDataset(
    root_data=DATA_ROOT,
    split=SPLIT,
    apply_quality_filter=True,
    thresholds_json_path=SEUILS_PATH,   # chemin JSON (pas un dict)
    balance_per_class=True,
    samples_per_class_per_epoch=30,
)

print(f"‚úÖ Dataset charg√© ({len(train_ds)} images) | {len(train_ds.classes)} classes")
print("Classes :", train_ds.classes)


üîé Split d√©tect√© : NCT-CRC-HE-100K
üìÅ Base effective : /workspace/data/NCT-CRC-HE-100K
üé® R√©f√©rence Vahadane fix√©e : TUM-TCGA-WAEEFPKC.tif
üé® R√©f√©rence Vahadane auto: TUM-TCGA-WAEEFPKC.tif
‚úÖ Seuils par classe charg√©s depuis : /workspace/configs/seuils_par_classe.json
‚öñÔ∏è √âchantillonnage √©quilibr√© activ√© (30 images / classe).


AttributeError: 'HistoDataset' object has no attribute 'classes'

In [None]:
# ===============================================================
# üß© Cellule 3 : Chargement du dataset + mappings + prototypes UNI-2h
# ===============================================================
from pathlib import Path
import json, torch
from p9dg.histo_dataset import HistoDataset
from utils.class_mappings import class_labels, make_idx_mappings

# ---------------------------------------------------------------
# 1Ô∏è‚É£ Chargement des seuils par classe
# ---------------------------------------------------------------
SEUILS_PATH = CONFIG_DIR / "seuils_par_classe.json"
with open(SEUILS_PATH, "r") as f:
    seuils_par_classe = json.load(f)
print(f"‚úÖ Seuils par classe charg√©s depuis : {SEUILS_PATH}")

# ---------------------------------------------------------------
# 2Ô∏è‚É£ Initialisation du dataset √©quilibr√©
# ---------------------------------------------------------------
DATA_PATH = Path("/workspace/data/NCT-CRC-HE-100K")  # üîß chemin correct

train_ds = HistoDataset(
    root_data=DATA_PATH,
    apply_quality_filter=True,
    thresholds_json_path=SEUILS_PATH,        # chemin JSON (pas dict)
    balance_per_class=True,
    samples_per_class_per_epoch=30,          # pour un test rapide
)

print(f"‚úÖ Dataset charg√© ({len(train_ds)} images) | {len(train_ds.classes)} classes")
print("Classes :", train_ds.classes)

# ---------------------------------------------------------------
# 3Ô∏è‚É£ Mappings lisibles
# ---------------------------------------------------------------
idx_to_name, idx_to_color, class_to_label = make_idx_mappings(train_ds.class_to_idx)

# ---------------------------------------------------------------
# 4Ô∏è‚É£ Calcul ou rechargement des prototypes UNI-2h
# ---------------------------------------------------------------
PROTOS_PATH = OUTPUT_DIR / "prototypes_uni2h.pt"
if PROTOS_PATH.exists():
    PROTOS = torch.load(PROTOS_PATH)
    print(f"‚úÖ Prototypes UNI-2h recharg√©s depuis {PROTOS_PATH}")
else:
    print("‚è≥ Calcul des prototypes UNI-2h par classe...")
    PROTOS = {}
    for cname in train_ds.class_to_idx.keys():
        imgs = [train_ds._resize(train_ds._load_image(p))
                for p in train_ds.paths_by_class[train_ds.class_to_idx[cname]][:8]]
        batch = torch.stack([train_ds._to_tensor(img).to(device) for img in imgs])
        e = uni_embed_from_tensor(batch)  # [B,1536]
        PROTOS[cname] = e.mean(0).detach().cpu()
    torch.save(PROTOS, PROTOS_PATH)
    print(f"‚úÖ Prototypes UNI-2h sauvegard√©s ‚Üí {PROTOS_PATH}")

# ---------------------------------------------------------------
# 5Ô∏è‚É£ V√©rification rapide
# ---------------------------------------------------------------
for cname, v in PROTOS.items():
    print(f"{cname:<8} ‚Üí {tuple(v.shape)}")
