In [48]:
#!pip -q uninstall -y diffusers transformers huggingface_hub
#!pip -q install -U "huggingface_hub" "transformers" "diffusers"


In [49]:
#!pip -q install torchmetrics
#!pip -q install torch-fidelity



In [50]:
from diffusers import DDPMPipeline
from torchvision import datasets, transforms
import random, numpy as np, torch
import torch.nn as nn
import matplotlib.pyplot as plt
import copy

seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"

In [51]:
# ----------------------------
# Load pretrained DDPM (CIFAR10)
# ----------------------------
pipe = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").to(device)
pipe.set_progress_bar_config(disable=True)
unet = pipe.unet.eval()

for p in unet.parameters():
    p.requires_grad_(False)

alphas_cumprod = pipe.scheduler.alphas_cumprod.to(device)  # shape [T]
T = alphas_cumprod.shape[0]

# DDPM marginal: x_t = sqrt(a_bar_t) x0 + sqrt(1-a_bar_t) eps
def add_ddpm_noise(x0, t, eps=None):
    """
    x0: (B,3,32,32) in [-1,1]
    t:  int timestep in [0, T-1]
    """
    if eps is None:
        eps = torch.randn_like(x0)
    a_bar = alphas_cumprod[t]
    mean_coeff = torch.sqrt(a_bar)
    sigma_t = torch.sqrt(1.0 - a_bar)   # <-- DDPM noise std
    x_t = mean_coeff * x0 + sigma_t * eps
    return x_t, sigma_t

def score_fn_xt(x_t, t, sigma_t=None, with_grad=False):
    """
    Returns score wrt x_t: ∇_{x_t} log p_t(x_t) ≈ - eps_pred / sigma_t
    If with_grad=True, allows gradients wrt x_t (unet weights stay frozen).
    """
    if sigma_t is None:
        sigma_t = torch.sqrt(1.0 - alphas_cumprod[t])

    t_tensor = torch.full((x_t.shape[0],), int(t), device=x_t.device, dtype=torch.long)

    if with_grad:
        eps_pred = unet(x_t, t_tensor).sample   # grad flows to x_t
    else:
        with torch.no_grad():
            eps_pred = unet(x_t, t_tensor).sample

    return -(eps_pred / sigma_t)


Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

An error occurred while trying to fetch /root/.cache/huggingface/hub/models--google--ddpm-cifar10-32/snapshots/267b167dc01f0e4e61923ea244e8b988f84deb80: Error no file named diffusion_pytorch_model.safetensors found in directory /root/.cache/huggingface/hub/models--google--ddpm-cifar10-32/snapshots/267b167dc01f0e4e61923ea244e8b988f84deb80.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


In [52]:
tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda z: z*2-1),
])
cifar = datasets.CIFAR10(root="./data", train=True, download=True, transform=tfm)
loader = torch.utils.data.DataLoader(cifar, batch_size=64, shuffle=True, num_workers=2)


In [53]:
@torch.no_grad()
def select_renyi_landmarks_torch(X, m, sigma2=0.5):
    # X: (n,3,32,32)
    n = X.shape[0]
    m = min(m, n)
    Xf = X.view(n, -1)

    # squared distances (n,n)
    dist2 = torch.cdist(Xf, Xf, p=2.0) ** 2
    K = torch.exp(-dist2 / (2.0 * sigma2))
    diag = torch.diagonal(K)

    row_sums = K.sum(dim=1)
    first = torch.argmin(row_sums).item()

    selected = [first]
    cross_sums = K[:, first].clone()

    while len(selected) < m:
        scores = 2 * cross_sums + diag
        scores[selected] = float("inf")
        nxt = torch.argmin(scores).item()
        selected.append(nxt)
        cross_sums += K[:, nxt]

    return torch.tensor(selected, device=X.device, dtype=torch.long)


In [54]:
import torch

@torch.no_grad()
def estimate_sigma2_median_heuristic(X_t, num_pairs=512, eps=1e-6):
    """
    Fast median heuristic for RBF bandwidth on flattened X_t.
    Returns sigma2 (variance) used in exp(-||x-y||^2 / (2*sigma2)).
    """
    B = X_t.shape[0]
    Xf = X_t.view(B, -1)

    # sample random pairs (i,j)
    i = torch.randint(0, B, (num_pairs,), device=X_t.device)
    j = torch.randint(0, B, (num_pairs,), device=X_t.device)

    diff = Xf[i] - Xf[j]
    dist2 = (diff * diff).sum(dim=1)

    med = torch.median(dist2)
    sigma2 = 0.5 * med + eps  # common choice: sigma^2 = median/2
    return float(sigma2.item())


