
# JP‑GAN (E1) — FFHQ‑128 — **I‑JEPA enabled (HF)** + metrics + 2×GPU + baseline vs JP‑GAN

**Now with native I‑JEPA support via 🤗Transformers (>= 4.49).**  
This notebook runs two configurations — **baseline** (λ=0) and **JP‑GAN** (λ>0) — and reports **FID/KID/IS**. It will:
- try to load **`facebook/ijepa_vith14_1k`** first (feature extractor),
- if unavailable, automatically **fallback to DINOv2** (or smaller ViT).

> If two T4 GPUs are available, **G** and **D** run on both via `nn.DataParallel`.


In [None]:

# %%capture
!pip -q install --upgrade pip
!pip -q install "transformers>=4.53.3" timm==0.9.12 accelerate==0.33.0
!pip -q install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118
!pip -q install torchmetrics==1.4.0


In [None]:

import os, sys, math, random, glob, time, json
from pathlib import Path
from dataclasses import dataclass
from typing import Tuple, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils as vutils
from PIL import Image

from transformers import AutoProcessor, AutoModel

from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.inception import InceptionScore

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
N_GPUS = torch.cuda.device_count()
print("Device:", DEVICE, "| #GPUs:", N_GPUS)

CFG = {
    "image_size": 128,
    "latent_dim": 256,
    "g_channels": 64,
    "d_channels": 64,
    "batch_size": 64,
    "num_workers": 4,
    "epochs": 2,
    "lr_g": 2e-4,
    "lr_d": 2e-4,
    "betas": (0.0, 0.9),
    "n_crit": 1,
    "lambda_jepa": 1.0,
    "mixed_precision": True,
    "max_steps_per_epoch": None,
    "seed": 42,
    "runs": [
        {"name": "baseline", "lambda_jepa": 0.0},
        {"name": "jpgan",    "lambda_jepa": 1.0},
    ],
    "embedder_preference": ["facebook/ijepa_vith14_1k", "facebook/dinov2-base", "facebook/dino-vits16", "google/vit-base-patch16-224"],
    "save_dir": "/kaggle/working/jp_gan_e1_ijepa",
    "eval_num_real": 2000,
    "eval_num_fake": 2000,
}

os.makedirs(CFG["save_dir"], exist_ok=True)
torch.backends.cudnn.benchmark = True

def set_seed(s):
    random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed(CFG["seed"])

class Timer:
    def __enter__(self): self.t0=time.time(); return self
    def __exit__(self,*a): self.dt=time.time()-self.t0


In [None]:

def find_images(root="/kaggle/input"):
    exts = {".png",".jpg",".jpeg",".webp",".bmp"}
    files = []
    for p in Path(root).glob("*"):
        if not p.is_dir(): 
            continue
        candidates = [p, p/"images", p/"thumbnails128x128", p/"ffhq-128", p/"ffhq"]
        for c in candidates:
            if c.exists():
                for f in c.rglob("*"):
                    if f.suffix.lower() in exts:
                        files.append(str(f))
    return sorted(files)

