# Z-Image Turbo + FluxSR FTD — Fixed Colab Training Notebook

**Key fixes over previous version:**
- `call_transformer` converts `(B,C,H,W)` → `List[(C,1,H,W)]` per sample (Z-Image's expected format)
- `all_cap_feats` passed as `List[Tensor(seq_len, cap_dim)]` per sample
- Debug cell dumps model shapes/signatures to files for inspection
- No duplicate definitions; every cell is safe to re-run
- Consistent `call_transformer` signature across FTD and recon loss


In [None]:
!pip -q install -U diffusers accelerate peft transformers safetensors torchvision lpips

## 1) Mount Drive + unzip dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import shutil
from pathlib import Path

# -------- EDIT THESE --------
ZIP_NAME = "zimage_offline_pairs_with_lr.zip"
DRIVE_ZIP_PATH = f"/content/drive/MyDrive/{ZIP_NAME}"
WORKDIR = Path("/content")
# ---------------------------

local_zip = WORKDIR / ZIP_NAME
print("Copying zip from Drive:", DRIVE_ZIP_PATH)
shutil.copy(DRIVE_ZIP_PATH, local_zip)

print("Unzipping...")
!unzip -q -o "{local_zip}" -d "{WORKDIR}"


PAIRS_DIR = Path("/content/zimage_offline_pairs/pairs")
print("PAIRS_DIR:", PAIRS_DIR)
print("Num samples:", len([d for d in PAIRS_DIR.iterdir() if d.is_dir()]))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
PAIRS_DIR: /content/zimage_offline_pairs/pairs
Num samples: 1200


## 2) Load Z-Image Turbo pipeline once

In [None]:
import torch, numpy as np, inspect, json
from diffusers import ZImageImg2ImgPipeline

DTYPE = "bf16"  # "bf16" or "fp16"
torch_dtype = torch.bfloat16 if DTYPE == "bf16" else torch.float16
device = "cuda"
MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"

pipe = ZImageImg2ImgPipeline.from_pretrained(MODEL_ID, torch_dtype=torch_dtype).to(device)

pipe.vae.requires_grad_(False)
if getattr(pipe, "text_encoder", None) is not None:
    pipe.text_encoder.requires_grad_(False)
pipe.transformer.requires_grad_(False)

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")

t_scale = float(getattr(pipe.transformer.config, "t_scale", 1.0))
vae_sf = float(getattr(pipe.vae.config, "scaling_factor", 1.0))
vae_sh = float(getattr(pipe.vae.config, "shift_factor", 0.0))
print(f"t_scale={t_scale}, vae_sf={vae_sf}, vae_sh={vae_sh}")

# --- Null conditioning ---
with torch.no_grad():
    sig = inspect.signature(pipe.encode_prompt)
    kwargs = {}
    if "device" in sig.parameters: kwargs["device"] = device
    if "do_classifier_free_guidance" in sig.parameters: kwargs["do_classifier_free_guidance"] = False
    if "num_images_per_prompt" in sig.parameters: kwargs["num_images_per_prompt"] = 1
    pe = pipe.encode_prompt("", **kwargs)

# Extract first tensor from encode_prompt result
null_cap_feats = None
def _first_tensor(x):
    if torch.is_tensor(x): return x
    if isinstance(x, (tuple, list)):
        for v in x:
            if torch.is_tensor(v): return v
    return None

null_cap_feats = _first_tensor(pe)
if null_cap_feats is not None:
    null_cap_feats = null_cap_feats.detach()
    print("null_cap_feats:", tuple(null_cap_feats.shape), null_cap_feats.dtype)
else:
    print("WARNING: encode_prompt returned no tensor")

# Offload text encoder
if getattr(pipe, "text_encoder", None) is not None:
    pipe.text_encoder.to("cpu")
    torch.cuda.empty_cache()

# Gradient checkpointing
if hasattr(pipe.transformer, "enable_gradient_checkpointing"):
    pipe.transformer.enable_gradient_checkpointing()

print("transformer.forward signature:", inspect.signature(pipe.transformer.forward))
print("Pipeline ready.")


Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

t_scale=1000.0, vae_sf=0.3611, vae_sh=0.1159
transformer.forward signature: (x: List[torch.Tensor], t, cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, return_dict: bool = True)
Pipeline ready.


## 2b) DEBUG: Introspect transformer & dump shapes to files

In [None]:
import inspect, json, traceback

DEBUG_DIR = Path("/content/debug_shapes")
DEBUG_DIR.mkdir(exist_ok=True)

# --- 1. Forward signature ---
sig = inspect.signature(pipe.transformer.forward)
sig_str = str(sig)
print("Forward signature:", sig_str)
with open(DEBUG_DIR / "forward_signature.txt", "w") as f:
    f.write(sig_str + "\n")
    for name, p in sig.parameters.items():
        f.write(f"  {name}: default={p.default}, kind={p.kind}\n")

# --- 2. Source of patchify_and_embed (if accessible) ---
try:
    src = inspect.getsource(pipe.transformer.patchify_and_embed)
    with open(DEBUG_DIR / "patchify_and_embed_source.py", "w") as f:
        f.write(src)
    print("Saved patchify_and_embed source")
except Exception as e:
    print(f"Could not get patchify_and_embed source: {e}")

# --- 3. Source of forward ---
try:
    # unwrap PEFT if present
    base = pipe.transformer
    if hasattr(base, "base_model"):
        base = base.base_model
    if hasattr(base, "model"):
        base = base.model
    src = inspect.getsource(type(base).forward)
    with open(DEBUG_DIR / "forward_source.py", "w") as f:
        f.write(src)
    print("Saved forward source (", len(src), "chars)")
except Exception as e:
    print(f"Could not get forward source: {e}")

# --- 4. Config ---
config_dict = dict(pipe.transformer.config) if hasattr(pipe.transformer.config, '__iter__') else {}
print("Transformer config:", json.dumps(config_dict, indent=2, default=str))
with open(DEBUG_DIR / "transformer_config.json", "w") as f:
    json.dump(config_dict, f, indent=2, default=str)

# --- 5. Named modules (first 40) ---
module_names = [name for name, _ in pipe.transformer.named_modules()]
with open(DEBUG_DIR / "named_modules.txt", "w") as f:
    for n in module_names:
        f.write(n + "\n")
print(f"Total named modules: {len(module_names)}")
print("First 30:", module_names[:30])

# --- 6. Test dummy forward to discover shapes ---
print("\n--- Testing dummy forward ---")
B, C, H, W = 1, 16, 64, 64  # 512x512 image at 8x compression
dummy_latent = torch.randn(C, 1, H, W, device=device, dtype=torch_dtype)
cap_dim = int(getattr(pipe.transformer.config, "cap_feat_dim", 2560))

# Prepare cap_feats: extract 2D tensor (seq_len, cap_dim)
if null_cap_feats is not None:
    cf = null_cap_feats
    while cf.ndim > 2: cf = cf[0]
    if cf.shape[-1] != cap_dim:
        cf = torch.zeros(1, cap_dim, device=device, dtype=torch_dtype)
else:
    cf = torch.zeros(1, cap_dim, device=device, dtype=torch_dtype)

print(f"cap_feats_2d shape: {tuple(cf.shape)}")
t_dummy = torch.tensor([0.5 * t_scale], device=device, dtype=torch_dtype)

# Try: all_image = [tensor(C,1,H,W)], all_cap_feats = [tensor(seq,dim)]
try:
    with torch.no_grad():
        out = pipe.transformer(
            [dummy_latent],        # List of one (C,1,H,W) tensor
            t_dummy,               # (1,) timestep
            [cf],                  # List of one (seq_len, cap_dim) tensor
            return_dict=True,
        )
    result = out.sample if hasattr(out, "sample") else out[0] if isinstance(out, (tuple,list)) else out
    if isinstance(result, list):
        print(f"SUCCESS: output is list of {len(result)} tensors, first shape: {tuple(result[0].shape)}")
        info = f"list of {len(result)}, shapes: {[tuple(r.shape) for r in result]}"
    else:
        print(f"SUCCESS: output tensor shape: {tuple(result.shape)}")
        info = f"tensor shape: {tuple(result.shape)}"
    with open(DEBUG_DIR / "dummy_forward_result.txt", "w") as f:
        f.write(f"Input: all_image=[({C},1,{H},{W})], t=({t_dummy.shape}), all_cap_feats=[{tuple(cf.shape)}]\n")
        f.write(f"Output: {info}\n")
except Exception as e:
    print(f"FAILED with List[(C,1,H,W)]: {e}")
    traceback.print_exc()
    with open(DEBUG_DIR / "dummy_forward_error.txt", "w") as f:
        f.write(traceback.format_exc())

    # Fallback: try 5D tensor (B,C,F,H,W)
    try:
        dummy_5d = torch.randn(1, C, 1, H, W, device=device, dtype=torch_dtype)
        with torch.no_grad():
            out = pipe.transformer(
                [dummy_5d],
                t_dummy,
                [cf],
                return_dict=True,
            )
        result = out.sample if hasattr(out, "sample") else out
        print(f"FALLBACK 5D SUCCESS: {type(result)}, shape: {tuple(result.shape) if torch.is_tensor(result) else 'list'}")
    except Exception as e2:
        print(f"FALLBACK 5D ALSO FAILED: {e2}")

torch.cuda.empty_cache()
print("\nDebug files saved to:", DEBUG_DIR)
print("Files:", sorted([p.name for p in DEBUG_DIR.iterdir()]))


## 3) (Optional) Ensure zL.pt exists

In [None]:
from PIL import Image
from tqdm.auto import tqdm

@torch.no_grad()
def vae_encode_latents(pil_img: Image.Image):
    img = np.array(pil_img).astype(np.float32) / 255.0
    img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
    img = img * 2.0 - 1.0
    img = img.to(device=device, dtype=torch_dtype)
    latents = pipe.vae.encode(img).latent_dist.sample()
    return latents * vae_sf  # match Phase 1 convention (scale only, no shift)

created = 0
for d in tqdm(sorted([p for p in PAIRS_DIR.iterdir() if p.is_dir()]), desc="zL.pt"):
    zl = d / "zL.pt"
    if zl.exists():
        continue
    lr_up = d / "lr_up.png"
    if not lr_up.exists():
        continue
    zL = vae_encode_latents(Image.open(lr_up).convert("RGB")).cpu()
    torch.save(zL, zl)
    created += 1

print("Created zL.pt:", created)
torch.cuda.empty_cache()


## 4) Dataset + call_transformer + vae_decode

In [None]:
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

# ──────────────────────────────────────────────────────────
# Dataset
# ──────────────────────────────────────────────────────────
class FTDPairDataset(Dataset):
    def __init__(self, pairs_dir: Path, load_pixels: bool):
        self.load_pixels = load_pixels
        self.items = []
        for d in sorted(pairs_dir.iterdir()):
            if not d.is_dir():
                continue
            if (d/"eps.pt").exists() and (d/"z0.pt").exists() and (d/"zL.pt").exists():
                if load_pixels and not (d/"x0.png").exists():
                    continue
                self.items.append(d)
        if not self.items:
            raise RuntimeError("No valid samples found.")

    def __len__(self): return len(self.items)

    def __getitem__(self, idx):
        d = self.items[idx]
        eps = torch.load(d/"eps.pt", map_location="cpu").squeeze(0)  # (C, H, W)
        z0  = torch.load(d/"z0.pt",  map_location="cpu").squeeze(0)
        zL  = torch.load(d/"zL.pt",  map_location="cpu").squeeze(0)
        out = {"eps": eps, "z0": z0, "zL": zL}
        if self.load_pixels:
            from PIL import Image as _Img
            x0 = _Img.open(d/"x0.png").convert("RGB")
            x0 = np.array(x0).astype(np.float32) / 255.0
            x0 = torch.from_numpy(x0).permute(2, 0, 1)
            out["x0_pixels"] = x0
        return out

def ftd_collate(batch):
    out = {k: torch.stack([b[k] for b in batch]) for k in batch[0]}
    return out

# ──────────────────────────────────────────────────────────
# call_transformer — Z-Image format conversion
# ──────────────────────────────────────────────────────────
# Z-Image transformer.forward expects:
#   all_image:     List[Tensor(C, F, H, W)] — one 4D tensor per batch sample, F=1
#   t:             Tensor — timestep(s)
#   all_cap_feats: List[Tensor(seq_len, cap_dim)] — one 2D tensor per batch sample
#
# We convert from our (B, C, H, W) training tensors to this format.
# ──────────────────────────────────────────────────────────

def call_transformer(transformer, *, latents, timestep, cap_feats_2d):
    """
    Args:
        latents:     (B, C, H, W) float tensor
        timestep:    (B,) float tensor (already multiplied by t_scale)
        cap_feats_2d: (seq_len, cap_dim) — single 2D tensor, replicated per sample
    Returns:
        (B, C, H, W) velocity prediction
    """
    B = latents.shape[0]

    # Convert (B, C, H, W) -> List of (C, 1, H, W) per sample
    all_image = [latents[i].unsqueeze(1) for i in range(B)]  # each: (C, 1, H, W)

    # Cap feats: one per sample
    all_cap_feats = [cap_feats_2d for _ in range(B)]  # each: (seq_len, cap_dim)

    out = transformer(
        all_image,
        timestep,
        all_cap_feats,
        return_dict=False,
    )

    # Unwrap output
    if isinstance(out, (tuple, list)):
        result = out[0]
    else:
        result = out

    # Output may be list of per-sample tensors or a stacked tensor
    if isinstance(result, list):
        # Each element: (C, 1, H, W) or (C, H, W)
        processed = []
        for r in result:
            if r.ndim == 4 and r.shape[1] == 1:
                r = r.squeeze(1)  # (C, 1, H, W) -> (C, H, W)
            processed.append(r)
        result = torch.stack(processed, dim=0)  # (B, C, H, W)
    elif result.ndim == 5 and result.shape[2] == 1:
        result = result.squeeze(2)  # (B, C, 1, H, W) -> (B, C, H, W)

    return result


def vae_decode(pipe, z):
    """Decode latents -> pixels [0,1]. Force bf16 via autocast."""
    z_raw = z / vae_sf
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
        x = pipe.vae.decode(z_raw).sample
    return (x / 2 + 0.5).clamp(0, 1)

print("vae_decode patched with autocast")

print("Helpers defined.")


vae_decode patched with autocast
Helpers defined.


## 5) TV-LPIPS + ADL losses

In [None]:
import lpips

def total_variation_filter(x):
    dh = torch.abs(x[:, :, 1:, :-1] - x[:, :, :-1, :-1])
    dw = torch.abs(x[:, :, :-1, 1:] - x[:, :, :-1, :-1])
    return dh + dw

class TVLPIPSLoss(torch.nn.Module):
    def __init__(self, gamma=0.5):
        super().__init__()
        self.gamma = gamma
        self.lpips_fn = lpips.LPIPS(net="vgg")
        for p in self.lpips_fn.parameters():
            p.requires_grad_(False)

    def forward(self, x, y):
        xn, yn = x * 2 - 1, y * 2 - 1
        loss1 = self.lpips_fn(xn, yn).mean()
        tvx, tvy = total_variation_filter(x), total_variation_filter(y)
        tv_max = max(tvx.detach().max().item(), tvy.detach().max().item(), 1e-6)
        tvx = (tvx / tv_max).clamp(0, 1)
        tvy = (tvy / tv_max).clamp(0, 1)
        loss2 = self.lpips_fn(tvx * 2 - 1, tvy * 2 - 1).mean()
        return loss1 + self.gamma * loss2

print("Losses defined.")


Losses defined.


## 6) FTD Training Loop

FluxSR FTD (Eq.16/17/18/21). Key points:
- LoRA on transformer attention layers only
- `call_transformer` properly converts to Z-Image's List format
- FTD loss on randomly sampled t ∈ [TL, 1]
- Pixel recon loss every N steps


In [None]:
pipe.vae.config.force_upcast = False
print("Disabled VAE force_upcast")

Disabled VAE force_upcast


In [None]:
import os, inspect, math
from pathlib import Path
from accelerate import Accelerator
from peft import LoraConfig, get_peft_model
from tqdm.auto import tqdm

# ── Hyperparams ── for 40GB
TL          = 0.25 # 0.15
BATCH_SIZE  = 4 # 1
GRAD_ACCUM  = 2 # 8
LR          = 5e-5 # 1e-4
MAX_STEPS   = 750
LOG_EVERY   = 20
SAVE_EVERY  = 150 #500
REC_LOSS_EVERY = 8      # set to 0 to disable, increase if OOM
LAMBDA_TVLPIPS = 1.0
GAMMA_TV       = 0.5

SAVE_DIR = Path("/content/drive/MyDrive/zimage_sr_lora_runs/ftd_run_v2")
SAVE_DIR.mkdir(parents=True, exist_ok=True)

accelerator = Accelerator(mixed_precision="no", gradient_accumulation_steps=GRAD_ACCUM)
device = accelerator.device

# ── Safety: prevent double LoRA ──
# ── Safety: remove existing LoRA if re-running ──
if hasattr(pipe.transformer, "peft_config"):
    print("Reloading clean transformer...")
    from diffusers import ZImageTransformer2DModel
    pipe.transformer = ZImageTransformer2DModel.from_pretrained(
        MODEL_ID, subfolder="transformer", torch_dtype=torch_dtype
    ).to(device)
    pipe.transformer.requires_grad_(False)
    if hasattr(pipe.transformer, "enable_gradient_checkpointing"):
        pipe.transformer.enable_gradient_checkpointing()

# ── LoRA setup ──
pipe.transformer.requires_grad_(False)

# Find attention Linear layers (avoid norms, embeds, time projections, PEFT internals)
def find_lora_targets(model):
    names = []
    for name, mod in model.named_modules():
        if not isinstance(mod, torch.nn.Linear):
            continue
        if any(x in name for x in ["lora_", "base_layer", "peft_", "norm", "embed", "time", "t_embed", "pos"]):
            continue
        if any(x in name for x in ["attention", "attn", "to_q", "to_k", "to_v", "to_out", "proj", "mlp", "ff"]):
            names.append(name)
    return sorted(set(names))

targets = find_lora_targets(pipe.transformer)
if not targets:
    # Fallback: all Linear layers except norms/embeds
    targets = sorted(set(
        name for name, mod in pipe.transformer.named_modules()
        if isinstance(mod, torch.nn.Linear) and not any(x in name for x in ["lora_", "base_layer", "norm", "embed"])
    ))

print(f"[LoRA] Targeting {len(targets)} layers. First 15:")
for t in targets[:15]: print(f"  {t}")

lora_cfg = LoraConfig(r=16, lora_alpha=16, lora_dropout=0.0, bias="none", target_modules=targets)
pipe.transformer = get_peft_model(pipe.transformer, lora_cfg)
pipe.transformer.print_trainable_parameters()

# ── Conditioning: prepare cap_feats_2d ──
cap_dim = int(getattr(pipe.transformer.config, "cap_feat_dim", 2560))

if null_cap_feats is not None:
    cf = null_cap_feats.to(device=device, dtype=torch_dtype)
    while cf.ndim > 2: cf = cf[0]         # (B, S, D) -> (S, D)
    if cf.ndim == 1: cf = cf.unsqueeze(0)  # (D,) -> (1, D)
    if cf.shape[-1] != cap_dim:
        print(f"[WARN] cap_feats dim {cf.shape[-1]} != {cap_dim}, using zeros")
        cf = torch.zeros(1, cap_dim, device=device, dtype=torch_dtype)
else:
    cf = torch.zeros(1, cap_dim, device=device, dtype=torch_dtype)

cap_feats_2d = cf.detach()
print(f"cap_feats_2d: {tuple(cap_feats_2d.shape)}, cap_dim={cap_dim}")

# ── TV-LPIPS ──
tv_lpips = TVLPIPSLoss(gamma=GAMMA_TV).eval()
for p in tv_lpips.parameters(): p.requires_grad_(False)

# ── Dataset ──
ds = FTDPairDataset(PAIRS_DIR, load_pixels=(REC_LOSS_EVERY > 0))
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,
                collate_fn=ftd_collate, drop_last=True)
