# My diffusion notebook for shape prompts
I pretend this is my lab diary: I train a text-to-image diffusion model for white circles/triangles/squares on black. I keep things tidy but still simple so it feels like something a third-year student can follow.


## Block 1: Quick setup check
I park install commands here so I do not forget them later. If the lab GPU already has the libs, I just skip this bit.


In [None]:
# If you need a fresh environment, remove the leading `#` characters and run once.# !pip install -q torch torchvision torchaudio matplotlib tqdm numpy# !pip install -q nbformat

## Block 2: Imports and device pick
I gather the libraries, seed everything, and ask for CUDA if it exists. This is how I keep runs stable between sessions.


In [None]:
import mathimport osimport randomfrom typing import Dict, List, Tupleimport numpy as npimport torchfrom torch import nnfrom torch.utils.data import Dataset, DataLoaderimport torchvision.transforms.functional as TFfrom torchvision.utils import make_gridimport matplotlib.pyplot as pltfrom tqdm.auto import tqdmtorch.manual_seed(42)np.random.seed(42)random.seed(42)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')amp_dtype = torch.float16 if torch.cuda.is_available() else torch.float32print('Using device:', device)

## Block 3: Dataset and transforms
I load the numpy sketches, turn them into tensors, and add light jitter so the model does not get lazy. Images are resized to a neat square size.


In [None]:
class ShapeTextDataset(Dataset):    def __init__(self, data_dir: str = 'data', resolution: int = 64, limit_per_class: int = 4000, augment: bool = True):        self.samples: List[Tuple[np.ndarray, str]] = []        prompt_map = {            'circle': 'a hand-drawn circle',            'square': 'a hand-drawn square',            'triangle': 'a hand-drawn triangle',        }        for name in ['circle', 'square', 'triangle']:            path = os.path.join(data_dir, f'{name}.npy')            arr = np.load(path, mmap_mode='r')[:limit_per_class]            if arr.ndim == 4:                arr = arr[..., 0]            for img in arr:                self.samples.append((img.astype(np.float32), prompt_map[name]))        self.resolution = resolution        self.augment = augment    def __len__(self):        return len(self.samples)    def _prepare_tensor(self, img: np.ndarray) -> torch.Tensor:        tensor = torch.as_tensor(img, dtype=torch.float32)        if tensor.ndim == 1:            side = int(math.sqrt(tensor.numel()))            tensor = tensor.view(side, side)        if tensor.ndim == 2:            tensor = tensor.unsqueeze(0)        if tensor.shape[0] > 1:            tensor = tensor.mean(dim=0, keepdim=True)        return tensor    def __getitem__(self, idx):        img, prompt = self.samples[idx]        img = self._prepare_tensor(img)        img = TF.resize(img, [self.resolution, self.resolution])        if self.augment:            angle = random.uniform(-10, 10)            translate = (random.uniform(-2, 2), random.uniform(-2, 2))            scale = random.uniform(0.95, 1.05)            img = TF.rotate(img, angle, fill=0.0)            img = TF.affine(img, angle=0.0, translate=translate, scale=scale, shear=0.0, fill=0.0)        img = (img / 255.0).clamp(0, 1) * 2 - 1        return img, promptdef build_loaders(batch_size=32, resolution=64, limit_per_class=4000, augment=True, num_workers=2):    dataset = ShapeTextDataset(resolution=resolution, limit_per_class=limit_per_class, augment=augment)    print(f'Loaded {len(dataset)} sketches (cap {limit_per_class} per class)')    loader = DataLoader(        dataset,        batch_size=batch_size,        shuffle=True,        num_workers=num_workers,        persistent_workers=bool(num_workers),        pin_memory=torch.cuda.is_available(),    )    return dataset, loadertrain_dataset, train_loader = build_loaders(batch_size=24, resolution=64, limit_per_class=2000, augment=True, num_workers=2)

## Block 4: Peek at the data
I always glance at a mini-batch to be sure labels and the white-on-black style look right. It calms me before training.


In [None]:
preview_dataset, preview_loader = build_loaders(batch_size=6, resolution=64, limit_per_class=32, augment=False, num_workers=0)imgs, prompts = next(iter(preview_loader))fig, axes = plt.subplots(1, imgs.size(0), figsize=(12, 2))for i, ax in enumerate(axes):    ax.imshow(((imgs[i].squeeze() + 1) / 2).clamp(0, 1), cmap='gray')    ax.set_title(prompts[i])    ax.axis('off')plt.tight_layout()plt.show()

