## 0. Environment check & imports

In [None]:
!nvidia-smi || echo "No GPU found"

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
!pip install -q torch torchvision tqdm pycocotools git+https://github.com/openai/CLIP.git

In [None]:
import os, math, random
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision import transforms
from torchvision.datasets import CocoCaptions
from torchvision.utils import make_grid, save_image

from tqdm.auto import tqdm
import clip

## 1. Download COCO val2017 images and captions

In [None]:
data_root = Path("./coco")
data_root.mkdir(parents=True, exist_ok=True)

%cd /content

# Download images (val2017)
if not (data_root / "val2017").exists():
    !wget -q http://images.cocodataset.org/zips/val2017.zip -O val2017.zip
    !unzip -q val2017.zip -d coco

# Download captions annotations
if not (data_root / "annotations").exists():
    !wget -q http://images.cocodataset.org/annotations/annotations_trainval2017.zip -O annotations_trainval2017.zip
    !unzip -q annotations_trainval2017.zip -d coco

%cd /content

## 2. Dataset: images + raw captions (CLIP will tokenize)

In [None]:
coco_root = "./coco"
img_root = os.path.join(coco_root, "val2017")
ann_file = os.path.join(coco_root, "annotations", "captions_val2017.json")

raw_coco = CocoCaptions(root=img_root, annFile=ann_file)
len(raw_coco)

In [None]:
image_size = 64

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),  # to [-1,1]
])

class CocoTextImageDataset(Dataset):
    """Returns (image_tensor, caption_string)."""
    def __init__(self, root, ann_file, transform=None, subset_size=None):
        self.coco = CocoCaptions(root=root, annFile=ann_file, transform=transform)
        if subset_size is not None:
            self.indices = list(range(min(subset_size, len(self.coco))))
        else:
            self.indices = list(range(len(self.coco)))

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        img, caps = self.coco[real_idx]
        cap = random.choice(caps)
        return img, cap

subset_size = 4000  # adjust for Colab speed/quality trade-off
dataset_full = CocoTextImageDataset(img_root, ann_file,
                                    transform=transform,
                                    subset_size=subset_size)

# train/val split
val_ratio = 0.1
val_size = int(val_ratio * len(dataset_full))
train_size = len(dataset_full) - val_size
train_dataset, val_dataset = random_split(dataset_full, [train_size, val_size])

batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        shuffle=False, num_workers=0, pin_memory=True)

len(train_dataset), len(val_dataset)

## 3. CLIP text encoder (conditioning)

In [None]:
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False)

clip_model.eval()
for p in clip_model.parameters():
    p.requires_grad = False

with torch.no_grad():
    dummy = clip.tokenize(["hello world"]).to(device)
    dummy_emb = clip_model.encode_text(dummy)
text_dim = dummy_emb.shape[-1]
text_dim

In [None]:
@torch.no_grad()
def get_text_emb(captions, device=device):
    tokens = clip.tokenize(captions, truncate=True).to(device)
    text_features = clip_model.encode_text(tokens)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    return text_features.float()

## 4. Convolutional VAE (image ↔ latent)

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_ch=3, latent_ch=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, 32, 4, 2, 1),  # 64 -> 32
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1),     # 32 -> 16
            nn.ReLU(True),
            nn.Conv2d(64, 128, 3, 1, 1),    # 16 -> 16
            nn.ReLU(True),
        )
        self.mu_conv = nn.Conv2d(128, latent_ch, 1)
        self.logvar_conv = nn.Conv2d(128, latent_ch, 1)

    def forward(self, x):
        h = self.net(x)
        mu = self.mu_conv(h)
        logvar = self.logvar_conv(h)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, out_ch=3, latent_ch=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(latent_ch, 128, 3, 1, 1),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 16 -> 32
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),   # 32 -> 64
            nn.ReLU(True),
            nn.Conv2d(32, out_ch, 3, 1, 1),
            nn.Tanh(),
        )

    def forward(self, z):
        return self.net(z)

class VAE(nn.Module):
    def __init__(self, in_ch=3, latent_ch=4):
        super().__init__()
        self.encoder = Encoder(in_ch, latent_ch)
        self.decoder = Decoder(in_ch, latent_ch)

    def encode(self, x):
        mu, logvar = self.encoder(x)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z, mu, logvar

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        x_rec = self.decode(z)
        return x_rec, mu, logvar