print(f"FTDPairDataset: {len(ds)} samples")

# ── Optimizer ──
optimizer = torch.optim.AdamW(
    [p for p in pipe.transformer.parameters() if p.requires_grad], lr=LR
)

# ── Accelerator prepare ──
pipe.transformer, optimizer, dl = accelerator.prepare(pipe.transformer, optimizer, dl)

# ── TL constants ──
TL_bc = torch.tensor([TL], device=device, dtype=torch_dtype).view(1, 1, 1, 1)

# ── Training ──
global_step = 0
pipe.transformer.train()
pipe.vae.eval()

loss_log = {"ftd": 0.0, "rec": 0.0, "total": 0.0}
pbar = tqdm(total=MAX_STEPS, desc="FluxSR-FTD")

while global_step < MAX_STEPS:
    for batch in dl:
        if global_step >= MAX_STEPS:
            break

        eps = batch["eps"].to(device=device, dtype=torch_dtype)
        z0  = batch["z0"].to(device=device, dtype=torch_dtype)
        zL  = batch["zL"].to(device=device, dtype=torch_dtype)
        B = eps.shape[0]

        u_t = eps - z0  # Algorithm 1 line 6

        with accelerator.accumulate(pipe.transformer):

            # ── FTD Loss (Eq. 16/17) ──
            t = torch.rand(B, device=device, dtype=torch_dtype) * (1.0 - TL) + TL
            t_bc = t.view(B, 1, 1, 1)

            # Eq. 16: interpolation
            x_t = ((1.0 - t_bc) / (1.0 - TL)) * zL + ((t_bc - TL) / (1.0 - TL)) * eps

            v_theta = call_transformer(
                pipe.transformer,
                latents=x_t,
                timestep=t * t_scale,
                cap_feats_2d=cap_feats_2d,
            )

            # Eq. 17
            ftd_pred   = u_t - v_theta * TL_bc
            ftd_target = eps - zL
            L_FTD = F.mse_loss(ftd_pred, ftd_target)

            # ── Recon Loss (Eq. 18/21) ──
            L_Rec = torch.tensor(0.0, device=device)
            do_rec = (REC_LOSS_EVERY > 0) and (global_step % REC_LOSS_EVERY == 0) and ("x0_pixels" in batch)

            DETACH_LPIPS = True  # Set False to backprop through LPIPS (uses ~40% more VRAM)
            DETACH_RECON = True  # True = gradient-free recon (stable), False = full gradients (may OOM/dtype issues)

            if do_rec:
                x_HR = batch["x0_pixels"].to(device=device, dtype=torch_dtype)
                TL_t = torch.full((B,), TL * t_scale, device=device, dtype=torch_dtype)

                if DETACH_RECON:
                    with torch.no_grad():
                        v_TL = call_transformer(
                            pipe.transformer, latents=zL, timestep=TL_t, cap_feats_2d=cap_feats_2d,
                        )
                        z0_hat = zL - v_TL * TL_bc
                        x0_hat = vae_decode(pipe, z0_hat)

                        L_MSE = F.mse_loss(x0_hat, x_HR)

                        tv_lpips.to(device)
                        L_TVLP = tv_lpips(x0_hat.float(), x_HR.float())
                        tv_lpips.to("cpu")

                    L_Rec = (L_MSE + LAMBDA_TVLPIPS * L_TVLP).to(device=device, dtype=torch_dtype)
                else:
                    v_TL = call_transformer(
                        pipe.transformer, latents=zL, timestep=TL_t, cap_feats_2d=cap_feats_2d,
                    )
                    z0_hat = zL - v_TL * TL_bc
                    x0_hat = vae_decode(pipe, z0_hat)

                    L_MSE = F.mse_loss(x0_hat, x_HR)

                    tv_lpips.to(device)
                    L_TVLP = tv_lpips(x0_hat, x_HR)
                    tv_lpips.to("cpu")
                    torch.cuda.empty_cache()

                    L_Rec = L_MSE + LAMBDA_TVLPIPS * L_TVLP

            loss = L_FTD + L_Rec

            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

        # ── Logging ──
        global_step += 1
        loss_log["ftd"]   += L_FTD.item()
        loss_log["rec"]   += L_Rec.item()
        loss_log["total"] += loss.item()

        pbar.update(1)
        if global_step % LOG_EVERY == 0:
            n = LOG_EVERY
            pbar.set_postfix({
                "ftd": f"{loss_log['ftd']/n:.4f}",
                "rec": f"{loss_log['rec']/n:.4f}",
                "tot": f"{loss_log['total']/n:.4f}",
            })
            loss_log = {k: 0.0 for k in loss_log}

        if global_step % SAVE_EVERY == 0:
            sp = SAVE_DIR / f"lora_step_{global_step}"
            sp.mkdir(parents=True, exist_ok=True)
            accelerator.unwrap_model(pipe.transformer).save_pretrained(sp)
            ok = (sp / "adapter_config.json").exists()
            print(f"\n[Step {global_step}] Saved: {sp} ({'OK' if ok else 'MISSING adapter_config!'})")

            # --- quick inference ---
            try:
                pipe.transformer.eval()
                with torch.no_grad():
                    sample = ds[0]  # first pair
                    zL_s = sample["zL"].unsqueeze(0).to(device=device, dtype=torch_dtype)
                    TL_t = torch.tensor([TL], device=device, dtype=torch_dtype)
                    TL_bc = TL_t.view(1,1,1,1)

                    v = call_transformer(pipe.transformer, latents=zL_s, timestep=TL_t, cap_feats_2d=null_cap_feats)
                    z0_hat = zL_s - v * TL_bc
                    img = vae_decode(pipe, z0_hat)[0]  # (C,H,W)
                    img_pil = TF.to_pil_image(img.clamp(0,1).float().cpu())

                    out_path = SAVE_DIR / f"sr_step_{global_step}.png"
                    img_pil.save(out_path)
                    print(f"[Step {global_step}] Inference saved: {out_path}")

                    # display inline if notebook
                    try:
                        from IPython.display import display
                        display(img_pil.resize((512,512), Image.LANCZOS))
                    except: pass
                pipe.transformer.train()
            except Exception as e:
                print(f"[Step {global_step}] Inference failed: {e}")
                pipe.transformer.train()