## Block 5: Diffusion math bits
I define the cosine noise schedule plus helper math for the forward and reverse steps. Keeping it in one place makes it easier to debug.


In [None]:
def cosine_beta_schedule(timesteps: int, s: float = 0.008):    steps = timesteps + 1    x = torch.linspace(0, timesteps, steps)    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])    return torch.clip(betas, 0.0001, 0.9999)T = 600betas = cosine_beta_schedule(T)alphas = 1.0 - betasalpha_cum = torch.cumprod(alphas, dim=0)def extract(a: torch.Tensor, t: torch.Tensor, x_shape):    b, *_ = t.shape    out = a.gather(-1, t.cpu()).float().to(t.device)    return out.reshape(b, *((1,) * (len(x_shape) - 1)))def q_sample(x_start: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:    return extract(alpha_cum.sqrt(), t, x_start.shape) * x_start + extract((1 - alpha_cum).sqrt(), t, x_start.shape) * noisedef predict_start_from_noise(x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:    return extract((1 / alpha_cum).sqrt(), t, x_t.shape) * x_t - extract((1 / alpha_cum - 1).sqrt(), t, x_t.shape) * noisedef p_mean_variance_from_eps(pred_noise: torch.Tensor, x: torch.Tensor, t: torch.Tensor):    beta_t = extract(betas, t, x.shape)    alpha_t = extract(alphas, t, x.shape)    alpha_cum_t = extract(alpha_cum, t, x.shape)    mean = (1 / torch.sqrt(alpha_t)) * (x - (beta_t / torch.sqrt(1 - alpha_cum_t)) * pred_noise)    var = beta_t    return mean, var

## Block 6: Small text encoder
I stick to a tiny Transformer encoder so prompts stay stable but training stays fast. The vocab grows when new words appear.


In [None]:
MAX_VOCAB = 2048MAX_LEN = 12def tokenize(prompts: List[str], vocab: Dict[str, int], allow_new: bool = True, pad_to: int = MAX_LEN):    tokens = []    for text in prompts:        ids = []        for word in text.lower().split():            if word not in vocab and allow_new and len(vocab) < MAX_VOCAB:                vocab[word] = len(vocab)            ids.append(vocab.get(word, vocab.get('<unk>', 1)))        tokens.append(ids[:pad_to])    max_len = max(1, min(pad_to, max(len(t) for t in tokens)))    padded = []    for ids in tokens:        padded.append(ids + [0] * (max_len - len(ids)))    return torch.tensor(padded, dtype=torch.long), vocabclass TextEncoder(nn.Module):    def __init__(self, vocab_size: int = MAX_VOCAB, emb_dim: int = 256, text_dim: int = 256):        super().__init__()        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)        self.pos_embedding = nn.Parameter(torch.randn(MAX_LEN, emb_dim) * 0.01)        encoder_layer = nn.TransformerEncoderLayer(            d_model=emb_dim, nhead=4, dim_feedforward=emb_dim * 4, batch_first=True, dropout=0.05        )        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)        self.proj = nn.Sequential(nn.LayerNorm(emb_dim), nn.Linear(emb_dim, text_dim))    def forward(self, tokens: torch.Tensor):        pos = self.pos_embedding[: tokens.size(1)]        x = self.embedding(tokens) + pos        x = self.encoder(x)        pooled = x.mean(dim=1)        return self.proj(pooled)

## Block 7: U-Net denoiser with attention
I mix residual blocks, FiLM conditioning, and a little self-attention. This keeps the shapes crisp while listening to the text.


