In [None]:
!pip install torch torchvision matplotlib tqdm

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
# --- Data: MNIST in [-1,1] ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x * 2.0 - 1.0)  # [0,1] -> [-1,1]
])

train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    transform=transform,
    download=True
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0)

print("Train size:", len(train_dataset))
imgs, labels = next(iter(train_loader))
print("Batch shape:", imgs.shape, "Labels shape:", labels.shape)


In [None]:
# --- Convolutional VAE ---
class ConvVAE(nn.Module):
    def __init__(self, latent_channels=4):
        super().__init__()
        self.latent_channels = latent_channels
        self.H_lat = 7
        self.W_lat = 7
        self.latent_dim = latent_channels * self.H_lat * self.W_lat

        # Encoder: 1x28x28 -> 32x14x14 -> 64x7x7
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # 14x14
            nn.ReLU(True),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # 7x7
            nn.ReLU(True),
        )
        self.enc_fc_mu = nn.Linear(64 * 7 * 7, self.latent_dim)
        self.enc_fc_logvar = nn.Linear(64 * 7 * 7, self.latent_dim)

        # Decoder: latent_dim -> 64x7x7 -> 32x14x14 -> 1x28x28
        self.dec_fc = nn.Linear(self.latent_dim, 64 * 7 * 7)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # 14x14
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),   # 28x28
            nn.Tanh(),  # output in [-1,1]
        )

    def encode(self, x):
        h = self.enc(x)
        h = h.view(x.size(0), -1)
        mu = self.enc_fc_mu(h)
        logvar = self.enc_fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z_vec):
        h = self.dec_fc(z_vec)
        h = h.view(z_vec.size(0), 64, 7, 7)
        x_recon = self.dec(h)
        return x_recon

    def forward(self, x):
        mu, logvar = self.encode(x)
        z_vec = self.reparameterize(mu, logvar)
        x_recon = self.decode(z_vec)
        return x_recon, mu, logvar

    def encode_to_2d_latent(self, x):
        mu, logvar = self.encode(x)
        z_vec = self.reparameterize(mu, logvar)
        z_2d = z_vec.view(x.size(0), self.latent_channels, self.H_lat, self.W_lat)
        return z_2d

    def decode_from_2d_latent(self, z_2d):
        z_vec = z_2d.view(z_2d.size(0), -1)
        return self.decode(z_vec)

def vae_loss(x_recon, x, mu, logvar):
    recon_loss = F.mse_loss(x_recon, x, reduction="sum") / x.size(0)
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon_loss + kl, recon_loss, kl

vae = ConvVAE(latent_channels=4).to(device)
vae_opt = torch.optim.Adam(vae.parameters(), lr=1e-3)

print("VAE parameters (M):", sum(p.numel() for p in vae.parameters())/1e6)


In [None]:
# --- Train VAE ---
num_vae_epochs = 50  # increase for better reconstructions

vae.train()
for epoch in range(num_vae_epochs):
    pbar = tqdm(train_loader, desc=f"VAE Epoch {epoch+1}/{num_vae_epochs}")
    total_loss = 0.0
    for imgs, _ in pbar:
        imgs = imgs.to(device)
        x_recon, mu, logvar = vae(imgs)
        loss, r_loss, kl = vae_loss(x_recon, imgs, mu, logvar)

        vae_opt.zero_grad()
        loss.backward()
        vae_opt.step()

        total_loss += loss.item() * imgs.size(0)
        pbar.set_postfix({"loss": loss.item(), "recon": r_loss.item(), "kl": kl.item()})

    print(f"Epoch {epoch+1}: avg loss = {total_loss/len(train_dataset):.4f}")


In [None]:
# --- Quick VAE reconstruction check ---
vae.eval()

imgs, _ = next(iter(train_loader))
imgs = imgs.to(device)[:8]
with torch.no_grad():
    recon, _, _ = vae(imgs)