pbar.close()

# Final save
final = SAVE_DIR / "lora_final"
final.mkdir(parents=True, exist_ok=True)
accelerator.unwrap_model(pipe.transformer).save_pretrained(final)
print("Final LoRA saved:", final)
print("Files:", sorted([p.name for p in final.iterdir()]))


[LoRA] Targeting 136 layers. First 15:
  context_refiner.0.attention.to_k
  context_refiner.0.attention.to_out.0
  context_refiner.0.attention.to_q
  context_refiner.0.attention.to_v
  context_refiner.1.attention.to_k
  context_refiner.1.attention.to_out.0
  context_refiner.1.attention.to_q
  context_refiner.1.attention.to_v
  layers.0.attention.to_k
  layers.0.attention.to_out.0
  layers.0.attention.to_q
  layers.0.attention.to_v
  layers.1.attention.to_k
  layers.1.attention.to_out.0
  layers.1.attention.to_q
trainable params: 16,711,680 || all params: 6,171,620,416 || trainable%: 0.2708
cap_feats_2d: (1, 2560), cap_dim=2560
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/vgg.pth
FTDPairDataset: 1199 samples


FluxSR-FTD:   0%|          | 0/3000 [00:00<?, ?it/s]


[Step 500] Saved: /content/drive/MyDrive/zimage_sr_lora_runs/ftd_run_v2/lora_step_500 (OK)


