# Text-Conditioned 28×28 Diffusion (MNIST-style)

Minimal notebook to train and demo a tiny text-conditioned diffusion model that generates 28×28 grayscale images (MNIST style). Intended for fast iteration and interview demos; keep batch/steps small for a quick run, or bump them for quality.

**Contents**
- Optional lightweight installs
- Data: MNIST + simple text prompts
- Model: small text encoder + UNet with FiLM
- Diffusion training loop (classifier-free guidance ready)
- Sampling with adjustable guidance scale and step count
- Quick visualization grid



In [None]:
# Optional: install dependencies (usually available in most ML envs)
# !pip install torch torchvision tqdm matplotlib



In [None]:
import math
import random
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Device / seed
seed = 42
random.seed(seed)
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)



In [None]:
# Data: MNIST + FashionMNIST with text prompts
# We train on (image, prompt) pairs so the model learns real text conditioning.

MNIST_NAMES = [
    "zero", "one", "two", "three", "four",
    "five", "six", "seven", "eight", "nine",
]
FASHION_NAMES = [
    "t-shirt/top", "trouser", "pullover", "dress", "coat",
    "sandal", "shirt", "sneaker", "bag", "ankle boot",
]

MNIST_TEMPLATES = [
    "digit {name}",
    "handwritten digit {name}",
]
FASHION_TEMPLATES = [
    "fashion {name}",
    "clothing {name}",
]

def prompt_mnist(y: int) -> str:
    name = MNIST_NAMES[int(y)]
    return random.choice(MNIST_TEMPLATES).format(name=name)

def prompt_fashion(y: int) -> str:
    name = FASHION_NAMES[int(y)]
    return random.choice(FASHION_TEMPLATES).format(name=name)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # [-1,1]
])

class PromptedDataset(torch.utils.data.Dataset):
    def __init__(self, base_ds, prompt_fn):
        self.base_ds = base_ds
        self.prompt_fn = prompt_fn
    def __len__(self):
        return len(self.base_ds)
    def __getitem__(self, idx):
        x, y = self.base_ds[idx]
        return x, self.prompt_fn(y)

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
fashion_train = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)

use_dataset = 'both'  # 'mnist' | 'fashion' | 'both'
if use_dataset == 'mnist':
    train_ds = PromptedDataset(mnist_train, prompt_mnist)
elif use_dataset == 'fashion':
    train_ds = PromptedDataset(fashion_train, prompt_fashion)
else:
    train_ds = torch.utils.data.ConcatDataset([
        PromptedDataset(mnist_train, prompt_mnist),
        PromptedDataset(fashion_train, prompt_fashion),
    ])

batch_size = 256 if torch.cuda.is_available() else 128
print('Dataset size:', len(train_ds), 'batch_size:', batch_size)



In [None]:
# Building blocks: time embedding, text encoder, FiLM-UNet

@dataclass
class DiffusionConfig:
    img_size: int = 28
    base_channels: int = 32
    channel_mults: tuple = (1, 2, 2)
    text_dim: int = 128
    time_dim: int = 128
    num_heads: int = 4
    num_layers_text: int = 2
    dropout: float = 0.1
    timesteps: int = 400
    beta_start: float = 1e-4
    beta_end: float = 0.02

def sinusoidal_time_embedding(timesteps: torch.Tensor, dim: int) -> torch.Tensor:
    half = dim // 2
    freqs = torch.exp(-math.log(10000) * torch.arange(half, device=timesteps.device) / (half - 1))
    args = timesteps[:, None] * freqs[None, :]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    return emb

