# Simple Text-to-Image Diffusion DemoWe train a tiny diffusion model that makes 32x32 drawings from short prompts.

## Imports and setupWe load common tools and pick CPU or GPU.

In [None]:
import jsonfrom pathlib import Pathimport mathimport randomimport torchfrom torch import nnfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transformsfrom PIL import Imageimport matplotlib.pyplot as pltdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')device

## Load paired text-image dataWe read the provided quickdraw sketches and keep pairs of file path and caption.

In [None]:
data_path = Path('quickdraw_dataset/data.jsonl')with data_path.open() as f:    pairs = [json.loads(line) for line in f]print(f"Loaded {len(pairs)} pairs")pairs[:2]

## Simple preprocessingWe resize images to 32x32 grayscale tensors and tokenize lowercase words.

In [None]:
image_tf = transforms.Compose([    transforms.Grayscale(),    transforms.Resize((32,32)),    transforms.ToTensor(),    transforms.Normalize(0.5, 0.5)])# Build word vocaball_words = set()for p in pairs:    for w in p['text'].lower().split():        all_words.add(w)word_to_idx = {w:i+1 for i,w in enumerate(sorted(all_words))}  # 0 for paddingvocab_size = len(word_to_idx) + 1class QuickDrawTextImage(Dataset):    def __init__(self, pairs):        self.pairs = pairs    def __len__(self):        return len(self.pairs)    def __getitem__(self, idx):        item = self.pairs[idx]        img = Image.open(item['image']).convert('L')        img = image_tf(img)        words = item['text'].lower().split()        token_ids = [word_to_idx[w] for w in words]        return img, torch.tensor(token_ids, dtype=torch.long)dataset = QuickDrawTextImage(pairs)def collate(batch, max_len=12):    imgs, seqs = zip(*batch)    imgs = torch.stack(imgs)    padded = torch.zeros(len(seqs), max_len, dtype=torch.long)    for i,seq in enumerate(seqs):        length = min(len(seq), max_len)        padded[i, :length] = seq[:length]    return imgs, paddeddloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate)imgs, tokens = next(iter(dloader))print(imgs.shape, tokens.shape)print(tokens[0])

## Diffusion helpersWe create a beta schedule plus functions that add noise and embed timesteps.

In [None]:
def cosine_beta_schedule(T, s=0.008):    steps = T + 1    x = torch.linspace(0, T, steps)    alphas_cumprod = torch.cos(((x / T) + s) / (1 + s) * math.pi * 0.5) ** 2    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])    return torch.clip(betas, 0.0001, 0.9999)T = 200betas = cosine_beta_schedule(T)alphas = 1. - betasalpha_bars = torch.cumprod(alphas, 0)def extract(a, t, x_shape):    bs = t.shape[0]    out = a.gather(-1, t.cpu()).to(device)    return out.view(bs, *((1,) * (len(x_shape) - 1)))def add_noise(x0, t, noise=None):    if noise is None:        noise = torch.randn_like(x0)    sqrt_ab = torch.sqrt(extract(alpha_bars, t, x0.shape))    sqrt_one_minus = torch.sqrt(1 - extract(alpha_bars, t, x0.shape))    return sqrt_ab * x0 + sqrt_one_minus * noise, noise# sinusoidal timestep embeddingclass TimeEmbedding(nn.Module):    def __init__(self, dim):        super().__init__()        self.dim = dim        self.proj = nn.Linear(dim, dim)    def forward(self, t):        half = self.dim // 2        freqs = torch.exp(            torch.arange(0, half, device=device, dtype=torch.float32) * -(math.log(10000) / (half - 1))        )        args = t[:, None].float() * freqs[None]        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)        return self.proj(emb)

## Text encoderWe embed words, average them, and project to condition the U-Net.

In [None]:
class TextEncoder(nn.Module):    def __init__(self, vocab, embed_dim=64, out_dim=128):        super().__init__()        self.embed = nn.Embedding(vocab, embed_dim, padding_idx=0)        self.proj = nn.Linear(embed_dim, out_dim)    def forward(self, tokens):        emb = self.embed(tokens)  # (B, L, E)        mask = (tokens != 0).float().unsqueeze(-1)        summed = (emb * mask).sum(1)        lengths = mask.sum(1).clamp(min=1)        avg = summed / lengths        return self.proj(avg)