def show_batch(x, title):
    x = (x.cpu() + 1) / 2  # [-1,1] -> [0,1]
    grid = torch.cat([xx for xx in x], dim=2)[0]
    plt.figure(figsize=(8,2))
    plt.imshow(grid, cmap="gray")
    plt.axis("off")
    plt.title(title)
    plt.show()

show_batch(imgs, "Original")
show_batch(recon, "VAE Reconstruction")


In [None]:
# --- Diffusion setup in latent space ---
T = 200
beta_start = 1e-4
beta_end   = 0.02

betas = torch.linspace(beta_start, beta_end, T).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), alphas_cumprod[:-1]], dim=0)

sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)

def q_sample(z0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(z0)
    sqrt_ac = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
    sqrt_om = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
    return sqrt_ac * z0 + sqrt_om * noise

def sample_timesteps(batch_size):
    return torch.randint(low=0, high=T, size=(batch_size,), device=device)

def sinusoidal_time_embedding(timesteps, dim):
    device_ = timesteps.device
    half_dim = dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=device_) * -emb)
    emb = timesteps.float().unsqueeze(1) * emb.unsqueeze(0)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if dim % 2 == 1:
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
    return emb


In [None]:
# --- Conditional U-Net in latent space ---
time_embed_dim = 64
text_embed_dim = 64
base_channels  = 64
num_classes    = 10
guidance_scale = 3.0
drop_cond_prob = 0.1

class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, cond_dim):
        super().__init__()
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)

        self.cond_proj = nn.Linear(cond_dim, 2 * out_ch)

        if in_ch != out_ch:
            self.skip = nn.Conv2d(in_ch, out_ch, 1)
        else:
            self.skip = nn.Identity()

    def forward(self, x, cond):
        gamma_beta = self.cond_proj(cond)
        gamma, beta = gamma_beta.chunk(2, dim=1)
        gamma = gamma.unsqueeze(-1).unsqueeze(-1)
        beta = beta.unsqueeze(-1).unsqueeze(-1)

        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)

        h = self.norm2(h)
        h = h * (1 + gamma) + beta
        h = F.silu(h)
        h = self.conv2(h)

        return h + self.skip(x)

class LatentUNet(nn.Module):
    def __init__(self, in_ch=4, base_ch=64, time_dim=64, text_dim=64):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(time_dim, time_dim),
            nn.SiLU(),
            nn.Linear(time_dim, time_dim),
        )
        self.text_mlp = nn.Sequential(
            nn.Linear(text_dim, text_dim),
            nn.SiLU(),
            nn.Linear(text_dim, text_dim),
        )
        cond_dim = time_dim + text_dim

        self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
        self.down1 = ResidualBlock(base_ch, base_ch, cond_dim)
        self.down2 = ResidualBlock(base_ch, base_ch * 2, cond_dim)
        self.pool1 = nn.AvgPool2d(2)
        self.down3 = ResidualBlock(base_ch * 2, base_ch * 2, cond_dim)

        self.mid = ResidualBlock(base_ch * 2, base_ch * 2, cond_dim)

        self.up1 = ResidualBlock(base_ch * 2, base_ch * 2, cond_dim)
        self.up2 = ResidualBlock(base_ch * 2, base_ch, cond_dim)
        self.conv_out = nn.Conv2d(base_ch, in_ch, 3, padding=1)

    def forward(self, x, t, text_emb):
        t_emb = sinusoidal_time_embedding(t, time_embed_dim)
        t_emb = self.time_mlp(t_emb)
        text_emb = self.text_mlp(text_emb)
        cond = torch.cat([t_emb, text_emb], dim=1)

        x = self.conv_in(x)
        x1 = self.down1(x, cond)
        x2 = self.down2(x1, cond)
        x2p = self.pool1(x2)
        x3 = self.down3(x2p, cond)

        m = self.mid(x3, cond)

        u1 = self.up1(m, cond)
        u1_up = F.interpolate(u1, size=x2.shape[-2:], mode="nearest")
        u2 = self.up2(u1_up, cond)
        out = self.conv_out(u2)
        return out

