# Simple Text-to-Image Diffusion Demo
This notebook shows a tiny diffusion pipeline for text-to-image using the provided shape sketches. I keep the comments short and simple.


## Install and Imports
I install the basic libraries and bring the imports I need for tensors, plotting, and loading the dataset.


In [None]:
# If your runtime misses these libraries, run this cell once.
# In many university labs this is enough to get numpy, torch, and plotting tools.
!pip install -q numpy torch torchvision matplotlib tqdm


In [None]:
import math
import random
from typing import List

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


## Data loading and prep
I load the tiny shape dataset (circle, square, triangle) and pair each image with a simple text prompt. I resize/normalize to a small square so training stays fast.


In [None]:
class ShapeTextDataset(Dataset):
    def __init__(self, data_dir="data", img_size=32):
        self.samples = []
        prompt_map = {
            "circle": "a simple black circle on white",
            "square": "a simple black square on white",
            "triangle": "a simple black triangle on white",
        }
        for name in ["circle", "square", "triangle"]:
            imgs = np.load(f"{data_dir}/{name}.npy")  # expected shape (N,H,W) or (N,H,W,1)
            if imgs.ndim == 4:
                imgs = imgs[..., 0]
            # ensure float32 and simple resize via numpy slicing if bigger than target
            imgs = imgs.astype(np.float32)
            # basic center crop/resize: pick middle square then resize with torch later
            self.samples.extend([(img, prompt_map[name]) for img in imgs])
        self.img_size = img_size

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img, prompt = self.samples[idx]
        img = torch.tensor(img)
        # add channel dim and simple resize via interpolation
        if img.ndim == 2:
            img = img.unsqueeze(0)
        img = img.unsqueeze(0) if img.ndim == 2 else img
        img = torch.nn.functional.interpolate(img.unsqueeze(0), size=(self.img_size, self.img_size), mode="bilinear", align_corners=False).squeeze(0)
        img = (img / 255.0) * 2 - 1  # scale to [-1,1]
        return img, prompt

def get_dataloader(batch_size=32, img_size=32):
    dataset = ShapeTextDataset(img_size=img_size)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataloader = get_dataloader(batch_size=16, img_size=32)
print("Total samples:", len(dataloader.dataset))


## Quick look at data
I plot a few samples to make sure the loader works and the prompts line up with the sketches.


In [None]:
imgs, prompts = next(iter(dataloader))
fig, axes = plt.subplots(1, 6, figsize=(12, 2))
for i in range(6):
    ax = axes[i]
    ax.imshow(((imgs[i].squeeze() + 1) / 2).clamp(0,1), cmap="gray")
    ax.set_title(prompts[i][:10] + "...")
    ax.axis("off")
plt.tight_layout()
plt.show()


## Diffusion helpers
I set up the beta schedule, sinusoidal time embeddings, and helper functions for the forward diffusion process.


In [None]:
def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

T = 200
betas = cosine_beta_schedule(T)
alphas = 1.0 - betas
alpha_cum = torch.cumprod(alphas, dim=0)


def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t.cpu()).float().to(t.device)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_ac = torch.sqrt(extract(alpha_cum, t, x_start.shape))
    sqrt_om = torch.sqrt(1 - extract(alpha_cum, t, x_start.shape))
    return sqrt_ac * x_start + sqrt_om * noise


## Text encoder and tiny U-Net
I encode the text with a GRU into a single vector and feed that into a light U-Net. The text and time embeddings modulate the conv blocks.


In [None]:
class TextEncoder(nn.Module):
    def __init__(self, embed_dim=128, hidden_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(1000, embed_dim)
        self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, tokens):
        emb = self.embedding(tokens)
        _, h = self.rnn(emb)
        return self.proj(h.squeeze(0))


def tokenize(prompts: List[str], vocab=None, max_len=16):
    if vocab is None:
        vocab = {}
    token_lists = []
    for text in prompts:
        tokens = []
        for word in text.lower().split():
            if word not in vocab:
                vocab[word] = len(vocab) + 1
            tokens.append(vocab[word])
        token_lists.append(tokens[:max_len])
    max_len = max(len(t) for t in token_lists)
    padded = []
    for t in token_lists:
        padded.append(t + [0] * (max_len - len(t)))
    return torch.tensor(padded, dtype=torch.long), vocab