In [55]:
class RenyiNystroemKSDTorch:
    def __init__(self, sigma2=1.0, m_fn=lambda n: int(4*np.sqrt(n)), ridge=1e-3, mc_eps=4):
        self.sigma2 = float(sigma2)
        self.m_fn = m_fn
        self.ridge = float(ridge)
        self.mc_eps = int(mc_eps)

    def h_p(self, X, Y, t, sigma_t, sigma2, with_grad_score=False):

        X_flat = X.view(X.shape[0], -1)
        Y_flat = Y.view(Y.shape[0], -1)

        grad_logpX = score_fn_xt(X, t, sigma_t, with_grad=with_grad_score).view(X.shape[0], -1)
        grad_logpY = score_fn_xt(Y, t, sigma_t, with_grad=with_grad_score).view(Y.shape[0], -1)


        diff = X_flat[:, None, :] - Y_flat[None, :, :]
        dist2 = (diff**2).sum(dim=2)

        K = torch.exp(-dist2 / (2 * sigma2))

        gram_glogp = grad_logpX @ grad_logpY.T
        gradX = -(diff / sigma2) * K[:, :, None]
        gradY = -gradX

        B = (gradX * grad_logpY[None, :, :]).sum(dim=2)
        C = (gradY * grad_logpX[:, None, :]).sum(dim=2)

        d = X_flat.shape[1]
        gradXY_sum = (dist2 / (sigma2**2) - d / sigma2) * K

        return K * gram_glogp + B + C + gradXY_sum

    def _single_stat_from_xt(self, X_t, t, sigma_t, ridge, sigma2, with_grad_score=False):

        n = X_t.shape[0]
        m = min(self.m_fn(n), n)

        idx = select_renyi_landmarks_torch(X_t, m, sigma2=sigma2)

        H_mn = self.h_p(X_t[idx], X_t, t=t, sigma_t=sigma_t, sigma2=sigma2, with_grad_score=with_grad_score)
        H_mm = H_mn[:, idx]

        H_mm = 0.5 * (H_mm + H_mm.T)
        I = torch.eye(m, device=X_t.device, dtype=H_mm.dtype)
        H_mm_reg = H_mm + ridge * I

        beta = H_mn @ (torch.ones(n, 1, device=X_t.device, dtype=H_mn.dtype) / n)

        x = torch.linalg.solve(H_mm_reg, beta)
        stat = (beta.T @ x).squeeze()
        return stat

    def compute_stat_from_xt(self, X_t, t, sigma_t, ridge=None, sigma2=None, with_grad_score=False):
        if ridge is None:
            ridge = self.ridge
        if sigma2 is None:
            sigma2 = self.sigma2
        return self._single_stat_from_xt(
            X_t, t=int(t), sigma_t=sigma_t,
            ridge=float(ridge), sigma2=float(sigma2),
            with_grad_score=with_grad_score
        )


    def compute_stat(self, X0, t, mc_samples=None, ridge=None, sigma2=None, bw_mode="fixed"):
        if mc_samples is None:
            mc_samples = self.mc_eps
        if ridge is None:
            ridge = self.ridge

        stats = []
        for _ in range(int(mc_samples)):
            X_t, sigma_t = add_ddpm_noise(X0, t)

            # bandwidth selection
            if bw_mode == "fixed":
                sigma2_eff = self.sigma2 if sigma2 is None else float(sigma2)
            elif bw_mode == "median":
                sigma2_eff = estimate_sigma2_median_heuristic(X_t)
            else:
                raise ValueError("bw_mode must be 'fixed' or 'median'")

            stats.append(self._single_stat_from_xt(X_t, t=t, sigma_t=sigma_t, ridge=float(ridge), sigma2=sigma2_eff))

        return torch.stack(stats).mean()

ksd = RenyiNystroemKSDTorch(sigma2=1.0, ridge=1e-3, mc_eps=4)


In [56]:
class DCGANGen(nn.Module):
    def __init__(self, z_dim=128, ngf=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, ngf*4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*4), nn.ReLU(True),

            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2), nn.ReLU(True),

            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf), nn.ReLU(True),

            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, z):
        return self.net(z)