In [None]:
vae_latent_ch = 4
vae = VAE(in_ch=3, latent_ch=vae_latent_ch).to(device)
sum(p.numel() for p in vae.parameters()) / 1e6

In [None]:
def vae_loss(x, x_rec, mu, logvar, beta_kl=1e-3):
    recon = F.mse_loss(x_rec, x, reduction='mean')
    kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + beta_kl * kl, recon, kl

### 4.1 Train VAE

In [None]:
vae_epochs = 20
vae_lr = 1e-3
opt_vae = torch.optim.Adam(vae.parameters(), lr=vae_lr)

for epoch in range(1, vae_epochs+1):
    vae.train()
    tot, rec, kl_sum, n = 0.0, 0.0, 0.0, 0
    pbar = tqdm(train_loader, desc=f"VAE Epoch {epoch}/{vae_epochs}")
    for imgs, caps in pbar:
        imgs = imgs.to(device)
        x_rec, mu, logvar = vae(imgs)
        loss, recon, kl = vae_loss(imgs, x_rec, mu, logvar)

        opt_vae.zero_grad()
        loss.backward()
        opt_vae.step()

        tot += loss.item(); rec += recon.item(); kl_sum += kl.item(); n += 1
        pbar.set_postfix({"loss": f"{loss.item():.4f}", "recon": f"{recon.item():.4f}"})

    print(f"[VAE] Epoch {epoch}: total={tot/n:.4f}, recon={rec/n:.4f}, kl={kl_sum/n:.4f}")

    vae.eval()
    with torch.no_grad():
        imgs, caps = next(iter(val_loader))
        imgs = imgs.to(device)
        x_rec, mu, logvar = vae(imgs)
        grid = make_grid(torch.cat([imgs, x_rec], dim=0), nrow=imgs.size(0))
        os.makedirs("vae_recon", exist_ok=True)
        save_image((grid+1)/2, f"vae_recon/epoch_{epoch:03d}.png")

In [None]:
# --- Visualize VAE reconstructions from the validation set ---

vae.eval()
with torch.no_grad():
    # get one batch from validation loader
    imgs, caps = next(iter(val_loader))
    imgs = imgs.to(device)

    # pass through VAE
    x_rec, mu, logvar = vae(imgs)

# make a grid: first row originals, second row reconstructions
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

grid = make_grid(
    torch.cat([imgs, x_rec], dim=0),  # stack originals + recon
    nrow=imgs.size(0)                 # one row per set (orig row, recon row)
)

# images are in [-1, 1] → convert to [0, 1] for display
grid_np = (grid.permute(1, 2, 0).cpu().numpy() + 1) / 2.0

plt.figure(figsize=(10, 5))
plt.imshow(grid_np)
plt.axis("off")
plt.title("Top: original images  |  Bottom: VAE reconstructions")
plt.show()


## 5. Diffusion process in latent space

In [None]:
def make_beta_schedule(T, beta_start=1e-4, beta_end=0.02):
    return torch.linspace(beta_start, beta_end, T)

T = 200
betas = make_beta_schedule(T).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.tensor([1.], device=device), alphas_cumprod[:-1]], dim=0)

sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

In [None]:
def q_sample_latent(z0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(z0)
    sqrt_ac = sqrt_alphas_cumprod[t].view(-1,1,1,1)
    sqrt_om = sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
    return sqrt_ac * z0 + sqrt_om * noise

class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if self.dim % 2 == 1:
            emb = F.pad(emb, (0, 1))
        return emb

## 6. UNet with cross-attention in latent space

In [None]:
class CrossAttention2D(nn.Module):
    def __init__(self, channels, ctx_dim, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (channels // num_heads) ** -0.5
        self.q = nn.Linear(channels, channels)
        self.k = nn.Linear(ctx_dim, channels)
        self.v = nn.Linear(ctx_dim, channels)
        self.proj = nn.Linear(channels, channels)

    def forward(self, x, ctx):
        B, C, H, W = x.shape
        x_flat = x.permute(0, 2, 3, 1).view(B, H*W, C)
        ctx = ctx.unsqueeze(1)

        q = self.q(x_flat)
        k = self.k(ctx)
        v = self.v(ctx)

        def split_heads(t):
            B, L, D = t.shape
            head_dim = D // self.num_heads
            t = t.view(B, L, self.num_heads, head_dim).transpose(1, 2)
            return t

        q = split_heads(q)
        k = split_heads(k)
        v = split_heads(v)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-2)

        out = attn @ v
        out = out.transpose(1, 2).contiguous().view(B, H*W, C)
        out = self.proj(out)
        out = out.view(B, H, W, C).permute(0, 3, 1, 2)
        return x + out

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, emb_dim):
        super().__init__()
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.act = nn.SiLU()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)

        self.time_fc = nn.Linear(emb_dim, out_ch)
        self.text_fc = nn.Linear(emb_dim, out_ch)

        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, t_emb, txt_emb):
        h = self.norm1(x)
        h = self.act(h)
        h = self.conv1(h)

        cond = self.time_fc(t_emb) + self.text_fc(txt_emb)
        h = h + cond[:, :, None, None]

        h = self.norm2(h)
        h = self.act(h)
        h = self.conv2(h)
        return h + self.skip(x)