class LabelTextEncoder(nn.Module):
    def __init__(self, num_classes=10, embed_dim=64):
        super().__init__()
        self.emb = nn.Embedding(num_classes, embed_dim)

    def forward(self, labels):
        return self.emb(labels)

latent_channels = 4
vae.eval()
for p in vae.parameters():
    p.requires_grad = False

unet = LatentUNet(in_ch=latent_channels,
                  base_ch=base_channels,
                  time_dim=time_embed_dim,
                  text_dim=text_embed_dim).to(device)
text_encoder = LabelTextEncoder(num_classes, text_embed_dim).to(device)

diff_opt = torch.optim.AdamW(
    list(unet.parameters()) + list(text_encoder.parameters()),
    lr=2e-4
)

print("Diffusion params (M):",
      (sum(p.numel() for p in unet.parameters()) +
       sum(p.numel() for p in text_encoder.parameters()))/1e6)


In [None]:
# --- Train latent diffusion (conditional on labels) ---
num_diff_epochs = 20  # increase for better sample quality

unet.train()
text_encoder.train()

for epoch in range(num_diff_epochs):
    pbar = tqdm(train_loader, desc=f"Diff Epoch {epoch+1}/{num_diff_epochs}")
    for imgs, labels in pbar:
        imgs = imgs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            z0 = vae.encode_to_2d_latent(imgs)

        t = sample_timesteps(z0.size(0))
        noise = torch.randn_like(z0)
        z_t = q_sample(z0, t, noise)

        text_emb = text_encoder(labels)
        drop_mask = (torch.rand(z0.size(0), device=device) < drop_cond_prob).float().view(-1, 1)
        text_emb = text_emb * (1.0 - drop_mask)

        noise_pred = unet(z_t, t, text_emb)
        loss = F.mse_loss(noise_pred, noise)

        diff_opt.zero_grad()
        loss.backward()
        diff_opt.step()

        pbar.set_postfix({"loss": loss.item()})
    print(f"Epoch {epoch+1}: last batch loss = {loss.item():.4f}")


In [None]:
# --- Sampling: label -> latent diffusion -> VAE decoder ---
@torch.no_grad()
def p_sample(x_t, t, text_emb, text_emb_uncond, guidance_scale=3.0):
    eps_cond = unet(x_t, t, text_emb)
    eps_uncond = unet(x_t, t, text_emb_uncond)
    eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)

    beta_t = betas[t].view(-1,1,1,1)
    sqrt_om = sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
    sqrt_recip_a = sqrt_recip_alphas[t].view(-1,1,1,1)

    ac_t = alphas_cumprod[t].view(-1,1,1,1)
    x0_pred = (x_t - sqrt_om * eps) / torch.sqrt(ac_t)
    mean = sqrt_recip_a * (x_t - beta_t / sqrt_om * eps)

    if t[0] > 0:
        noise = torch.randn_like(x_t)
        var = posterior_variance[t].view(-1,1,1,1)
        x_prev = mean + torch.sqrt(var) * noise
    else:
        x_prev = mean
    return x_prev

@torch.no_grad()
def sample_from_labels(labels, num_steps=100, guidance_scale=3.0):
    vae.eval()
    unet.eval()
    text_encoder.eval()

    if not torch.is_tensor(labels):
        labels = torch.tensor(labels, device=device, dtype=torch.long)
    else:
        labels = labels.to(device)
    B = labels.size(0)

    z_t = torch.randn(B, latent_channels, 7, 7, device=device)

    text_emb = text_encoder(labels)
    text_emb_uncond = torch.zeros_like(text_emb)

    for step in tqdm(reversed(range(num_steps)), total=num_steps, desc="Sampling"):
        t = torch.full((B,), step, device=device, dtype=torch.long)
        z_t = p_sample(z_t, t, text_emb, text_emb_uncond, guidance_scale)

    imgs = vae.decode_from_2d_latent(z_t)
    imgs = (imgs.clamp(-1,1) + 1) / 2.0
    return imgs


