In [None]:
!pip -q install torch torchvision diffusers==0.24.0 transformers==4.35.2


In [None]:
import torch
from diffusers import DDPMPipeline
from torchvision import datasets, transforms


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32").to(device)
pipe.set_progress_bar_config(disable=True)
unet = pipe.unet.eval()

def score_from_noise_pred(eps_pred, sigma):
    return -(1.0 / sigma) * eps_pred

@torch.no_grad()
def map_sigma_to_t(sigma):
    alphas_cumprod = pipe.scheduler.alphas_cumprod.to(device)
    sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod)
    t = int((sigmas - sigma).abs().argmin().item())
    return t

@torch.no_grad()
def score_fn(x, sigma=0.1):
    t = map_sigma_to_t(sigma)
    t_tensor = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long)
    eps_pred = unet(x, t_tensor).sample
    return score_from_noise_pred(eps_pred, sigma)


In [None]:
tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda z: z*2-1),
])
cifar = datasets.CIFAR10(root="./data", train=True, download=True, transform=tfm)
loader = torch.utils.data.DataLoader(cifar, batch_size=64, shuffle=True, num_workers=2)


In [None]:
def k_gauss(X, Y, sigma2=0.5):
    # X: (n,d), Y: (m,d)
    XX = (X**2).sum(dim=1, keepdim=True)
    YY = (Y**2).sum(dim=1, keepdim=True)
    dist2 = XX - 2*X@Y.T + YY.T
    return torch.exp(-dist2 / (2 * sigma2))


In [None]:
import numpy as np

def select_renyi_landmarks_np(X, m, sigma2=0.5):
    # X: torch tensor (n,d), returns numpy indices
    Xn = X.detach().cpu().numpy()
    n = len(Xn); m = min(m, n)
    # Gaussian kernel Gram
    XX = (Xn**2).sum(axis=1, keepdims=True)
    dist2 = XX - 2*Xn@Xn.T + XX.T
    K = np.exp(-dist2 / (2 * sigma2))
    diag = np.diag(K)
    row_sums = K.sum(axis=1)
    first = int(np.argmin(row_sums))
    selected = [first]
    cross_sums = K[:, first].copy()
    while len(selected) < m:
        scores = 2 * cross_sums + diag
        scores[selected] = np.inf
        nxt = int(np.argmin(scores))
        selected.append(nxt)
        cross_sums += K[:, nxt]
    return np.array(selected, dtype=int)


In [None]:
class RenyiNystroemKSDTorch:
    def __init__(self, score_fn, sigma2=0.5, m_fn=lambda n: int(4*torch.sqrt(torch.tensor(n)).item())):
        self.score_fn = score_fn
        self.sigma2 = sigma2
        self.m_fn = m_fn

    def h_p(self, X, Y):
        # X: (n,d), Y: (m,d)
        grad_logpX = self.score_fn(X)
        grad_logpY = self.score_fn(Y)
        gram_glogp = grad_logpX @ grad_logpY.T
        K = k_gauss(X, Y, sigma2=self.sigma2)

        # gradients of kernel
        # For Gaussian kernel: grad_X k = -(X - Y)/sigma2 * k
        # Use broadcasting
        diff = X[:, None, :] - Y[None, :, :]
        k_xy = K[:, :, None]
        gradX = -(diff / self.sigma2) * k_xy  # (n,m,d)
        gradY = -gradX

        B = (gradX * grad_logpY[None, :, :]).sum(dim=2)
        C = (gradY * grad_logpX[:, None, :]).sum(dim=2)
        gradXY_sum = (-diff / self.sigma2 * gradY).sum(dim=2)  # trace of second derivative term

        return K * gram_glogp + B + C + gradXY_sum

    def compute_stat(self, X):
        n = X.shape[0]
        m = min(self.m_fn(n), n)
        idx = select_renyi_landmarks_np(X, m, sigma2=self.sigma2)
        idx = torch.as_tensor(idx, device=X.device)

        H_mn = self.h_p(X[idx], X)
        H_mm = H_mn[:, idx]
        H_mm_inv = torch.linalg.pinv(H_mm)
        beta = H_mn @ (torch.ones(n, 1, device=X.device) / n)
        stat = (beta.T @ H_mm_inv @ beta).squeeze()
        return stat


In [None]:
ksd = RenyiNystroemKSDTorch(score_fn=score_fn, sigma2=0.5)

x, _ = next(iter(loader))
x = x.to(device)
loss = ksd.compute_stat(x)
print("KSD stat:", loss.item())


In [None]:
# placeholder generator for smoke test
import torch.nn as nn

G = nn.Sequential(
    nn.Linear(128, 512), nn.ReLU(),
    nn.Linear(512, 3*32*32), nn.Tanh()
).to(device)

opt = torch.optim.Adam(G.parameters(), lr=2e-4)

z = torch.randn(64, 128, device=device)
fake = G(z).view(-1, 3, 32, 32)
loss = ksd.compute_stat(fake)
loss.backward()
opt.step()

print("Generator step done, loss:", loss.item())