class TextEncoder(nn.Module):
    def __init__(self, vocab_size: int, emb_dim: int, hidden_dim: int, num_layers: int = 2, dropout: float = 0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.gru = nn.GRU(emb_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        self.proj = nn.Sequential(nn.Linear(hidden_dim * 2, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim))
    def forward(self, tokens, lengths):
        x = self.embedding(tokens)
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        out, _ = self.gru(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        # take last valid timestep per sequence
        idx = (lengths - 1).clamp(min=0)
        last = out[torch.arange(out.size(0)), idx]
        return self.proj(last)

class FiLM(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim * 2)
    def forward(self, x, cond):
        scale, shift = self.linear(cond).chunk(2, dim=1)
        return x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim, text_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.time_film = FiLM(time_dim, out_ch)
        self.text_film = FiLM(text_dim, out_ch)
        self.act = nn.SiLU()
        self.norm1 = nn.GroupNorm(8, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
    def forward(self, x, t_emb, txt_emb):
        h = self.act(self.norm1(self.conv1(x)))
        h = self.time_film(h, t_emb)
        h = self.text_film(h, txt_emb)
        h = self.act(self.norm2(self.conv2(h)))
        return h + self.skip(x)

class UNet(nn.Module):
    def __init__(self, cfg: DiffusionConfig, text_vocab: int):
        super().__init__()
        ch = cfg.base_channels
        self.text_encoder = TextEncoder(
            text_vocab,
            emb_dim=cfg.text_dim,
            hidden_dim=cfg.text_dim,
            num_layers=cfg.num_layers_text,
            dropout=cfg.dropout,
        )
        self.null_text = nn.Parameter(torch.zeros(cfg.text_dim))

        self.time_mlp = nn.Sequential(
            nn.Linear(cfg.time_dim, cfg.time_dim * 4), nn.SiLU(), nn.Linear(cfg.time_dim * 4, cfg.time_dim)
        )

        # Down
        self.enc1 = ResBlock(1, ch, cfg.time_dim, cfg.text_dim)
        self.enc2 = ResBlock(ch, ch * cfg.channel_mults[1], cfg.time_dim, cfg.text_dim)
        self.enc3 = ResBlock(ch * cfg.channel_mults[1], ch * cfg.channel_mults[2], cfg.time_dim, cfg.text_dim)
        self.pool = nn.AvgPool2d(2)

        # Bottleneck
        self.mid = ResBlock(ch * cfg.channel_mults[2], ch * cfg.channel_mults[2], cfg.time_dim, cfg.text_dim)

        # Up
        self.up1 = nn.ConvTranspose2d(ch * cfg.channel_mults[2], ch * cfg.channel_mults[1], 2, stride=2)
        self.dec1 = ResBlock(ch * cfg.channel_mults[1] * 2, ch * cfg.channel_mults[1], cfg.time_dim, cfg.text_dim)
        self.up2 = nn.ConvTranspose2d(ch * cfg.channel_mults[1], ch, 2, stride=2)
        self.dec2 = ResBlock(ch * 2, ch, cfg.time_dim, cfg.text_dim)

        self.out = nn.Conv2d(ch, 1, 1)

    def forward(self, x, t, txt_tokens, txt_lens, drop_text_prob: float = 0.1):
        t_emb = self.time_mlp(sinusoidal_time_embedding(t, self.time_mlp[0].in_features))

        # --- classifier-free guidance support ---
        # During sampling we need a true unconditional path even in eval(),
        # so drop_text_prob==1.0 forces the null embedding.
        if drop_text_prob >= 1.0:
            txt_emb = self.null_text[None, :].expand(x.size(0), -1)
        else:
            txt_emb = self.text_encoder(txt_tokens, txt_lens)
            if self.training and drop_text_prob > 0.0:
                mask = (torch.rand(txt_emb.size(0), device=txt_emb.device) < drop_text_prob).float()[:, None]
                txt_emb = txt_emb * (1 - mask) + self.null_text[None, :] * mask

        e1 = self.enc1(x, t_emb, txt_emb)
        e2 = self.enc2(self.pool(e1), t_emb, txt_emb)
        e3 = self.enc3(self.pool(e2), t_emb, txt_emb)

        m = self.mid(e3, t_emb, txt_emb)

        d1 = self.up1(m)
        d1 = torch.cat([d1, e2], dim=1)
        d1 = self.dec1(d1, t_emb, txt_emb)
        d2 = self.up2(d1)
        d2 = torch.cat([d2, e1], dim=1)
        d2 = self.dec2(d2, t_emb, txt_emb)
        return self.out(d2)



In [None]:
# Tokenizer helpers (simple character-level; robust for small prompt vocab)

class SimpleCharTokenizer:
    def __init__(self, texts, pad_token='<pad>', unk_token='<unk>'):
        chars = sorted(list({c for t in texts for c in t.lower()}))
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.itos = [pad_token, unk_token] + chars
        self.stoi = {c: i for i, c in enumerate(self.itos)}

    @property
    def vocab_size(self):
        return len(self.itos)

    def encode(self, text: str, max_len: int = 32):
        text = text.lower()
        ids = [self.stoi.get(c, self.stoi[self.unk_token]) for c in text[:max_len]]
        length = len(ids)
        if length < max_len:
            ids += [self.stoi[self.pad_token]] * (max_len - length)
        return torch.tensor(ids, dtype=torch.long), torch.tensor(length, dtype=torch.long)

    def encode_batch(self, texts, max_len: int = 32):
        toks, lens = zip(*[self.encode(t, max_len=max_len) for t in texts])
        return torch.stack(toks), torch.stack(lens)

# Build tokenizer vocab from both datasets' templates
all_prompts = []
for i in range(10):
    for tpl in MNIST_TEMPLATES:
        all_prompts.append(tpl.format(name=MNIST_NAMES[i]))
    for tpl in FASHION_TEMPLATES:
        all_prompts.append(tpl.format(name=FASHION_NAMES[i]))

tokenizer = SimpleCharTokenizer(all_prompts)
print('Vocab size:', tokenizer.vocab_size)



In [None]:
# Diffusion utilities

class SimpleDiffusion(nn.Module):
    """DDPM utilities with schedule tensors registered as buffers (so .to(device) works)."""
    def __init__(self, cfg: DiffusionConfig):
        super().__init__()
        self.cfg = cfg

        betas = torch.linspace(cfg.beta_start, cfg.beta_end, cfg.timesteps, dtype=torch.float32)
        alphas = 1.0 - betas
        alphas_cum = torch.cumprod(alphas, dim=0)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alphas_cum', alphas_cum)

    def sample_timesteps(self, batch_size: int, device: Optional[torch.device] = None):
        if device is None:
            device = self.betas.device
        return torch.randint(0, self.cfg.timesteps, (batch_size,), device=device)

    def add_noise(self, x0, t, noise):
        # t: (B,) long on same device as buffers
        sqrt_ac = self.alphas_cum[t].sqrt()[:, None, None, None]
        sqrt_one_minus_ac = (1 - self.alphas_cum[t]).sqrt()[:, None, None, None]
        return sqrt_ac * x0 + sqrt_one_minus_ac * noise

    @torch.no_grad()
    def _predict_eps_cfg(self, model, x, t: int, txt_tokens, txt_lens, guidance_scale: float):
        # classifier-free guidance
        t_batch = torch.full((x.size(0),), t, device=x.device, dtype=torch.long)
        eps_text = model(x, t_batch, txt_tokens, txt_lens, drop_text_prob=0.0)
        eps_null = model(x, t_batch, txt_tokens, txt_lens, drop_text_prob=1.0)
        return eps_null + guidance_scale * (eps_text - eps_null)

    @torch.no_grad()
    def p_sample(self, model, x, t: int, txt_tokens, txt_lens, guidance_scale: float = 2.0):
        """Ancestral DDPM step (kept for reference)."""
        eps = self._predict_eps_cfg(model, x, t, txt_tokens, txt_lens, guidance_scale)
        beta_t = self.betas[t]
        alpha_t = self.alphas[t]
        alpha_cum_t = self.alphas_cum[t]
        mean = (1 / alpha_t.sqrt()) * (x - beta_t / (1 - alpha_cum_t).sqrt() * eps)
        if t == 0:
            return mean
        noise = torch.randn_like(x)
        return mean + beta_t.sqrt() * noise

    @torch.no_grad()
    def ddim_step(self, model, x, t: int, t_prev: int, txt_tokens, txt_lens, guidance_scale: float = 2.0, eta: float = 0.0):
        """DDIM step that supports skipping timesteps cleanly."""
        eps = self._predict_eps_cfg(model, x, t, txt_tokens, txt_lens, guidance_scale)

        ac_t = self.alphas_cum[t]
        ac_prev = self.alphas_cum[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=x.device)

        # predict x0
        x0 = (x - (1 - ac_t).sqrt() * eps) / ac_t.sqrt()
        x0 = x0.clamp(-1, 1)

        # DDIM variance control
        if t_prev < 0:
            return x0

        sigma = eta * torch.sqrt((1 - ac_prev) / (1 - ac_t) * (1 - ac_t / ac_prev))
        noise = torch.randn_like(x) if eta > 0 else torch.zeros_like(x)

        dir_xt = (1 - ac_prev - sigma**2).sqrt() * eps
        x_prev = ac_prev.sqrt() * x0 + dir_xt + sigma * noise
        return x_prev

    @torch.no_grad()
    def sample(self, model, txt_tokens, txt_lens, steps: int = 40, guidance_scale: float = 2.0, eta: float = 0.0):
        model.eval()
        b = txt_tokens.size(0)
        x = torch.randn(b, 1, self.cfg.img_size, self.cfg.img_size, device=txt_tokens.device)

        # choose a schedule of timesteps (descending)
        steps = int(steps)
        steps = max(2, min(steps, self.cfg.timesteps))
        ts = torch.linspace(self.cfg.timesteps - 1, 0, steps, device=txt_tokens.device).long()

        for i in range(len(ts)):
            t = int(ts[i].item())
            t_prev = int(ts[i + 1].item()) if i + 1 < len(ts) else -1
            x = self.ddim_step(model, x, t, t_prev, txt_tokens, txt_lens, guidance_scale=guidance_scale, eta=eta)

        return x


def collate_batch(batch):
    imgs, prompts = zip(*batch)
    imgs = torch.stack(imgs)
    toks, lens = tokenizer.encode_batch(prompts, max_len=32)
    return imgs, toks, lens, list(prompts)

# DataLoader (now that tokenizer + collate exist)
train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    collate_fn=collate_batch,
)
print('Train batches:', len(train_loader))



In [None]:
# Model + training loop

cfg = DiffusionConfig()
diffusion = SimpleDiffusion(cfg).to(device)

base_model = UNet(cfg, text_vocab=tokenizer.vocab_size).to(device)

# Use both T4s on Kaggle if available
if torch.cuda.device_count() > 1:
    print('Using DataParallel on', torch.cuda.device_count(), 'GPUs')
    model = nn.DataParallel(base_model)
else:
    model = base_model

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-2)

# EMA helps samples look cleaner with the same number of training steps
@torch.no_grad()
def unwrap(m: nn.Module) -> nn.Module:
    return m.module if hasattr(m, 'module') else m

@torch.no_grad()
def ema_update(ema_model: nn.Module, model: nn.Module, decay: float = 0.999):
    src = unwrap(model)
    for ema_p, p in zip(ema_model.parameters(), src.parameters()):
        ema_p.data.mul_(decay).add_(p.data, alpha=1 - decay)

ema_model = UNet(cfg, text_vocab=tokenizer.vocab_size).to(device)
ema_model.load_state_dict(unwrap(model).state_dict())
ema_decay = 0.999

# Mixed precision for speed on T4
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

num_steps = 30_000  # "complete" baseline for MNIST/FashionMNIST
log_interval = 200

model.train()
step = 0
pbar = tqdm(total=num_steps, desc='train')
while step < num_steps:
    for x, tokens, lens, _prompts in train_loader:
        x = x.to(device, non_blocking=True)
        tokens = tokens.to(device, non_blocking=True)
        lens = lens.to(device, non_blocking=True)

        t = diffusion.sample_timesteps(x.size(0), device)
        noise = torch.randn_like(x)
        x_noisy = diffusion.add_noise(x, t, noise)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            pred = model(x_noisy, t, tokens, lens, drop_text_prob=0.15)
            loss = F.mse_loss(pred, noise)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        ema_update(ema_model, model, decay=ema_decay)

        step += 1
        pbar.update(1)
        if step % log_interval == 0:
            pbar.set_postfix(loss=float(loss.detach()))
        if step >= num_steps:
            break
pbar.close()
print('Finished training steps:', step)



In [None]:
# Sampling demo
# Use EMA weights for nicer samples
ema_model.eval()

# Demo prompts (digits + fashion)
prompts = [
    "digit zero",
    "digit three",
    "digit seven",
    "fashion sneaker",
    "fashion ankle boot",
]

tokens = []
lens = []
for p in prompts:
    tok, l = tokenizer.encode(p)
    tokens.append(tok)
    lens.append(l)
tokens = torch.stack(tokens).to(device)
lens = torch.stack(lens).to(device)

with torch.no_grad():
    # try steps=30..60, guidance_scale=1.5..3.0
    samples = diffusion.sample(ema_model, tokens, lens, steps=40, guidance_scale=2.0, eta=0.0)

imgs = (samples.clamp(-1, 1) * 0.5 + 0.5).cpu()  # back to [0,1]
fig, axes = plt.subplots(1, len(prompts), figsize=(len(prompts)*2, 2))
for ax, img, p in zip(axes, imgs, prompts):
    ax.imshow(img[0], cmap='gray')
    ax.axis('off')
    ax.set_title(p)
plt.show()



## Tips / Next steps
- Increase `num_steps`, `base_channels`, and train for more iterations for better quality.
- Try more diverse prompts by expanding the tokenizer vocabulary (add your own prompt list).
- Save/Load: `torch.save(model.state_dict(), 'cfdiffusion.pt')` and load with `model.load_state_dict(torch.load(...))`.
- Swap sampler stride logic to use full step schedule for best results; current stride ties `steps` to coarse skipping for speed.
- To condition on richer text, swap the GRU for a tiny Transformer encoder.