In [None]:
# --- Visualize samples for digits 0..9 ---
labels_to_generate = [0,1,2,3,4,5,6,7,8,9]
samples = sample_from_labels(labels_to_generate, num_steps=100, guidance_scale=3.0)
samples = samples.cpu()

plt.figure(figsize=(12,3))
for i, lbl in enumerate(labels_to_generate):
    plt.subplot(2, (len(labels_to_generate)+1)//2, i+1)
    plt.imshow(samples[i,0], cmap="gray")
    plt.axis("off")
    plt.title(f"digit {lbl}")
plt.tight_layout()
plt.show()


In [None]:
# --- FID Evaluation Utilities ---
import numpy as np

@torch.no_grad()
def get_inception_model(device):
    inception = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1)
    inception.fc = nn.Identity()
    inception.eval().to(device)
    return inception

@torch.no_grad()
def get_inception_features(x, model, device):
    if x.size(1) == 1:
        x = x.repeat(1,3,1,1)
    x = F.interpolate(x, size=(299,299), mode="bilinear", align_corners=False)
    x = x.to(device)
    feat = model(x)
    if isinstance(feat, tuple):
        feat = feat[0]
    return feat.cpu().numpy()

def compute_activation_statistics_from_loader(loader, model, num_samples=10000, device="cuda"):
    model.eval()
    feats = []
    seen = 0
    for imgs, _ in tqdm(loader, desc="Real features for FID"):
        imgs = (imgs + 1) / 2.0
        bs = imgs.size(0)
        if seen + bs > num_samples:
            imgs = imgs[:num_samples-seen]
            bs = imgs.size(0)
        f = get_inception_features(imgs, model, device)
        feats.append(f)
        seen += bs
        if seen >= num_samples:
            break
    feats = np.concatenate(feats, axis=0)
    mu = np.mean(feats, axis=0)
    sigma = np.cov(feats, rowvar=False)
    return mu, sigma

def compute_activation_statistics_from_generator(gen_fn, model, num_samples=10000, batch_size=64, device="cuda"):
    model.eval()
    feats = []
    seen = 0
    while seen < num_samples:
        cur_bs = min(batch_size, num_samples-seen)
        imgs = gen_fn(cur_bs)
        f = get_inception_features(imgs, model, device)
        feats.append(f)
        seen += imgs.size(0)
    feats = np.concatenate(feats, axis=0)
    mu = np.mean(feats, axis=0)
    sigma = np.cov(feats, rowvar=False)
    return mu, sigma

def sqrtm_psd(mat):
    vals, vecs = np.linalg.eigh(mat)
    vals = np.maximum(vals, 0)
    return (vecs * np.sqrt(vals)).dot(vecs.T)

def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
    diff = mu1 - mu2
    covmean = sqrtm_psd(sigma1.dot(sigma2) + eps * np.eye(sigma1.shape[0]))
    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
    return float(fid)


In [None]:
# --- Compute FID between real MNIST and generated samples ---
inception = get_inception_model(device)

# 1) Real stats
mu_real, cov_real = compute_activation_statistics_from_loader(
    train_loader,
    inception,
    num_samples=500,  # can increase for better estimate
    device=device
)
print("Real stats done.")

# 2) Fake stats from diffusion model
@torch.no_grad()
def diffusion_generate_fn(batch_size):
    labels = torch.randint(0, 10, (batch_size,), device=device)
    imgs = sample_from_labels(labels, num_steps=100, guidance_scale=3.0)
    return imgs

mu_fake, cov_fake = compute_activation_statistics_from_generator(
    diffusion_generate_fn,
    inception,
    num_samples=500,
    batch_size=64,
    device=device
)
print("Fake stats done.")

fid_value = calculate_fid(mu_real, cov_real, mu_fake, cov_fake)
print(f"FID (latent diffusion vs real MNIST): {fid_value:.4f}")