def sinusoidal_embedding(n, d):
    pos = torch.arange(n)[:, None]
    dim = torch.arange(d)[None, :]
    angle = pos / (10000 ** (2 * (dim // 2) / d))
    emb = torch.zeros((n, d))
    emb[:, 0::2] = torch.sin(angle[:, 0::2])
    emb[:, 1::2] = torch.cos(angle[:, 1::2])
    return emb


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(1, out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(1, out_ch),
            nn.SiLU(),
        )

    def forward(self, x):
        return self.block(x)


class TinyUNet(nn.Module):
    def __init__(self, base=32, time_dim=128, text_dim=128):
        super().__init__()
        self.time_mlp = nn.Sequential(nn.Linear(time_dim, time_dim), nn.SiLU())
        self.text_proj = nn.Linear(text_dim, time_dim)

        self.down1 = ConvBlock(1, base)
        self.down2 = ConvBlock(base, base * 2)
        self.to_vec = nn.Conv2d(base * 2, base * 4, 3, padding=1)
        self.up1 = ConvBlock(base * 4, base * 2)
        self.up2 = ConvBlock(base * 2, base)
        self.out = nn.Conv2d(base, 1, 1)

    def forward(self, x, t_embed, text_embed):
        t = self.time_mlp(t_embed)[:, :, None, None]
        txt = self.text_proj(text_embed)[:, :, None, None]

        d1 = self.down1(x)
        d1 = d1 + t + txt
        d2 = self.down2(nn.functional.avg_pool2d(d1, 2))
        d2 = d2 + t + txt
        mid = self.to_vec(d2)
        mid = mid + t + txt
        u1 = nn.functional.interpolate(self.up1(mid), scale_factor=2, mode="nearest")
        u2 = self.up2(u1 + d1)
        return self.out(u2)


## Training loop
I pick a small number of steps so it runs on a laptop. The loss trains the U-Net to predict the noise added at each timestep.


In [None]:
text_encoder = TextEncoder().to(device)
model = TinyUNet().to(device)
optimizer = torch.optim.Adam(list(model.parameters()) + list(text_encoder.parameters()), lr=1e-3)
vocab = {}


def get_time_embedding(timesteps, dim=128):
    emb = sinusoidal_embedding(max(T, timesteps.max().item() + 1), dim).to(device)
    return emb[timesteps]


def p_losses(x0, prompts):
    b = x0.shape[0]
    t = torch.randint(0, T, (b,), device=device).long()
    noise = torch.randn_like(x0)
    x_noisy = q_sample(x0, t, noise)
    global vocab
    token_batch, vocab = tokenize(list(prompts), vocab)
    token_batch = token_batch.to(device)
    text_emb = text_encoder(token_batch)
    t_emb = get_time_embedding(t)
    pred = model(x_noisy, t_emb, text_emb)
    return nn.functional.mse_loss(pred, noise)


def train(epochs=3):
    model.train()
    text_encoder.train()
    for epoch in range(epochs):
        loop = tqdm(dataloader, desc=f"epoch {epoch+1}")
        for imgs, prompts in loop:
            imgs = imgs.to(device)
            loss = p_losses(imgs, prompts)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loop.set_postfix(loss=loss.item())

train(epochs=2)


## Sampling from the model
I start from random noise and apply the reverse diffusion steps while conditioning on the text prompt. A few prompts show the link between words and shapes.


In [None]:
@torch.no_grad()
def p_sample(model, x, t, text_emb):
    betas_t = extract(betas, t, x.shape).to(device)
    sqrt_one_minus_ac = torch.sqrt(1 - extract(alpha_cum, t, x.shape)).to(device)
    sqrt_recip_alpha = torch.sqrt(1.0 / extract(alphas, t, x.shape)).to(device)

    model_mean = sqrt_recip_alpha * (x - betas_t * model(x, get_time_embedding(t), text_emb) / sqrt_one_minus_ac)
    if t.item() == 0:
        return model_mean
    noise = torch.randn_like(x)
    return model_mean + torch.sqrt(betas_t) * noise


@torch.no_grad()
def sample(prompts: List[str], img_size=32):
    model.eval(); text_encoder.eval()
    token_batch, _ = tokenize(prompts, vocab)
    token_batch = token_batch.to(device)
    text_emb = text_encoder(token_batch)
    b = len(prompts)
    img = torch.randn((b, 1, img_size, img_size), device=device)
    for i in reversed(range(T)):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), text_emb)
    img = (img.clamp(-1, 1) + 1) / 2
    return img.cpu()

example_prompts = [
    "a simple black circle on white",
    "a simple black square on white",
    "a simple black triangle on white",
]

samples = sample(example_prompts)
grid = make_grid(samples, nrow=len(example_prompts), normalize=False)
plt.figure(figsize=(9, 3))
plt.imshow(grid.squeeze(), cmap="gray")
plt.axis("off")
plt.show()