KeyboardInterrupt: 

## 7) One-step inference sanity check

In [None]:
import random
from PIL import Image, ImageDraw
from peft import PeftModel
import torchvision.transforms.functional as TF

# Re-encode null conditioning if needed
if null_cap_feats is None or not isinstance(null_cap_feats, torch.Tensor):
    cap_dim = getattr(pipe.transformer, 'config', None)
    if cap_dim and hasattr(cap_dim, 'cap_feat_dim'):
        cap_dim = cap_dim.cap_feat_dim
    else:
        cap_dim = 2560
    null_cap_feats = torch.zeros(1, cap_dim, device=device, dtype=torch_dtype)
    print(f"null_cap_feats: {null_cap_feats.shape}")

# Pick random sample
d = random.choice([p for p in PAIRS_DIR.iterdir() if p.is_dir()])
print("Sample:", d.name)

zL = torch.load(d/"zL.pt", map_location="cpu").to(device=device, dtype=torch_dtype)
if zL.ndim == 3: zL = zL.unsqueeze(0)

TL_t  = torch.tensor([TL], device=device, dtype=torch_dtype)
TL_bc = TL_t.view(1,1,1,1)

# --- Unwrap to clean base transformer ---
base_tr = accelerator.unwrap_model(pipe.transformer)
if hasattr(base_tr, 'base_model'):  # it's a PeftModel from training
    base_tr = base_tr.base_model.model  # raw ZImageTransformer2DModel
    print("Unwrapped PeftModel to base transformer")

