
# 08_Adapter_LoRA_diffusion ‚Äî UNI2‚Äëh (gel√©) ‚Üî PixCell (r√©el) *gated*

**But** : r√©duire le *domain shift* observ√© sur les embeddings (UNI2‚Äëh, PathoDuet) en adaptant l√©g√®rement la pipeline PixCell via :
- un **Adapter** MLP (projection/pack de l‚Äôembedding UNI2‚Äëh en quelques tokens de contexte) ;
- des **LoRA** l√©g√®res inject√©es uniquement dans les **cross‚Äëattentions** du **UNet** (PixCell).

**Principe** :  
UNI2‚Äëh est gel√© et fournit un vecteur de conditionnement `z_uni`. On monkey‚Äëpatch la pipeline sans toucher aux poids de base : on concat√®ne les tokens projet√©s par l'adapter aux `encoder_hidden_states` (texte) consomm√©s par les cross‚Äëattn, et on n‚Äôentra√Æne que l‚Äôadapter + les matrices basses‚Äërangs (LoRA).

> ‚ö†Ô∏è Le *monkey patch* est central : on n‚Äôalt√®re pas la signature publique de la pipeline, on accroche un hook propre √† `_encode_prompt` (SD‚Äëlike) et on append nos tokens.  
> ‚ö†Ô∏è Param√©trage d√©licat ‚Üí toutes les dimensions sensibles sont factoris√©es dans une seule section de config.


## Cellule 1 ‚Äî Environnement & login Hugging Face

In [1]:
# ========= Cellule 1 ‚Äî Environnement & login Hugging Face =========
# - Installe/MAJ les libs n√©cessaires (sans xformers pour √©viter les conflits CUDA)
# - Se connecte √† Hugging Face avec un token personnel (non affich√©)
# - Affiche un r√©cap GPU / versions pour sanity-check

# %pip -qv install --upgrade diffusers transformers accelerate safetensors huggingface_hub

import os, torch
from huggingface_hub import login
from getpass import getpass

# --- Connexion Hugging Face ---
# Option 1 : mettre le token dans la variable d'environnement HF_TOKEN avant le lancement du notebook
# Option 2 : saisie s√©curis√©e au clavier (recommand√© si tu n'as rien export√©)
token = os.environ.get("HF_TOKEN") or getpass("üëâ Entrez votre Hugging Face token (il ne sera pas affich√©) : ")
login(token=token, add_to_git_credential=False)

# --- (Optionnel) acc√©l√©ration transferts ---
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# --- R√©cap environnement GPU ---
device = "cuda" if torch.cuda.is_available() else "cpu"
cuda_ver = getattr(torch.version, "cuda", None)
gpu_name = torch.cuda.get_device_name(0) if device == "cuda" else "CPU"
print(f"[OK] Torch={torch.__version__} | CUDA={cuda_ver} | Device={gpu_name}")


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


[OK] Torch=2.4.0 | CUDA=12.4 | Device=NVIDIA GeForce RTX 4060 Laptop GPU


In [2]:
# ========= Cellule 2 ‚Äî Config minimal, chemins & seed =========
# - Fixe device, dtype, seed
# - Centralise les chemins/projets/datasets
# - Pr√©pare les IDs Hugging Face (√† compl√©ter selon tes repos)

import os, json, random, math
from pathlib import Path

import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")

# ---- Device & dtype conseill√©s
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BF16_OK = (DEVICE == "cuda") and torch.cuda.is_bf16_supported()
AMP_DTYPE = torch.bfloat16 if BF16_OK else torch.float16

# ---- Seed
SEED = 42
def seed_everything(seed: int = 42):
    import numpy as np, torch, random, os
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # on reste en mode perfs (pas de determinism strict)
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

# ---- Arborescence projet / outputs
PROJECT_ROOT = Path.cwd()
OUTPUTS_DIR  = PROJECT_ROOT / "outputs" / "08_adapter_lora"
CHECKPOINTS_DIR = OUTPUTS_DIR / "checkpoints"
SAMPLES_DIR  = OUTPUTS_DIR / "samples"
for d in (OUTPUTS_DIR, CHECKPOINTS_DIR, SAMPLES_DIR):
    d.mkdir(parents=True, exist_ok=True)