G = DCGANGen(z_dim=128).to(device)
opt = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))


In [57]:
def estimate_real_baseline(t, num_batches=5, mc_samples=4, ridge=1e-3):
    vals = []
    it = iter(loader)
    for _ in range(num_batches):
        xr, _ = next(it)
        xr = xr.to(device)
        vals.append(ksd.compute_stat(xr, t=t, mc_samples=mc_samples, ridge=ridge).item())
    return float(np.mean(vals)), float(np.std(vals))

# Choose an initial diffusion level (not too small at first)
# (You can later anneal it downward.)
t_baseline = int(0.6 * (T - 1))
real_mean, real_std = estimate_real_baseline(t_baseline, num_batches=5)
print(f"KSD(real @ t={t_baseline}) baseline: mean={real_mean:.2f}, std={real_std:.2f}")

KSD(real @ t=599) baseline: mean=0.56, std=0.02


In [None]:
def rbf_kernel_flat(X, Y, sigma2):
    Xf = X.view(X.size(0), -1)
    Yf = Y.view(Y.size(0), -1)
    XX = (Xf**2).sum(dim=1, keepdim=True)
    YY = (Yf**2).sum(dim=1, keepdim=True)
    dist2 = XX - 2*Xf @ Yf.T + YY.T
    return torch.exp(-dist2 / (2.0 * sigma2))

def mmd2_rbf(X, Y, sigma2):
    Kxx = rbf_kernel_flat(X, X, sigma2)
    Kyy = rbf_kernel_flat(Y, Y, sigma2)
    Kxy = rbf_kernel_flat(X, Y, sigma2)
    return Kxx.mean() + Kyy.mean() - 2.0 * Kxy.mean()

# ----------------------------
# Training loop: t-mixture + bandwidth tuning + REAL ANCHOR
# ----------------------------
steps = 2000
batch_size = 64
losses = []
log_every = 50

t_min = int(0.6 * (T - 1))   # start safer
t_max = int(0.9 * (T - 1))

K_T = 4
MC_SAMPLES = 8
RIDGE_LAM  = 1e-3
LAMBDA_MMD = 1.0   # try 0.1, 1.0, 10.0

it = iter(loader)

for i in range(steps):
    # ---- real batch ----
    try:
        real0, _ = next(it)
    except StopIteration:
        it = iter(loader)
        real0, _ = next(it)
    real0 = real0.to(device)

    # ---- fake batch ----
    z = torch.randn(batch_size, 128, 1, 1, device=device)
    fake0 = G(z)

    # ---- sample t's ----
    t_list = sample_t_batch(K_T, t_min, t_max, bias_to_small=False)

    loss_terms = []

    for t in t_list:
        fake_t, sigma_t = add_ddpm_noise(fake0, t)
        real_t, _       = add_ddpm_noise(real0, t)

        Xt = torch.cat([real_t, fake_t], dim=0)
        sigma2_eff = estimate_sigma2_median_heuristic(Xt)

        # KSD on x_t (NOT on x0), and allow grad wrt x_t
        ksd_t = ksd.compute_stat_from_xt(
            X_t=fake_t,
            t=t,
            sigma_t=sigma_t,
            ridge=RIDGE_LAM,
            sigma2=sigma2_eff,
            with_grad_score=True,   # IMPORTANT
        )

        mmd_t = mmd2_rbf(real_t, fake_t, sigma2=sigma2_eff)

        loss_terms.append(ksd_t + LAMBDA_MMD * mmd_t)


    loss = torch.stack(loss_terms).mean()

    if not torch.isfinite(loss):
        print(f"Skipping step {i} due to non-finite loss")
        continue

    opt.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(G.parameters(), max_norm=5.0)
    opt.step()

    losses.append(loss.item())

    if (i + 1) % log_every == 0:
        print(f"Step {i+1}: t_list={t_list} | loss={loss.item():.3f}")