# --- Base model SR (no LoRA) ---
base_tr.eval()
with torch.no_grad():
    v_base = call_transformer(base_tr, latents=zL, timestep=TL_t, cap_feats_2d=null_cap_feats)
    z0_base = zL - v_base * TL_bc
    base_img = vae_decode(pipe, z0_base)[0]
    lr_dec   = vae_decode(pipe, zL)[0]

# --- Load LoRA on clean base and run SR ---
lora_path = str(SAVE_DIR / "lora_step_500")  # or lora_final
lora_tr = PeftModel.from_pretrained(base_tr, lora_path).to(device)
lora_tr.eval()
print("LoRA loaded cleanly")

with torch.no_grad():
    v_lora = call_transformer(lora_tr, latents=zL, timestep=TL_t, cap_feats_2d=null_cap_feats)
    z0_lora = zL - v_lora * TL_bc
    lora_img = vae_decode(pipe, z0_lora)[0]

# --- Restore pipe.transformer to LoRA version for further use ---
pipe.transformer = lora_tr

def to_pil(t): return TF.to_pil_image(t.clamp(0,1).float().cpu())

hr_pil = Image.open(d/"x0.png").convert("RGB")

# Side-by-side: LR | Base SR | LoRA SR | HR
imgs   = [to_pil(lr_dec), to_pil(base_img), to_pil(lora_img), hr_pil]
labels = ["LR (decoded)", "Base model SR", "LoRA SR", "HR ground truth"]