# ---- Donn√©es locales (ajuste si besoin)
DATA_ROOT   = Path("/workspace/data")
TRAIN_DIR   = DATA_ROOT / "NCT-CRC-HE-100K"
VAL_DIR     = DATA_ROOT / "CRC-VAL-HE-7K"

# ---- Hyperparam√®tres de base
IMG_SIZE       = 256
BATCH_SIZE     = 4          # commence bas (VRAM friendly)
NUM_WORKERS    = 0          # √©viter le multiprocessing sous Jupyter+CUDA
LEARNING_RATE  = 1e-4
TRAIN_STEPS    = 1000
GRAD_ACCUM     = 1

# ---- Hugging Face (√† compl√©ter) : on tirera UNI2-h & PixCell directement depuis HF
UNI_REPO_ID       = "hf-hub:MahmoodLab/UNI2-h"        # ex: "acme-ai/uni2h-base"
UNI_REVISION      = "main"
PIXCELL_REPO_ID = "StonyBrook-CVLab/PixCell-256"      # ex: "acme-ai/pixcell-sd15"
PIXCELL_REVISION  = "main"

# ---- Dimensions de conditionnement (adapter & cross-attn)
UNI_OUT_DIM     = 512       # dim de l'embedding UNI2-h (√† confirmer selon le repo)
TEXT_CTX_DIM    = 768       # dim du contexte texte (SD 1.5-like)
TOKENS_FROM_UNI = 4         # nb de tokens inject√©s par l'adapter

# ---- LoRA (cross-attn uniquement)
LORA_RANK   = 8
LORA_ALPHA  = 8
LORA_SCALE  = 1.0
LORA_TARGET = ["to_q", "to_k", "to_v", "to_out.0"]  # motifs √† patcher dans le UNet

# ---- Qualit√© de vie Hugging Face cache
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")      # transferts plus rapides
os.environ.setdefault("HF_HOME", str(PROJECT_ROOT / ".hf"))  # cache local dans le projet

# ---- R√©cap rapide
summary = {
    "device": DEVICE,
    "amp_dtype": "bf16" if AMP_DTYPE is torch.bfloat16 else "fp16",
    "seed": SEED,
    "img_size": IMG_SIZE,
    "batch_size": BATCH_SIZE,
    "data_root": str(DATA_ROOT),
    "train_dir_exists": TRAIN_DIR.exists(),
    "val_dir_exists": VAL_DIR.exists(),
    "outputs_dir": str(OUTPUTS_DIR),
    "uni_repo_id": UNI_REPO_ID,
    "pixcell_repo_id": PIXCELL_REPO_ID,
    "uni_out_dim": UNI_OUT_DIM,
    "text_ctx_dim": TEXT_CTX_DIM,
    "tokens_from_uni": TOKENS_FROM_UNI,
    "lora_rank": LORA_RANK,
}
print(json.dumps(summary, indent=2))


{
  "device": "cuda",
  "amp_dtype": "bf16",
  "seed": 42,
  "img_size": 256,
  "batch_size": 4,
  "data_root": "/workspace/data",
  "train_dir_exists": true,
  "val_dir_exists": true,
  "outputs_dir": "/workspace/notebooks/outputs/08_adapter_lora",
  "uni_repo_id": "hf-hub:MahmoodLab/UNI2-h",
  "pixcell_repo_id": "StonyBrook-CVLab/PixCell-256",
  "uni_out_dim": 512,
  "text_ctx_dim": 768,
  "tokens_from_uni": 4,
  "lora_rank": 8
}


In [3]:
# ========= Cellule 3 ‚Äî PixCell (custom) + VAE SD3.5 LARGE ‚Äî version 07 =========
import torch
from diffusers import DiffusionPipeline, AutoencoderKL

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1) VAE SD3.5 Large (subfolder="vae")
vae = AutoencoderKL.from_pretrained(
    "stabilityai/stable-diffusion-3.5-large",
    subfolder="vae",
    torch_dtype=torch.float16,     # comme dans 07
)

# 2) Pipeline PixCell-256 (custom pipeline + code distant)
pipe256 = DiffusionPipeline.from_pretrained(
    "StonyBrook-CVLab/PixCell-256",    # tu peux garder la variable PIXCELL_REPO_ID si tu pr√©f√®res
    vae=vae,
    custom_pipeline="StonyBrook-CVLab/PixCell-pipeline",
    trust_remote_code=True,
    torch_dtype=torch.float16,         # comme dans 07
).to(device)