Step 50: t_list=[731, 686, 872, 810] | loss=24.146
Step 100: t_list=[869, 754, 630, 617] | loss=24.433
Step 150: t_list=[661, 842, 648, 861] | loss=24.173
Step 200: t_list=[892, 601, 801, 696] | loss=24.267
Step 250: t_list=[688, 886, 764, 634] | loss=24.170
Step 300: t_list=[688, 874, 887, 658] | loss=24.046
Step 350: t_list=[788, 650, 799, 805] | loss=24.021
Step 400: t_list=[812, 616, 732, 692] | loss=24.311
Step 450: t_list=[764, 807, 811, 670] | loss=23.953
Step 500: t_list=[898, 773, 636, 732] | loss=24.005
Step 550: t_list=[657, 702, 772, 599] | loss=24.139
Step 600: t_list=[681, 861, 821, 820] | loss=23.969
Step 650: t_list=[869, 837, 638, 735] | loss=23.981
Step 700: t_list=[732, 628, 676, 720] | loss=24.188
Step 750: t_list=[836, 747, 866, 825] | loss=23.811
Step 800: t_list=[746, 746, 779, 755] | loss=23.904
Step 850: t_list=[789, 829, 604, 895] | loss=24.028
Step 900: t_list=[728, 702, 651, 653] | loss=24.194
Step 950: t_list=[829, 664, 746, 777] | loss=24.034
Step 1000: t_

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import matplotlib.pyplot as plt

# ---------------------------------------------------------
# Discriminator (DCGAN-style) for CIFAR10 32x32
# Output: logits (no sigmoid inside; we use BCEWithLogitsLoss)
# ---------------------------------------------------------
class DCGANDis(nn.Module):
    def __init__(self, ndf=64):
        super().__init__()
        self.net = nn.Sequential(
            # (3,32,32) -> (ndf,16,16)
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # (ndf,16,16) -> (ndf*2,8,8)
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),

            # (ndf*2,8,8) -> (ndf*4,4,4)
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),

            # (ndf*4,4,4) -> (1,1,1)
            nn.Conv2d(ndf*4, 1, 4, 1, 0, bias=False)
        )

    def forward(self, x):
        # returns logits shape (B, 1)
        return self.net(x).view(-1, 1)

# ---------------------------------------------------------
# Helper: show a small grid of samples
# ---------------------------------------------------------
@torch.no_grad()
def show_samples(G, title="samples", n=64, z_dim=128):
    G.eval()
    z = torch.randn(n, z_dim, 1, 1, device=device)
    x = G(z).detach().cpu()  # in [-1,1]
    x = (x + 1) / 2.0        # to [0,1] for plotting

    # Make a square grid
    s = int(np.sqrt(n))
    x = x[:s*s]
    grid = x.view(s, s, 3, 32, 32).permute(0, 3, 1, 4, 2).reshape(s*32, s*32, 3).numpy()

    plt.figure(figsize=(5,5))
    plt.imshow(grid)
    plt.axis("off")
    plt.title(title)
    plt.show()

# ---------------------------------------------------------
# Standard GAN training loop (DCGAN-ish)
# - Same G architecture as yours
# - Adds D and trains with BCEWithLogits
# ---------------------------------------------------------
def train_standard_gan(
    G_init,
    loader,
    steps=2000,
    batch_size=64,
    z_dim=128,
    lr_g=2e-4,
    lr_d=2e-4,
    betas=(0.5, 0.999),
    d_steps=1,
    label_smooth=0.9,
    log_every=200
):
    G = copy.deepcopy(G_init).to(device)
    D = DCGANDis(ndf=64).to(device)

    optG = torch.optim.Adam(G.parameters(), lr=lr_g, betas=betas)
    optD = torch.optim.Adam(D.parameters(), lr=lr_d, betas=betas)
    bce = nn.BCEWithLogitsLoss()

    lossesG, lossesD = [], []

    it = iter(loader)
    for step in range(1, steps + 1):
        try:
            real, _ = next(it)
        except StopIteration:
            it = iter(loader)
            real, _ = next(it)

        real = real.to(device)

        # -------------------------
        # (A) Train D
        # -------------------------
        for _ in range(d_steps):
            z = torch.randn(real.size(0), z_dim, 1, 1, device=device)
            fake = G(z).detach()

            logits_real = D(real)
            logits_fake = D(fake)

            # real labels ~ 0.9, fake labels ~ 0.0
            y_real = torch.full_like(logits_real, label_smooth, device=device)
            y_fake = torch.zeros_like(logits_fake, device=device)

            lossD = bce(logits_real, y_real) + bce(logits_fake, y_fake)

            optD.zero_grad(set_to_none=True)
            lossD.backward()
            optD.step()

        # -------------------------
        # (B) Train G (tries to fool D)
        # -------------------------
        z = torch.randn(real.size(0), z_dim, 1, 1, device=device)
        fake = G(z)
        logits_fake = D(fake)

        # want D(fake)=1
        y_gen = torch.ones_like(logits_fake, device=device)
        lossG = bce(logits_fake, y_gen)

        optG.zero_grad(set_to_none=True)
        lossG.backward()
        optG.step()

        lossesD.append(lossD.item())
        lossesG.append(lossG.item())

        if step % log_every == 0:
            print(f"[GAN] step {step}/{steps} | lossD={lossD.item():.3f} | lossG={lossG.item():.3f}")

    return G.eval(), D.eval(), lossesG, lossesD