W, H = imgs[0].size
canvas = Image.new("RGB", (W*4, H+24), "white")
draw = ImageDraw.Draw(canvas)
for i, (img, lbl) in enumerate(zip(imgs, labels)):
    canvas.paste(img.resize((W, H)), (W*i, 0))
    draw.text((W*i + 4, H+4), lbl, fill="black")

canvas.save(SAVE_DIR / "eval_base_vs_lora.png")
display(canvas)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
# Add this right after both inferences:
diff = (v_lora - v_base).abs().mean().item()
print(f"v_base vs v_lora mean abs diff: {diff}")

v_base vs v_lora mean abs diff: 0.0


In [None]:
import safetensors
from pathlib import Path

lora_path = str(SAVE_DIR / "lora_step_500_fixed")
# Check saved key names
st_file = list(Path(lora_path).glob("*.safetensors"))[0]
keys = list(safetensors.safe_open(str(st_file), framework="pt").keys())
print("First 5 keys:", keys[:5])
print("Total keys:", len(keys))

First 5 keys: ['base_model.model.context_refiner.0.attention.to_k.lora_A.weight', 'base_model.model.context_refiner.0.attention.to_k.lora_B.weight', 'base_model.model.context_refiner.0.attention.to_out.0.lora_A.weight', 'base_model.model.context_refiner.0.attention.to_out.0.lora_B.weight', 'base_model.model.context_refiner.0.attention.to_q.lora_A.weight']
Total keys: 272