# 4) Sanity check minimal (pas d'acc√®s .unet ni ._modules)
print("Pipeline charg√©e:", type(pipe256).__name__)
print("Has vae:", hasattr(pipe256, "vae"))
print("Has transformer:", hasattr(pipe256, "transformer"))
print("Has scheduler:", hasattr(pipe256, "scheduler"))
print("Device:", device, "| dtype:", torch.float16)


Keyword arguments {'trust_remote_code': True} are not expected by PixCellPipeline and will be ignored.


Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]

The config attributes {'double_self_attention': False, 'num_vector_embeds': None, 'only_cross_attention': False, 'use_linear_projection': False} were passed to PixCellTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
The config attributes {'flow_shift': 1.0, 'use_flow_sigmas': False} were passed to DPMSolverMultistepScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.


Pipeline charg√©e: PixCellPipeline
Has vae: True
Has transformer: True
Has scheduler: True
Device: cuda | dtype: torch.float16


In [4]:
# ========= Cellule 4 ‚Äî UNI-2h (timm) + transform =========
# Charge UNI-2h depuis HF via timm, avec les kwargs fournis par la model card PixCell.
# Fournit aussi une fonction utilitaire pour extraire un embedding (B, 1, D).

import torch
try:
    import timm
    from timm.data import resolve_data_config
    from timm.data.transforms_factory import create_transform
except Exception as e:
    raise RuntimeError(
        "Le package 'timm' est requis. Installe-le puis relance cette cellule :\n"
        "%pip install timm"
    ) from e

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Kwargs conformes √† la doc PixCell (ViT-H style)
timm_kwargs = {
    'img_size': 224,
    'patch_size': 14,
    'depth': 24,
    'num_heads': 24,
    'init_values': 1e-5,
    'embed_dim': 1536,
    'mlp_ratio': 2.66667*2,
    'num_classes': 0,
    'no_embed_class': True,
    'mlp_layer': timm.layers.SwiGLUPacked,
    'act_layer': torch.nn.SiLU,
    'reg_tokens': 8,
    'dynamic_img_size': True
}

# 1) Mod√®le UNI-2h pr√©-entra√Æn√© depuis HF (MahmoodLab/UNI2-h)
uni_model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
uni_model.eval().to(device)

# 2) Transform officiel d√©riv√© de la config timm
transform = create_transform(**resolve_data_config(uni_model.pretrained_cfg, model=uni_model))

# 3) Helper: encode une PIL.Image ou un batch tensor (B,3,H,W) -> (B,1,D)
@torch.inference_mode()
def encode_uni(input_):
    """
    input_: PIL.Image ou torch.Tensor (B,3,H,W) dans [0,1]
    retourne: torch.Tensor (B, 1, D)
    """
    if isinstance(input_, torch.Tensor):
        x = input_.to(device)
        if x.dim() == 3:
            x = x.unsqueeze(0)
        # assume d√©j√† normalis√© ? sinon on passe par la transform
        # Ici on repasse par transform par s√©curit√© : convertissons en PIL pour rester fid√®les
        from torchvision.transforms.functional import to_pil_image
        imgs = [transform(to_pil_image(t.cpu().clamp(0,1))) for t in x]
        x = torch.stack(imgs, dim=0).to(device)
    else:
        # PIL.Image -> tensor via transform
        x = transform(input_).unsqueeze(0).to(device)

    z = uni_model(x)          # (B, D)
    z = z.unsqueeze(1)        # (B, 1, D)
    return z

# 4) Mini smoke test avec l'image d'exemple du repo PixCell
from huggingface_hub import hf_hub_download
from PIL import Image

test_path = hf_hub_download(repo_id="StonyBrook-CVLab/PixCell-256", filename="test_image.png")
img = Image.open(test_path).convert("RGB")
z_uni = encode_uni(img)
print("UNI emb shape:", tuple(z_uni.shape))  # attendu: (1, 1, 1536)


test_image.png:   0%|          | 0.00/131k [00:00<?, ?B/s]

UNI emb shape: (1, 1, 1536)