In [None]:
@torch.no_grad()
def eval_ksd_over_t(
    G,
    ksd,
    t_list,
    batch_size=64,
    z_dim=128,
    mc_samples=4,
    ridge=1e-3
):
    G.eval()
    vals = []
    for t in t_list:
        z = torch.randn(batch_size, z_dim, 1, 1, device=device)
        fake0 = G(z)
        v = ksd.compute_stat(fake0, t=int(t), mc_samples=mc_samples, ridge=ridge).item()
        vals.append(v)
    return np.array(vals)

@torch.no_grad()
def real_baseline_over_t(
    loader,
    ksd,
    t_list,
    num_batches=3,
    mc_samples=4,
    ridge=1e-3
):
    # average KSD(real@t) across a few batches per t
    out_mean, out_std = [], []
    for t in t_list:
        vals = []
        it = iter(loader)
        for _ in range(num_batches):
            xr, _ = next(it)
            xr = xr.to(device)
            vals.append(ksd.compute_stat(xr, t=int(t), mc_samples=mc_samples, ridge=ridge).item())
        out_mean.append(np.mean(vals))
        out_std.append(np.std(vals))
    return np.array(out_mean), np.array(out_std)

def plot_ksd_comparison(t_list, real_mean, real_std, ksd_vals_A, label_A, ksd_vals_B, label_B):
    plt.figure(figsize=(7,4))
    plt.plot(t_list, ksd_vals_A, marker="o", label=label_A)
    plt.plot(t_list, ksd_vals_B, marker="o", label=label_B)
    plt.plot(t_list, real_mean, marker="o", label="real baseline")
    plt.fill_between(t_list, real_mean-real_std, real_mean+real_std, alpha=0.2)
    plt.gca().invert_xaxis()  # optional: show high t on left, low t on right
    plt.xlabel("t (diffusion step)")
    plt.ylabel("KSD@t (diffusion-score)")
    plt.title("KSD diagnostic across diffusion levels")
    plt.legend()
    plt.show()


In [None]:
# 0) Keep a copy of your KSD-trained generator
G_ksd = copy.deepcopy(G).eval()

# 1) Train a standard GAN generator from scratch (same architecture)
G0 = DCGANGen(z_dim=128).to(device)  # fresh init for fair baseline
G_gan, D_gan, lossesG, lossesD = train_standard_gan(
    G_init=G0,
    loader=loader,
    steps=2000,            # start with 2k–5k; increase if stable
    batch_size=64,
    z_dim=128,
    log_every=200
)

show_samples(G_ksd, title="KSD-trained G samples")
show_samples(G_gan, title="Standard GAN-trained G samples")

# 2) Evaluate both with KSD@t
MC_SAMPLES = 8
RIDGE_LAM  = 1e-2

# IMPORTANT: don't go too low t initially (that’s where you saw explosions)
t_list = [int(x*(T-1)) for x in [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3]]

real_mean, real_std = real_baseline_over_t(
    loader, ksd, t_list,
    num_batches=3, mc_samples=MC_SAMPLES, ridge=RIDGE_LAM
)

ksd_ksdG = eval_ksd_over_t(
    G_ksd, ksd, t_list,
    batch_size=64, z_dim=128, mc_samples=MC_SAMPLES, ridge=RIDGE_LAM
)

ksd_ganG = eval_ksd_over_t(
    G_gan, ksd, t_list,
    batch_size=64, z_dim=128, mc_samples=MC_SAMPLES, ridge=RIDGE_LAM
)

plot_ksd_comparison(t_list, real_mean, real_std, ksd_ksdG, "G trained by KSD", ksd_ganG, "G trained by GAN")