In [None]:
from safetensors.torch import load_file

state = load_file(str(st_file))
# Strip double nesting
fixed = {}
for k, v in state.items():
    k_fixed = k.replace("base_model.model.base_model.model.", "base_model.model.")
    fixed[k_fixed] = v

# Save fixed weights
from safetensors.torch import save_file
fixed_path = SAVE_DIR / "lora_step_500_fixed"
fixed_path.mkdir(exist_ok=True)
save_file(fixed, str(fixed_path / st_file.name))
# Copy adapter config
import shutil
shutil.copy(Path(lora_path)/"adapter_config.json", fixed_path/"adapter_config.json")

# Now load fixed LoRA
lora_tr = PeftModel.from_pretrained(base_tr, str(fixed_path)).to(device)

In [None]:
# Check if LoRA weights are actually non-zero after loading
lora_tr = PeftModel.from_pretrained(base_tr, str(SAVE_DIR / "lora_step_500")).to(device)

for name, param in lora_tr.named_parameters():
    if 'lora_A' in name:
        print(f"{name}: mean={param.abs().mean().item():.6f}")
        break

for name, param in lora_tr.named_parameters():
    if 'lora_B' in name:
        print(f"{name}: mean={param.abs().mean().item():.6f}")
        break

# Check active adapter
print("Active adapter:", lora_tr.active_adapter)
print("Adapter enabled:", not lora_tr.disabled)

base_model.model.base_model.model.noise_refiner.0.attention.to_q.lora_A.default.weight: mean=0.008084
base_model.model.base_model.model.noise_refiner.0.attention.to_q.lora_B.default.weight: mean=0.000000
Active adapter: default


AttributeError: 'ZImageTransformer2DModel' object has no attribute 'disabled'

In [None]:
import random
from PIL import Image, ImageDraw
from peft import PeftModel
import torchvision.transforms.functional as TF
from diffusers import ZImageTransformer2DModel

# Re-encode null conditioning if needed
if null_cap_feats is None or not isinstance(null_cap_feats, torch.Tensor):
    null_cap_feats = torch.zeros(1, 2560, device=device, dtype=torch_dtype)

# Pick random sample
d = random.choice([p for p in PAIRS_DIR.iterdir() if p.is_dir()])
print("Sample:", d.name)

zL = torch.load(d/"zL.pt", map_location="cpu").to(device=device, dtype=torch_dtype)
if zL.ndim == 3: zL = zL.unsqueeze(0)

TL_t  = torch.tensor([TL], device=device, dtype=torch_dtype)
TL_bc = TL_t.view(1,1,1,1)

# --- Load FRESH base transformer (no PEFT wrapping) ---
base_tr = ZImageTransformer2DModel.from_pretrained(
    MODEL_ID, subfolder="transformer", torch_dtype=torch_dtype
).to(device)
base_tr.eval()

# --- Base model SR ---
with torch.no_grad():
    v_base = call_transformer(base_tr, latents=zL, timestep=TL_t, cap_feats_2d=null_cap_feats)
    z0_base = zL - v_base * TL_bc
    base_img = vae_decode(pipe, z0_base)[0]
    lr_dec   = vae_decode(pipe, zL)[0]

# --- Load LoRA on clean base ---
lora_path = str(SAVE_DIR / "lora_step_500")
lora_tr = PeftModel.from_pretrained(base_tr, lora_path).to(device)
lora_tr.eval()