In [None]:
class LatentUNet(nn.Module):
    def __init__(self, latent_ch=4, base_ch=64, emb_dim=512, ctx_dim=512):
        super().__init__()
        self.time_mlp = nn.Sequential(
            TimeEmbedding(emb_dim),
            nn.Linear(emb_dim, emb_dim),
            nn.SiLU(),
            nn.Linear(emb_dim, emb_dim),
        )

        self.conv_in = nn.Conv2d(latent_ch, base_ch, 3, padding=1)
        self.down1 = ResBlock(base_ch, base_ch, emb_dim)
        self.down2 = ResBlock(base_ch, base_ch*2, emb_dim)
        self.down3 = ResBlock(base_ch*2, base_ch*4, emb_dim)

        self.downsample = nn.AvgPool2d(2)

        self.mid1 = ResBlock(base_ch*4, base_ch*4, emb_dim)
        self.mid_attn = CrossAttention2D(base_ch*4, ctx_dim)
        self.mid2 = ResBlock(base_ch*4, base_ch*4, emb_dim)

        self.up3 = ResBlock(base_ch*4, base_ch*2, emb_dim)
        self.up2 = ResBlock(base_ch*2, base_ch, emb_dim)
        self.up1 = ResBlock(base_ch, base_ch, emb_dim)

        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv_out = nn.Conv2d(base_ch, latent_ch, 3, padding=1)

    def forward(self, zt, t, txt_emb):
        t_emb = self.time_mlp(t)

        x0 = self.conv_in(zt)
        x1 = self.down1(x0, t_emb, txt_emb)
        x2 = self.downsample(x1)
        x2 = self.down2(x2, t_emb, txt_emb)
        x3 = self.downsample(x2)
        x3 = self.down3(x3, t_emb, txt_emb)
        x4 = self.downsample(x3)

        m = self.mid1(x4, t_emb, txt_emb)
        m = self.mid_attn(m, txt_emb)
        m = self.mid2(m, t_emb, txt_emb)

        u3 = self.upsample(m) + x3
        u3 = self.up3(u3, t_emb, txt_emb)
        u2 = self.upsample(u3) + x2
        u2 = self.up2(u2, t_emb, txt_emb)
        u1 = self.upsample(u2) + x1
        u1 = self.up1(u1, t_emb, txt_emb)

        return self.conv_out(u1)

In [None]:
base_ch = 64
diff_model = LatentUNet(latent_ch=vae_latent_ch,
                        base_ch=base_ch,
                        emb_dim=text_dim,
                        ctx_dim=text_dim).to(device)
sum(p.numel() for p in diff_model.parameters()) / 1e6

## 7. Diffusion training in latent space

In [None]:
def get_t(batch_size, T, device):
    return torch.randint(0, T, (batch_size,), device=device)

def p_losses_latent(model, vae, x, captions, t):
    with torch.no_grad():
        z0, mu, logvar = vae.encode(x)
    noise = torch.randn_like(z0)
    zt = q_sample_latent(z0, t, noise)
    txt_emb = get_text_emb(captions, device=device)
    pred_noise = model(zt, t, txt_emb)
    return F.mse_loss(pred_noise, noise)

In [None]:
diff_epochs = 20
lr = 2e-4
opt_diff = torch.optim.Adam(diff_model.parameters(), lr=lr)

for p in vae.parameters():
    p.requires_grad = False
vae.eval()

fixed_prompts = [
    "a man riding a bicycle on a street",
    "a dog running on the grass",
    "a group of people sitting at a table",
    "a plate of food on a table",
]

### 7.1 Reverse diffusion step and sampling