class ImageFolderFlat(Dataset):
    def __init__(self, files: List[str], size: int = 128):
        self.files = files
        self.transform = transforms.Compose([
            transforms.Resize((size, size), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
        ])
    def __len__(self): return len(self.files)
    def __getitem__(self, idx):
        fp = self.files[idx]
        im = Image.open(fp).convert("RGB")
        return self.transform(im)

files = find_images()
print(f"Found {len(files)} image files under /kaggle/input/*")
assert len(files) > 0, "No images found. Add an FFHQ-128 dataset via 'Add data'."

ds = ImageFolderFlat(files, size=CFG["image_size"])
dl = DataLoader(ds, batch_size=CFG["batch_size"], shuffle=True, num_workers=CFG["num_workers"], pin_memory=True, drop_last=True)


In [None]:

class ResBlockUp(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.skip = nn.Conv2d(in_ch, out_ch, 1)
    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        h = self.bn1(self.conv1(x)); h = F.relu(h, inplace=True)
        h = self.bn2(self.conv2(h)); h = F.relu(h, inplace=True)
        s = self.skip(x)
        return h + s

class Generator(nn.Module):
    def __init__(self, z_dim=256, base=64, img_size=128):
        super().__init__()
        self.init_sz = img_size // 16
        self.fc = nn.Linear(z_dim, base*8*self.init_sz*self.init_sz)
        ch = base*8
        blocks = []
        for _ in range(4):
            blocks.append(ResBlockUp(ch, ch//2))
            ch //= 2
        self.blocks = nn.Sequential(*blocks)
        self.to_rgb = nn.Conv2d(ch, 3, 3, 1, 1)
    def forward(self, z):
        h = self.fc(z).view(z.size(0), -1, self.init_sz, self.init_sz)
        h = self.blocks(h)
        x = torch.tanh(self.to_rgb(h))
        return x

def SN(m): return nn.utils.spectral_norm(m)

class DiscBlock(nn.Module):
    def __init__(self, in_ch, out_ch, down=True):
        super().__init__()
        self.conv1 = SN(nn.Conv2d(in_ch, out_ch, 3, 1, 1))
        self.conv2 = SN(nn.Conv2d(out_ch, out_ch, 3, 1, 1))
        self.down = down
        self.skip = SN(nn.Conv2d(in_ch, out_ch, 1, 1, 0))
    def forward(self, x):
        h = F.leaky_relu(self.conv1(x), 0.2, inplace=True)
        h = F.leaky_relu(self.conv2(h), 0.2, inplace=True)
        if self.down:
            h = F.avg_pool2d(h, 2)
            s = F.avg_pool2d(self.skip(x), 2)
        else:
            s = self.skip(x)
        return h + s

class Discriminator(nn.Module):
    def __init__(self, base=64, img_size=128):
        super().__init__()
        ch = base
        blocks = [DiscBlock(3, ch, down=True)]
        for _ in range(3):
            blocks.append(DiscBlock(ch, ch*2, down=True))
            ch *= 2
        blocks.append(DiscBlock(ch, ch, down=False))
        self.blocks = nn.Sequential(*blocks)
        self.linear = SN(nn.Linear(ch*8*8, 1))
    def forward(self, x):
        h = self.blocks(x)
        h = h.view(h.size(0), -1)
        out = self.linear(h)
        return out.squeeze(1)

G = Generator(z_dim=CFG["latent_dim"], base=CFG["g_channels"], img_size=CFG["image_size"])
D = Discriminator(base=CFG["d_channels"], img_size=CFG["image_size"])

if DEVICE == "cuda" and torch.cuda.device_count() > 1:
    print("Using DataParallel on", torch.cuda.device_count(), "GPUs for G and D")
    G = nn.DataParallel(G).to(DEVICE)
    D = nn.DataParallel(D).to(DEVICE)
else:
    G = G.to(DEVICE); D = D.to(DEVICE)

def unwrap(m):
    return m.module if isinstance(m, nn.DataParallel) else m

print(f"G params: {sum(p.numel() for p in unwrap(G).parameters())/1e6:.2f}M | D params: {sum(p.numel() for p in unwrap(D).parameters())/1e6:.2f}M")


In [None]:

def load_embedder(preference_list):
    last_err = None
    for name in preference_list:
        try:
            print(f"Trying embedder: {name}")
            processor = AutoProcessor.from_pretrained(name, trust_remote_code=True)
            model = AutoModel.from_pretrained(name, trust_remote_code=True)
            model.eval().to(DEVICE)
            def featurize_nograd(x):
                with torch.no_grad():
                    imgs = (x * 0.5 + 0.5).clamp(0,1)
                    pil_list = [transforms.ToPILImage()(img) for img in imgs.cpu()]
                    inputs = processor(images=pil_list, return_tensors="pt").to(DEVICE)
                    out = model(**inputs)
                if hasattr(out, "last_hidden_state"):
                    feats = out.last_hidden_state.mean(dim=1)
                elif hasattr(out, "pooler_output") and out.pooler_output is not None:
                    feats = out.pooler_output
                else:
                    logits = getattr(out, "logits", None)
                    if logits is None:
                        raise RuntimeError("Unsupported model outputs for feature extraction.")
                    feats = logits
                return feats
            def featurize_with_grad(x):
                imgs = (x * 0.5 + 0.5).clamp(0,1)
                pil_list = [transforms.ToPILImage()(img) for img in imgs.detach().cpu()]
                inputs = processor(images=pil_list, return_tensors="pt").to(DEVICE)
                out = model(**inputs)
                if hasattr(out, "last_hidden_state"):
                    feats = out.last_hidden_state.mean(dim=1)
                elif hasattr(out, "pooler_output") and out.pooler_output is not None:
                    feats = out.pooler_output
                else:
                    logits = getattr(out, "logits", None)
                    if logits is None:
                        raise RuntimeError("Unsupported model outputs for feature extraction.")
                    feats = logits
                return feats
            _ = featurize_nograd(torch.randn(2,3,CFG["image_size"],CFG["image_size"], device=DEVICE))
            print(f"Loaded embedder: {name}")
            return processor, model, featurize_nograd, featurize_with_grad, name
        except Exception as e:
            print("Failed to load", name, "because:", repr(e))
            last_err = e
    raise last_err

processor, embedder_model, featurize_nograd, featurize_with_grad, embedder_name = load_embedder(CFG["embedder_preference"])
print("Embedder selected:", embedder_name)


In [None]:

opt_g = torch.optim.Adam(unwrap(G).parameters(), lr=CFG["lr_g"], betas=CFG["betas"])
opt_d = torch.optim.Adam(unwrap(D).parameters(), lr=CFG["lr_d"], betas=CFG["betas"])
scaler = torch.cuda.amp.GradScaler(enabled=CFG["mixed_precision"])

def d_hinge_loss(real_logits, fake_logits):
    return F.relu(1.0 - real_logits).mean() + F.relu(1.0 + fake_logits).mean()

def g_hinge_loss(fake_logits):
    return -fake_logits.mean()

@torch.no_grad()
def sample_grid(G_model, n=64, fname_prefix="samples"):
    was_train = G_model.training
    G_model.eval()
    z = torch.randn(n, CFG["latent_dim"], device=DEVICE)
    x = G_model(z).clamp(-1,1)
    if was_train: G_model.train()
    grid = vutils.make_grid((x*0.5+0.5), nrow=int(math.sqrt(n)))
    out_path = os.path.join(CFG["save_dir"], f"{fname_prefix}_{int(time.time())}.png")
    vutils.save_image(grid, out_path)
    return out_path


In [None]:

def jepa_feature_moment_loss(x_real, x_fake):
    with torch.no_grad():
        f_real = featurize_nograd(x_real)
    f_fake = featurize_with_grad(x_fake)
    mu_r, mu_f = f_real.mean(0), f_fake.mean(0)
    sd_r, sd_f = f_real.std(0),  f_fake.std(0)
    return (mu_r - mu_f).pow(2).mean() + (sd_r - sd_f).pow(2).mean()

def run_training(run_cfg):
    lambda_jepa = run_cfg["lambda_jepa"]
    run_name = run_cfg["name"]
    print(f"\n=== Run: {run_name} | lambda_jepa={lambda_jepa} | embedder={embedder_name} ===")
    global G, D, opt_g, opt_d
    G = Generator(z_dim=CFG["latent_dim"], base=CFG["g_channels"], img_size=CFG["image_size"])
    D = Discriminator(base=CFG["d_channels"], img_size=CFG["image_size"])
    if DEVICE == "cuda" and torch.cuda.device_count() > 1:
        G = nn.DataParallel(G).to(DEVICE)
        D = nn.DataParallel(D).to(DEVICE)
    else:
        G = G.to(DEVICE); D = D.to(DEVICE)
    opt_g = torch.optim.Adam(unwrap(G).parameters(), lr=CFG["lr_g"], betas=CFG["betas"])
    opt_d = torch.optim.Adam(unwrap(D).parameters(), lr=CFG["lr_d"], betas=CFG["betas"])

    step = 0
    for epoch in range(1, CFG["epochs"]+1):
        running = {"d":0.0, "g":0.0, "jepa":0.0}
        with Timer() as t_epoch:
            for i, real in enumerate(dl):
                if CFG["max_steps_per_epoch"] is not None and i >= CFG["max_steps_per_epoch"]:
                    break
                real = real.to(DEVICE, non_blocking=True)
                bs = real.size(0)
                z = torch.randn(bs, CFG["latent_dim"], device=DEVICE)
                for _ in range(CFG["n_crit"]):
                    with torch.cuda.amp.autocast(enabled=CFG["mixed_precision"]):
                        fake = G(z).detach()
                        d_real = D(real)
                        d_fake = D(fake)
                        loss_d = d_hinge_loss(d_real, d_fake)
                    opt_d.zero_grad(set_to_none=True)
                    scaler.scale(loss_d).backward()
                    scaler.step(opt_d)
                z = torch.randn(bs, CFG["latent_dim"], device=DEVICE)
                with torch.cuda.amp.autocast(enabled=CFG["mixed_precision"]):
                    fake = G(z)
                    g_adv = g_hinge_loss(D(fake))
                    jepa = jepa_feature_moment_loss(real, fake) * lambda_jepa
                    loss_g = g_adv + jepa
                opt_g.zero_grad(set_to_none=True)
                scaler.scale(loss_g).backward()
                scaler.step(opt_g)
                scaler.update()

                running["d"] += float(loss_d.detach().cpu())
                running["g"] += float(g_adv.detach().cpu())
                running["jepa"] += float((jepa.detach().cpu()))

                if (step % 200) == 0:
                    path = sample_grid(G, 36, fname_prefix=f"{run_name}_samples")
                    print(f"[{run_name}] ep {epoch:02d} it {i:05d} | D={running['d']/(i+1):.3f}  G={running['g']/(i+1):.3f}  JEPA={running['jepa']/(i+1):.3f}  → {path}")
                step += 1

        ckpt = {
            "G": unwrap(G).state_dict(),
            "D": unwrap(D).state_dict(),
            "opt_g": opt_g.state_dict(),
            "opt_d": opt_d.state_dict(),
            "cfg": CFG,
            "run_cfg": run_cfg,
            "embedder": embedder_name,
        }
        out_ckpt = os.path.join(CFG["save_dir"], f"{run_name}_ckpt_e{epoch:02d}.pt")
        torch.save(ckpt, out_ckpt)
        print(f"[{run_name}] Epoch {epoch} done in {t_epoch.dt:.1f}s. Saved: {out_ckpt}")

    final_img = sample_grid(G, 64, fname_prefix=f"{run_name}_final")
    print(f"[{run_name}] Final sample grid: {final_img}")
    return G


In [None]:

@torch.no_grad()
def evaluate_metrics(G_model, num_real=1000, num_fake=1000, batch_size=64):
    print(f"Evaluating metrics on {num_real} real / {num_fake} fake images...")
    fid = FrechetInceptionDistance(feature=2048, normalize=True).to(DEVICE)
    kid = KernelInceptionDistance(subset_size=1000, normalize=True).to(DEVICE)
    isc = InceptionScore(splits=10).to(DEVICE)

    n_real = 0
    for real in dl:
        real = real.to(DEVICE, non_blocking=True)
        imgs = (real*0.5+0.5).clamp(0,1)
        fid.update(imgs, real=True)
        kid.update(imgs, real=True)
        n_real += real.size(0)
        if n_real >= num_real: break

    n_fake = 0
    while n_fake < num_fake:
        bs = min(batch_size, num_fake - n_fake)
        z = torch.randn(bs, CFG["latent_dim"], device=DEVICE)
        fake = G_model(z)
        imgs = (fake*0.5+0.5).clamp(0,1)
        fid.update(imgs, real=False)
        kid.update(imgs, real=False)
        isc.update(imgs)
        n_fake += bs

    fid_val = float(fid.compute().cpu())
    kid_mean, kid_std = kid.compute()
    is_mean, is_std   = isc.compute()
    res = {
        "FID": fid_val,
        "KID_mean": float(kid_mean.cpu()),
        "KID_std": float(kid_std.cpu()),
        "IS_mean": float(is_mean.cpu()),
        "IS_std": float(is_std.cpu()),
        "num_real": n_real,
        "num_fake": n_fake,
        "embedder": embedder_name,
    }
    print("Metrics:", res)
    return res


In [None]:

all_results = {}
for run in CFG["runs"]:
    G_trained = run_training(run)
    res = evaluate_metrics(G_trained, num_real=CFG["eval_num_real"], num_fake=CFG["eval_num_fake"], batch_size=CFG["batch_size"])
    all_results[run["name"]] = res

summary_path = os.path.join(CFG["save_dir"], "metrics_summary.json")
with open(summary_path, "w") as f:
    json.dump(all_results, f, indent=2)
print("Saved metrics summary to:", summary_path)
print(json.dumps(all_results, indent=2))



### References
- **I‑JEPA in 🤗Transformers** (v4.49+ docs, example with `facebook/ijepa_vith14_1k`).  
- **I‑JEPA model card**: `facebook/ijepa_vith14_1k` (feature extractor).  
- **Original paper (CVPR 2023)**.