# Verify weights actually loaded
for name, param in lora_tr.named_parameters():
    if 'lora_B' in name:
        print(f"{name}: mean={param.abs().mean().item():.6f}")
        break

with torch.no_grad():
    v_lora = call_transformer(lora_tr, latents=zL, timestep=TL_t, cap_feats_2d=null_cap_feats)
    z0_lora = zL - v_lora * TL_bc
    lora_img = vae_decode(pipe, z0_lora)[0]

diff = (v_lora - v_base).abs().mean().item()
print(f"v_base vs v_lora mean abs diff: {diff}")

# --- Cleanup: free base_tr, keep lora_tr ---
del base_tr
pipe.transformer = lora_tr

def to_pil(t): return TF.to_pil_image(t.clamp(0,1).float().cpu())

hr_pil = Image.open(d/"x0.png").convert("RGB")

imgs   = [to_pil(lr_dec), to_pil(base_img), to_pil(lora_img), hr_pil]
labels = ["LR (decoded)", "Base model SR", "LoRA SR", "HR ground truth"]

W, H = imgs[0].size
canvas = Image.new("RGB", (W*4, H+24), "white")
draw = ImageDraw.Draw(canvas)
for i, (img, lbl) in enumerate(zip(imgs, labels)):
    canvas.paste(img.resize((W, H)), (W*i, 0))
    draw.text((W*i + 4, H+4), lbl, fill="black")

canvas.save(SAVE_DIR / "eval_base_vs_lora.png")
display(canvas)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
# Free old transformer
try:
    if hasattr(pipe, 'transformer') and pipe.transformer is not None:
        if hasattr(pipe.transformer, 'to'):
            pipe.transformer.to("cpu")
        del pipe.transformer
        gc.collect()
        torch.cuda.empty_cache()
except:
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
import gc
pipe.transformer = None
gc.collect()
torch.cuda.empty_cache()

In [None]:
import random, gc
from PIL import Image, ImageDraw
from peft import PeftModel
import torchvision.transforms.functional as TF
from diffusers import ZImageTransformer2DModel

# Re-encode null conditioning if needed
if null_cap_feats is None or not isinstance(null_cap_feats, torch.Tensor):
    null_cap_feats = torch.zeros(1, 2560, device=device, dtype=torch_dtype)

# Pick random sample
d = random.choice([p for p in PAIRS_DIR.iterdir() if p.is_dir()])
print("Sample:", d.name)

zL = torch.load(d/"zL.pt", map_location="cpu").to(device=device, dtype=torch_dtype)
if zL.ndim == 3: zL = zL.unsqueeze(0)

# Free old transformer
if hasattr(pipe, 'transformer') and pipe.transformer is not None:
    pipe.transformer.to("cpu")
    del pipe.transformer
    gc.collect()
    torch.cuda.empty_cache()

# --- Load FRESH base transformer ---
base_tr = ZImageTransformer2DModel.from_pretrained(
    MODEL_ID, subfolder="transformer", torch_dtype=torch_dtype
).to(device)
base_tr.eval()

# --- Load LoRA on clean base ---
lora_path = str(SAVE_DIR / "lora_step_500")
lora_tr = PeftModel.from_pretrained(base_tr, lora_path).to(device)
lora_tr.eval()

for name, param in lora_tr.named_parameters():
    if 'lora_B' in name:
        print(f"{name}: mean={param.abs().mean().item():.6f}")
        break

def to_pil(t): return TF.to_pil_image(t.clamp(0,1).float().cpu())

# --- Decode LR and HR once ---
with torch.no_grad():
    lr_dec = vae_decode(pipe, zL)[0]
hr_pil = Image.open(d/"x0.png").convert("RGB")

# --- Test multiple TL values ---
test_TLs = [0.15, 0.25, 0.35, 0.50]
lora_imgs = []

for tl in test_TLs:
    TL_t  = torch.tensor([tl], device=device, dtype=torch_dtype)
    TL_bc = TL_t.view(1,1,1,1)
    with torch.no_grad():
        v = call_transformer(lora_tr, latents=zL, timestep=TL_t, cap_feats_2d=null_cap_feats)
        z0_hat = zL - v * TL_bc
        img = vae_decode(pipe, z0_hat)[0]
        lora_imgs.append(to_pil(img))
    print(f"TL={tl} done")

# --- Build canvas: LR | TL=0.15 | TL=0.25 | TL=0.35 | TL=0.50 | HR ---
all_imgs  = [to_pil(lr_dec)] + lora_imgs + [hr_pil]
all_labels = ["LR"] + [f"TL={tl}" for tl in test_TLs] + ["HR GT"]

W, H = all_imgs[0].size
n = len(all_imgs)
canvas = Image.new("RGB", (W*n, H+24), "white")
draw = ImageDraw.Draw(canvas)
for i, (img, lbl) in enumerate(zip(all_imgs, all_labels)):
    canvas.paste(img.resize((W, H)), (W*i, 0))
    draw.text((W*i + 4, H+4), lbl, fill="black")

canvas.save(SAVE_DIR / "eval_TL_sweep.png")
display(canvas)

# Keep lora_tr
del base_tr
pipe.transformer = lora_tr

Output hidden; open in https://colab.research.google.com to view.