In [None]:
@torch.no_grad()
def p_sample_latent_step(model, zt, t, txt_emb):
    pred_noise = model(zt, t, txt_emb)

    beta_t = betas[t].view(-1,1,1,1)
    alpha_t = alphas[t].view(-1,1,1,1)
    ac      = alphas_cumprod[t].view(-1,1,1,1)
    ac_prev = alphas_cumprod_prev[t].view(-1,1,1,1)

    sqrt_one_minus_ac = sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
    sqrt_recip_ac = torch.sqrt(1.0 / ac)
    z0_hat = (zt - sqrt_one_minus_ac * pred_noise) * sqrt_recip_ac

    posterior_mean = (
        torch.sqrt(ac_prev) * beta_t / (1.0 - ac) * z0_hat +
        torch.sqrt(alpha_t) * (1.0 - ac_prev) / (1.0 - ac) * zt
    )

    noise = torch.randn_like(zt)
    var = posterior_variance[t].view(-1,1,1,1)
    nonzero_mask = (t > 0).float().view(-1,1,1,1)
    return posterior_mean + nonzero_mask * torch.sqrt(var) * noise

In [None]:
@torch.no_grad()
def sample_latent_and_decode(model, vae, prompts, n_steps=T):
    model_was_training = model.training
    model.eval()

    b = len(prompts)
    txt_emb = get_text_emb(prompts, device=device)
    zt = torch.randn(b, vae_latent_ch, image_size//4, image_size//4, device=device)

    for i in reversed(range(n_steps)):
        t = torch.full((b,), i, device=device, dtype=torch.long)
        zt = p_sample_latent_step(model, zt, t, txt_emb)

    z0 = zt.clamp(-5, 5)
    x0 = vae.decode(z0).clamp(-1, 1)

    if model_was_training:
        model.train()
    return x0

In [None]:
global_step = 0
sample_every = 400

for epoch in range(1, diff_epochs+1):
    diff_model.train()
    epoch_losses = []
    pbar = tqdm(train_loader, desc=f"Diffusion Epoch {epoch}/{diff_epochs}")
    for imgs, caps in pbar:
        imgs = imgs.to(device)
        B = imgs.size(0)
        t = get_t(B, T, device)

        loss = p_losses_latent(diff_model, vae, imgs, caps, t)
        opt_diff.zero_grad()
        loss.backward()
        opt_diff.step()

        epoch_losses.append(loss.item())
        global_step += 1
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        if global_step % sample_every == 0:
            with torch.no_grad():
                samples = sample_latent_and_decode(diff_model, vae, fixed_prompts, n_steps=T)
                grid = make_grid((samples+1)/2, nrow=2)
                os.makedirs("samples_latent_clip", exist_ok=True)
                save_image(grid, f"samples_latent_clip/step_{global_step:06d}.png")

    print(f"[Diffusion] Epoch {epoch}: train_loss={sum(epoch_losses)/len(epoch_losses):.4f}")

    diff_model.eval()
    val_losses = []
    with torch.no_grad():
        for imgs, caps in val_loader:
            imgs = imgs.to(device)
            B = imgs.size(0)
            t = get_t(B, T, device)
            loss = p_losses_latent(diff_model, vae, imgs, caps, t)
            val_losses.append(loss.item())
    print(f"[Diffusion] Epoch {epoch}: val_loss={sum(val_losses)/len(val_losses):.4f}")

    with torch.no_grad():
        samples = sample_latent_and_decode(diff_model, vae, fixed_prompts, n_steps=T)
        grid = make_grid((samples+1)/2, nrow=2)
        os.makedirs("samples_latent_clip", exist_ok=True)
        save_image(grid, f"samples_latent_clip/epoch_{epoch:03d}.png")

## 8. Inference: generate from custom prompts

In [None]:
user_prompts = [
    "a cat sitting on a chair",
    "a person skiing on a snowy mountain",
    "a red car driving on a road",
    "a cup of coffee on a table",
]

with torch.no_grad():
    samples = sample_latent_and_decode(diff_model, vae, user_prompts, n_steps=T)
    grid = make_grid((samples+1)/2, nrow=2)
    os.makedirs("samples_latent_clip", exist_ok=True)
    save_image(grid, "samples_latent_clip/user_prompts.png")

import matplotlib.pyplot as plt
plt.imshow(grid.permute(1,2,0).cpu().numpy())
plt.axis("off")