
# JP‑GAN (E1): I‑JEPA‑guided GAN on FFHQ‑128 (Kaggle Notebook)

This notebook implements **Experiment E1 (JP‑GAN)** — a lightweight GAN trained on **FFHQ 128×128** with an **auxiliary I‑JEPA feature‑matching prior** added to the generator loss.

**Idea (JP‑GAN):** alongside the standard adversarial objective, the generator minimizes a **feature moment matching loss** in a *semantic* embedding space extracted by a frozen **I‑JEPA** encoder. Concretely, in each batch we compute embeddings for real images `f(x)` and generated images `f(G(z))` and minimize the difference between **batch means and standard deviations** of these embeddings:
\(
\mathcal{L}_{\text{JEPA}} = \lVert \mu_\text{real}-\mu_\text{fake}\rVert_2^2 + \lVert \sigma_\text{real}-\sigma_\text{fake}\rVert_2^2
\)
The final generator loss is:
\(
\mathcal{L}_G = \mathcal{L}_{\text{adv}} + \lambda \cdot \mathcal{L}_{\text{JEPA}}.
\)

> **Why this variant?** It’s (1) cheap and stable, (2) keeps I‑JEPA **frozen** (no extra training), and (3) avoids computing full FID-like covariances while still nudging the generator toward semantically plausible modes.

## What you need to run this on Kaggle