In [None]:
class ResBlock(nn.Module):    def __init__(self, in_ch, out_ch, cond_dim):        super().__init__()        self.in_ch = in_ch        self.out_ch = out_ch        self.norm1 = nn.GroupNorm(8, in_ch)        self.norm2 = nn.GroupNorm(8, out_ch)        self.act = nn.SiLU()        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)        self.cond = nn.Linear(cond_dim, out_ch * 2)        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()    def forward(self, x, cond):        h = self.conv1(self.act(self.norm1(x)))        scale, shift = self.cond(cond).chunk(2, dim=1)        h = self.conv2(self.act(self.norm2(h)) * (1 + scale[:, :, None, None]) + shift[:, :, None, None])        return h + self.skip(x)class AttentionBlock(nn.Module):    def __init__(self, channels, heads=4):        super().__init__()        self.norm = nn.GroupNorm(8, channels)        self.attn = nn.MultiheadAttention(channels, heads, batch_first=True)    def forward(self, x):        b, c, h, w = x.shape        y = self.norm(x).view(b, c, h * w).transpose(1, 2)        y, _ = self.attn(y, y, y)        y = y.transpose(1, 2).view(b, c, h, w)        return x + yclass UNetDenoiser(nn.Module):    def __init__(self, base=64, time_dim=256, text_dim=256):        super().__init__()        cond_dim = time_dim + text_dim        self.time_mlp = nn.Sequential(nn.Linear(time_dim, time_dim), nn.SiLU(), nn.Linear(time_dim, time_dim))        self.text_mlp = nn.Sequential(nn.Linear(text_dim, text_dim), nn.SiLU())        self.input = nn.Conv2d(1, base, 3, padding=1)        self.down1 = ResBlock(base, base, cond_dim)        self.down2 = ResBlock(base, base * 2, cond_dim)        self.attn2 = AttentionBlock(base * 2)        self.down3 = ResBlock(base * 2, base * 3, cond_dim)        self.attn3 = AttentionBlock(base * 3)        self.mid1 = ResBlock(base * 3, base * 3, cond_dim)        self.mid_attn = AttentionBlock(base * 3)        self.mid2 = ResBlock(base * 3, base * 3, cond_dim)        self.up3 = ResBlock(base * 3 + base * 3, base * 2, cond_dim)        self.up_attn3 = AttentionBlock(base * 2)        self.up2 = ResBlock(base * 2 + base * 2, base, cond_dim)        self.up_attn2 = AttentionBlock(base)        self.up1 = ResBlock(base + base, base, cond_dim)        self.to_noise = nn.Sequential(nn.GroupNorm(8, base), nn.SiLU(), nn.Conv2d(base, 1, 3, padding=1))    def forward(self, x, t_embed, text_embed):        t_feat = self.time_mlp(t_embed)        txt_feat = self.text_mlp(text_embed)        cond = torch.cat([t_feat, txt_feat], dim=1)        x0 = self.input(x)        d1 = self.down1(x0, cond)        d2 = self.down2(nn.functional.avg_pool2d(d1, 2), cond)        d2 = self.attn2(d2)        d3 = self.down3(nn.functional.avg_pool2d(d2, 2), cond)        d3 = self.attn3(d3)        mid = self.mid1(d3, cond)        mid = self.mid_attn(mid)        mid = self.mid2(mid, cond)        u3 = nn.functional.interpolate(mid, scale_factor=2, mode='nearest')        u3 = self.up3(torch.cat([u3, d3], dim=1), cond)        u3 = self.up_attn3(u3)        u2 = nn.functional.interpolate(u3, scale_factor=2, mode='nearest')        u2 = self.up2(torch.cat([u2, d2], dim=1), cond)        u2 = self.up_attn2(u2)        u1 = self.up1(torch.cat([u2, d1], dim=1), cond)        return self.to_noise(u1)

## Block 8: Training helpers
I set up embeddings, optimizer, EMA tracking, and the noise-prediction loss. Mixed precision keeps the GPU busy without melting it.