## Tiny conditional U-NetWe predict noise with a few convolutions and FiLM layers for time and text.

In [None]:
class ResBlock(nn.Module):    def __init__(self, in_c, out_c, embed_dim):        super().__init__()        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)        self.act = nn.SiLU()        self.emb_proj = nn.Linear(embed_dim, out_c*2)        self.skip = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()    def forward(self, x, emb):        scale, shift = self.emb_proj(emb).chunk(2, dim=1)        h = self.conv1(x)        h = self.act(h + scale[:, :, None, None])        h = self.conv2(h)        h = h + shift[:, :, None, None]        return self.act(h + self.skip(x))class UNet(nn.Module):    def __init__(self, base=32, embed_dim=128):        super().__init__()        self.time_emb = TimeEmbedding(embed_dim)        self.text_proj = nn.Linear(embed_dim, embed_dim)        self.down1 = ResBlock(1, base, embed_dim)        self.down2 = ResBlock(base, base*2, embed_dim)        self.pool = nn.AvgPool2d(2)        self.mid = ResBlock(base*2, base*2, embed_dim)        self.up = nn.ConvTranspose2d(base*2, base, 2, stride=2)        self.up_block = ResBlock(base*2, base, embed_dim)        self.out = nn.Conv2d(base, 1, 1)    def forward(self, x, t, text_emb):        emb = self.time_emb(t) + self.text_proj(text_emb)        h1 = self.down1(x, emb)        h2 = self.down2(self.pool(h1), emb)        mid = self.mid(h2, emb)        up = self.up(mid)        cat = torch.cat([up, h1], dim=1)        h3 = self.up_block(cat, emb)        return self.out(h3)

## Training loopWe train a few epochs to predict the noise added at random steps.

In [None]:
model = UNet().to(device)text_encoder = TextEncoder(vocab_size).to(device)opt = torch.optim.Adam(list(model.parameters()) + list(text_encoder.parameters()), lr=1e-3)EPOCHS = 8for epoch in range(1, EPOCHS+1):    total_loss = 0    for imgs, tokens in dloader:        imgs = imgs.to(device)        tokens = tokens.to(device)        t = torch.randint(0, T, (imgs.size(0),), device=device)        text_emb = text_encoder(tokens)        noisy, noise = add_noise(imgs, t)        pred = model(noisy, t, text_emb)        loss = nn.functional.mse_loss(pred, noise)        opt.zero_grad()        loss.backward()        opt.step()        total_loss += loss.item() * imgs.size(0)    print(f"Epoch {epoch}: loss {total_loss/len(dataset):.4f}")

## Sampling functionWe reverse the noise process and draw images for any short prompt.

In [None]:
@torch.no_grad()def sample(prompt, steps=50):    tokens = torch.tensor([[word_to_idx.get(w,0) for w in prompt.lower().split()]], device=device)    text_emb = text_encoder(tokens)    x = torch.randn(1,1,32,32, device=device)    for i in reversed(range(steps)):        t = torch.tensor([i], device=device)        bet = betas[i]        alpha = alphas[i]        alpha_bar = alpha_bars[i]        pred_noise = model(x, t, text_emb)        coef1 = 1 / torch.sqrt(alpha)        coef2 = (1 - alpha) / torch.sqrt(1 - alpha_bar)        x = coef1 * (x - coef2 * pred_noise)        if i > 0:            noise = torch.randn_like(x)            x = x + torch.sqrt(bet) * noise    return x.clamp(-1,1)def show_image(tensor_img, title):    img = tensor_img.squeeze().cpu().numpy()    plt.imshow(img, cmap='gray')    plt.title(title)    plt.axis('off')sample_img = sample("a hand-drawn circle")show_image(sample_img, "Generated circle")plt.show()

## Quick evaluationWe draw a few prompts to check that shapes follow text.

In [None]:
prompts = [    "a hand-drawn circle",    "a hand-drawn square",    "a hand-drawn star"]fig, axes = plt.subplots(1, len(prompts), figsize=(9,3))for ax, p in zip(axes, prompts):    out = sample(p)    ax.imshow(out.squeeze().cpu(), cmap='gray')    ax.set_title(p.split()[-1])    ax.axis('off')plt.tight_layout()plt.show()