1. **Add FFHQ‑128 dataset to your notebook** (Kaggle → *Add Data* → search one of:  
   - `dullaz/flickrfaces-dataset-nvidia-128x128`  
   - `potatohd404/ffhq-128-70k`  
   The loader will auto-detect images under `/kaggle/input/*/`.

2. **Turn on Internet** (Notebook *Settings → Internet → On*) so Hugging Face can download the **I‑JEPA** (or fallback) weights the first time.

3. Optional: lower `max_steps_per_epoch` to do a quick sanity run.

> If I‑JEPA is too heavy for your GPU, the code **automatically falls back to a smaller DINO ViT‑S/16** encoder. You can force the embedder via `CFG['embedder_preference']`.


In [None]:

# %%capture
!pip -q install --upgrade pip
!pip -q install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118
!pip -q install transformers==4.43.3 timm==0.9.12 accelerate==0.33.0


In [None]:

import os, sys, math, random, glob, time, json
from pathlib import Path
from dataclasses import dataclass
from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils as vutils
from PIL import Image

from transformers import AutoImageProcessor, AutoModel, AutoModelForImageClassification

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ---- Config ----
CFG = {
    "image_size": 128,
    "latent_dim": 256,
    "g_channels": 64,            # base channels for G
    "d_channels": 64,            # base channels for D
    "batch_size": 64,
    "num_workers": 4,
    "epochs": 2,                 # increase as you wish
    "lr_g": 2e-4,
    "lr_d": 2e-4,
    "betas": (0.0, 0.9),         # good for hinge-GANs
    "n_crit": 1,                 # D steps per G step
    "lambda_jepa": 1.0,          # JEPA prior strength
    "mixed_precision": True,
    "max_steps_per_epoch": None, # e.g., 1000 for quick runs
    "seed": 42,
    # Try IJepa first; if OOM or download issues, we fallback to DINO ViT-S/16
    "embedder_preference": ["facebook/ijepa_vith14_1k", "facebook/ijepa_vith16_1k", "facebook/dino-vits16"],
    "save_dir": "/kaggle/working/jp_gan_e1",
}

os.makedirs(CFG["save_dir"], exist_ok=True)

def set_seed(s):
    random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.benchmark = True
set_seed(CFG["seed"])

# Utility: simple timer
class Timer:
    def __enter__(self):
        import time
        self.t0=time.time(); return self
    def __exit__(self,*args):
        import time
        self.dt=time.time()-self.t0


In [None]:

def find_images(root="/kaggle/input"):
    exts = {".png",".jpg",".jpeg",".webp",".bmp"}
    files = []
    for p in Path(root).glob("*"):
        if not p.is_dir(): 
            continue
        # common subdirs in FFHQ uploads
        candidates = [p, p/"images", p/"thumbnails128x128", p/"ffhq-128", p/"ffhq"]
        for c in candidates:
            if c.exists():
                for f in c.rglob("*"):
                    if f.suffix.lower() in exts:
                        files.append(str(f))
    return sorted(files)

class ImageFolderFlat(Dataset):
    def __init__(self, files: List[str], size: int = 128):
        self.files = files
        self.transform = transforms.Compose([
            transforms.Resize((size, size), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
        ])
    def __len__(self): return len(self.files)
    def __getitem__(self, idx):
        fp = self.files[idx]
        im = Image.open(fp).convert("RGB")
        return self.transform(im)

files = find_images()
print(f"Found {len(files)} image files under /kaggle/input/*")
assert len(files) > 0, "No images found. Add an FFHQ-128 Kaggle dataset to the notebook (see intro cell)."

ds = ImageFolderFlat(files, size=CFG["image_size"])
dl = DataLoader(ds, batch_size=CFG["batch_size"], shuffle=True, num_workers=CFG["num_workers"], pin_memory=True, drop_last=True)


In [None]:

class ResBlockUp(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.skip = nn.Conv2d(in_ch, out_ch, 1)
    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        h = self.bn1(self.conv1(x)); h = F.relu(h, inplace=True)
        h = self.bn2(self.conv2(h)); h = F.relu(h, inplace=True)
        s = self.skip(x)
        return h + s

class Generator(nn.Module):
    def __init__(self, z_dim=256, base=64, img_size=128):
        super().__init__()
        self.init_sz = img_size // 16  # 8 for 128x128
        self.fc = nn.Linear(z_dim, base*8*self.init_sz*self.init_sz)
        ch = base*8
        blocks = []
        for _ in range(4): # x2 four times: 8->16->32->64->128
            blocks.append(ResBlockUp(ch, ch//2))
            ch //= 2
        self.blocks = nn.Sequential(*blocks)
        self.to_rgb = nn.Conv2d(ch, 3, 3, 1, 1)
    def forward(self, z):
        h = self.fc(z).view(z.size(0), -1, self.init_sz, self.init_sz)
        h = self.blocks(h)
        x = torch.tanh(self.to_rgb(h))
        return x

def SN(module): 
    return nn.utils.spectral_norm(module)

class DiscBlock(nn.Module):
    def __init__(self, in_ch, out_ch, down=True):
        super().__init__()
        self.conv1 = SN(nn.Conv2d(in_ch, out_ch, 3, 1, 1))
        self.conv2 = SN(nn.Conv2d(out_ch, out_ch, 3, 1, 1))
        self.down = down
        self.skip = SN(nn.Conv2d(in_ch, out_ch, 1, 1, 0))
    def forward(self, x):
        h = F.leaky_relu(self.conv1(x), 0.2, inplace=True)
        h = F.leaky_relu(self.conv2(h), 0.2, inplace=True)
        if self.down:
            h = F.avg_pool2d(h, 2)
            s = F.avg_pool2d(self.skip(x), 2)
        else:
            s = self.skip(x)
        return h + s

class Discriminator(nn.Module):
    def __init__(self, base=64, img_size=128):
        super().__init__()
        ch = base
        blocks = [DiscBlock(3, ch, down=True)]
        for _ in range(3): # 128->64->32->16->8
            blocks.append(DiscBlock(ch, ch*2, down=True))
            ch *= 2
        blocks.append(DiscBlock(ch, ch, down=False))
        self.blocks = nn.Sequential(*blocks)
        self.linear = SN(nn.Linear(ch*8*8, 1))
    def forward(self, x):
        h = self.blocks(x)
        h = h.view(h.size(0), -1)
        out = self.linear(h)
        return out.squeeze(1)

G = Generator(z_dim=CFG["latent_dim"], base=CFG["g_channels"], img_size=CFG["image_size"]).to(DEVICE)
D = Discriminator(base=CFG["d_channels"], img_size=CFG["image_size"]).to(DEVICE)
print(f"G params: {sum(p.numel() for p in G.parameters())/1e6:.2f}M | D params: {sum(p.numel() for p in D.parameters())/1e6:.2f}M")


In [None]:

def load_embedder(preference_list):
    last_err = None
    for name in preference_list:
        try:
            print(f"Trying embedder: {name}")
            processor = AutoImageProcessor.from_pretrained(name, trust_remote_code=True)
            model = AutoModel.from_pretrained(name, trust_remote_code=True)
            model.eval().to(DEVICE)
            # Small probe to determine pooling: we use mean of last_hidden_state if present, else penultimate features
            def featurize(x):  # x in [-1,1], Bx3xHxW
                with torch.no_grad():
                    # Map to [0,1] and resize to expected size
                    imgs = (x * 0.5 + 0.5).clamp(0,1)
                    # transformers expects list of PIL or numpy; but pixel_values works too
                    inputs = processor(images=[transforms.ToPILImage()(img) for img in imgs.cpu()], return_tensors="pt").to(DEVICE)
                    out = model(**inputs)
                if hasattr(out, "last_hidden_state"):
                    feats = out.last_hidden_state.mean(dim=1)  # mean pool tokens
                elif hasattr(out, "pooler_output") and out.pooler_output is not None:
                    feats = out.pooler_output
                else:
                    # Fallback: try extracting from logits if it's a classifier model
                    logits = getattr(out, "logits", None)
                    if logits is None:
                        raise RuntimeError("Unsupported model outputs for feature extraction.")
                    feats = logits
                return feats
            # smoke test on a tiny random batch
            _ = featurize(torch.randn(2,3,CFG["image_size"],CFG["image_size"], device=DEVICE))
            print(f"Loaded embedder: {name}")
            return processor, model, featurize, name
        except Exception as e:
            print("Failed to load", name, "because:", repr(e))
            last_err = e
    raise last_err

processor, embedder, featurize_nograd, embedder_name = load_embedder(CFG["embedder_preference"])

# Version with gradient for fake images (so G receives signal). We detach only the real features.
def featurize_with_grad(x):
    # No torch.no_grad here. We still wrap pre/post steps as needed.
    imgs = (x * 0.5 + 0.5).clamp(0,1)
    inputs = processor(images=[transforms.ToPILImage()(img) for img in imgs.detach().cpu()], return_tensors="pt").to(DEVICE)
    out = embedder(**inputs)
    if hasattr(out, "last_hidden_state"):
        feats = out.last_hidden_state.mean(dim=1)
    elif hasattr(out, "pooler_output") and out.pooler_output is not None:
        feats = out.pooler_output
    else:
        logits = getattr(out, "logits", None)
        if logits is None:
            raise RuntimeError("Unsupported model outputs for feature extraction.")
        feats = logits
    return feats

print("Embedder selected:", embedder_name)


In [None]:

opt_g = torch.optim.Adam(G.parameters(), lr=CFG["lr_g"], betas=CFG["betas"])
opt_d = torch.optim.Adam(D.parameters(), lr=CFG["lr_d"], betas=CFG["betas"])

scaler = torch.cuda.amp.GradScaler(enabled=CFG["mixed_precision"])

def d_hinge_loss(real_logits, fake_logits):
    loss_real = F.relu(1.0 - real_logits).mean()
    loss_fake = F.relu(1.0 + fake_logits).mean()
    return loss_real + loss_fake

def g_hinge_loss(fake_logits):
    return -fake_logits.mean()

@torch.no_grad()
def sample_grid(n=64):
    G.eval()
    z = torch.randn(n, CFG["latent_dim"], device=DEVICE)
    x = G(z).clamp(-1,1)
    G.train()
    grid = vutils.make_grid((x*0.5+0.5), nrow=int(math.sqrt(n)))
    out_path = os.path.join(CFG["save_dir"], f"samples_{int(time.time())}.png")
    vutils.save_image(grid, out_path)
    return out_path


In [None]:

def jepa_feature_moment_loss(x_real, x_fake):
    # x_* in [-1,1], Bx3xHxW
    with torch.no_grad():
        f_real = featurize_nograd(x_real)  # detach for real; no grad needed
    f_fake = featurize_with_grad(x_fake)   # allow grad to flow into G via embedder
    mu_r, mu_f = f_real.mean(0), f_fake.mean(0)
    sd_r, sd_f = f_real.std(0),  f_fake.std(0)
    return (mu_r - mu_f).pow(2).mean() + (sd_r - sd_f).pow(2).mean()

def train():
    global G, D
    step = 0
    best_seen = None
    for epoch in range(1, CFG["epochs"]+1):
        running = {"d":0.0, "g":0.0, "jepa":0.0}
        with Timer() as t_epoch:
            for i, real in enumerate(dl):
                if CFG["max_steps_per_epoch"] is not None and i >= CFG["max_steps_per_epoch"]:
                    break
                real = real.to(DEVICE, non_blocking=True)
                bs = real.size(0)
                z = torch.randn(bs, CFG["latent_dim"], device=DEVICE)
                # ------------------ Train D ------------------
                for _ in range(CFG["n_crit"]):
                    with torch.cuda.amp.autocast(enabled=CFG["mixed_precision"]):
                        fake = G(z).detach()
                        d_real = D(real)
                        d_fake = D(fake)
                        loss_d = d_hinge_loss(d_real, d_fake)
                    opt_d.zero_grad(set_to_none=True)
                    scaler.scale(loss_d).backward()
                    scaler.step(opt_d)
                    # no scaler.update() here, we do once per iteration after G update

                # ------------------ Train G ------------------
                z = torch.randn(bs, CFG["latent_dim"], device=DEVICE)
                with torch.cuda.amp.autocast(enabled=CFG["mixed_precision"]):
                    fake = G(z)
                    g_adv = g_hinge_loss(D(fake))
                    jepa = jepa_feature_moment_loss(real, fake) * CFG["lambda_jepa"]
                    loss_g = g_adv + jepa
                opt_g.zero_grad(set_to_none=True)
                scaler.scale(loss_g).backward()
                scaler.step(opt_g)
                scaler.update()

                # Logs
                running["d"] += float(loss_d.detach().cpu())
                running["g"] += float(g_adv.detach().cpu())
                running["jepa"] += float((jepa.detach().cpu()))

                if (step % 200) == 0:
                    path = sample_grid(36)
                    print(f"[ep {epoch:02d} | it {i:05d}] D={running['d']/(i+1):.3f}  G={running['g']/(i+1):.3f}  JEPA={running['jepa']/(i+1):.3f}  → saved {path}")
                step += 1

        # Save checkpoint per epoch
        ckpt = {
            "G": G.state_dict(),
            "D": D.state_dict(),
            "opt_g": opt_g.state_dict(),
            "opt_d": opt_d.state_dict(),
            "cfg": CFG,
            "embedder": embedder_name,
        }
        torch.save(ckpt, os.path.join(CFG["save_dir"], f"ckpt_e{epoch:02d}.pt"))
        print(f"Epoch {epoch} done in {t_epoch.dt:.1f}s. Checkpoint saved.")

train()
print("Training finished.")
print("Sample preview:", sample_grid(64))