In [None]:
def sinusoidal_embedding(n: int, d: int):    pos = torch.arange(n)[:, None]    dim = torch.arange(d)[None, :]    angle = pos / (10000 ** (2 * (dim // 2) / d))    emb = torch.zeros((n, d))    emb[:, 0::2] = torch.sin(angle[:, 0::2])    emb[:, 1::2] = torch.cos(angle[:, 1::2])    return embdef get_time_embedding(timesteps: torch.Tensor, dim: int = 256):    emb = sinusoidal_embedding(max(T, timesteps.max().item() + 1), dim).to(timesteps.device)    return emb[timesteps]def update_ema(src: nn.Module, tgt: nn.Module, decay: float):    with torch.no_grad():        for ps, pt in zip(src.parameters(), tgt.parameters()):            pt.mul_(decay).add_(ps, alpha=1 - decay)vocab: Dict[str, int] = {'<pad>': 0, '<unk>': 1}text_encoder = TextEncoder().to(device)model = UNetDenoiser().to(device)ema_model = UNetDenoiser().to(device)ema_model.load_state_dict(model.state_dict())optimizer = torch.optim.AdamW(list(model.parameters()) + list(text_encoder.parameters()), lr=3e-4, weight_decay=1e-4)scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))ema_decay = 0.997guidance_dropout = 0.1def p_losses(x0: torch.Tensor, prompts: List[str]):    b = x0.shape[0]    t = torch.randint(0, T, (b,), device=device).long()    noise = torch.randn_like(x0)    x_noisy = q_sample(x0, t, noise)    drop_mask = torch.rand(b, device=device) < guidance_dropout    mixed_prompts = [p if not m else '' for p, m in zip(prompts, drop_mask)]    tokens, _ = tokenize(mixed_prompts, vocab, allow_new=True)    tokens = tokens.to(device)    t_emb = get_time_embedding(t)    text_emb = text_encoder(tokens)    with torch.cuda.amp.autocast(enabled=(device.type == 'cuda'), dtype=amp_dtype):        pred = model(x_noisy, t_emb, text_emb)        loss = nn.functional.mse_loss(pred, noise)    return loss

## Block 9: Training loop
I train with cosine learning rate, gradient clipping, and EMA backups. Bump the epoch count if you want even cleaner strokes.


In [None]:
def train(epochs: int = 12, grad_clip: float = 1.0):    global vocab    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs * len(train_loader), eta_min=5e-5)    step = 0    for epoch in range(epochs):        loop = tqdm(train_loader, desc=f'epoch {epoch+1}')        for imgs, prompts in loop:            imgs = imgs.to(device)            loss = p_losses(imgs, prompts)            optimizer.zero_grad()            scaler.scale(loss).backward()            nn.utils.clip_grad_norm_(list(model.parameters()) + list(text_encoder.parameters()), grad_clip)            scaler.step(optimizer)            scaler.update()            lr_scheduler.step()            update_ema(model, ema_model, ema_decay)            loop.set_postfix(loss=loss.item())            step += 1        torch.save(            {                'model': model.state_dict(),                'ema': ema_model.state_dict(),                'text': text_encoder.state_dict(),                'vocab': vocab,            },            'shape_diffusion.pt',        )        print(f'Epoch {epoch+1} checkpoint saved')# Uncomment to start training when ready# train(epochs=20)

## Block 10: Sampling time
I sample with classifier-free guidance so prompts stay clear. Raise `guidance_scale` for sharper alignment, lower it for softer strokes.


In [None]:
@torch.no_grad()def sample(prompts: List[str], guidance_scale: float = 1.6, use_ema: bool = True):    net = ema_model if use_ema else model    net.eval()    cond_tokens, _ = tokenize(prompts, vocab, allow_new=False)    uncond_tokens, _ = tokenize([''] * len(prompts), vocab, allow_new=False)    cond_tokens = cond_tokens.to(device)    uncond_tokens = uncond_tokens.to(device)    cond_text = text_encoder(cond_tokens)    uncond_text = text_encoder(uncond_tokens)    b = len(prompts)    img = torch.randn(b, 1, 64, 64, device=device)    for i in tqdm(reversed(range(T)), desc='sampling'):        t = torch.full((b,), i, device=device, dtype=torch.long)        t_emb = get_time_embedding(t)        eps_cond = net(img, t_emb, cond_text)        eps_uncond = net(img, t_emb, uncond_text)        eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)        mean, var = p_mean_variance_from_eps(eps, img, t)        if i == 0:            img = mean        else:            noise = torch.randn_like(img)            img = mean + torch.sqrt(var) * noise    return img.clamp(-1, 1)example_prompts = ['a hand-drawn circle', 'a hand-drawn triangle', 'a hand-drawn square']# Uncomment after training to visualize# samples = sample(example_prompts, guidance_scale=1.8)# grid = make_grid(samples, nrow=3, normalize=True, value_range=(-1, 1))# plt.figure(figsize=(9, 3))# plt.imshow(grid.permute(1, 2, 0).cpu().numpy(), cmap='gray')# plt.axis('off')# plt.show()