# Simple Text-to-Image Diffusion Demo

This notebook walks through a tiny diffusion-like pipeline aligned with the provided text-to-image generation steps. The code uses lightweight components and synthetic data so it can run quickly while still showing the full process.

## Data loading and preprocessing

We treat the dataset as pairs of images and text prompts. The links you provided can be downloaded manually into a local folder named `data/`. Each row should include an image file and its caption. For simplicity, the demo below builds a synthetic dataset with small random images and simple text labels so that the notebook runs without external downloads.

In [None]:
import os
import math
import random
from typing import List, Tuple

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt

torch.manual_seed(0)

### Tiny synthetic dataset

The dataset yields 32x32 RGB tensors and numeric tokens derived from the prompt. Replace the synthetic data with real images and captions from the provided Google Drive links by adjusting the `ImageTextDataset` class to read files instead of generating noise.

In [None]:
class ImageTextDataset(Dataset):
    def __init__(self, prompts: List[str], image_size: int = 32):
        self.prompts = prompts
        self.image_size = image_size
        self.transform = T.Compose([T.Resize((image_size, image_size)), T.ToTensor()])

    def __len__(self):
        return len(self.prompts)

    def _tokenize(self, text: str) -> torch.Tensor:
        # Very small tokenizer: map characters to integer ids and pad/truncate.
        tokens = [ord(c) % 256 for c in text.lower()][:16]
        tokens += [0] * (16 - len(tokens))
        return torch.tensor(tokens, dtype=torch.long)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        prompt = self.prompts[idx]
        # Synthetic image: random noise. Replace with real image loading if available.
        img = torch.rand(3, self.image_size, self.image_size)
        tokens = self._tokenize(prompt)
        return img, tokens

# Example prompts; swap these with captions from the shared dataset.
prompts = [
    "a calm sunset over low mountains",
    "bright flowers in a glass vase",
    "a small orange cat resting on a pillow",
    "city skyline at night with lights",
]
train_ds = ImageTextDataset(prompts)
train_dl = DataLoader(train_ds, batch_size=2, shuffle=True)

## Model components

The model couples a tiny text encoder with a light U-Net. The text encoder embeds tokens and averages them. The U-Net uses simple convolutions and injects the text embedding as a bias term during denoising.

In [None]:
class SimpleTextEncoder(nn.Module):
    def __init__(self, vocab_size: int = 256, embed_dim: int = 64):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)

    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        emb = self.embed(tokens)  # (B, T, D)
        return emb.mean(dim=1)    # (B, D)

class TinyUNet(nn.Module):
    def __init__(self, text_dim: int = 64):
        super().__init__()
        self.down = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
        )
        self.mid = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
        )
        self.up = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 3, 3, padding=1),
        )
        self.text_to_bias = nn.Linear(text_dim, 64)

    def forward(self, x: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor:
        h = self.down(x)
        bias = self.text_to_bias(text_emb).view(text_emb.size(0), -1, 1, 1)
        h = self.mid(h + bias)
        out = self.up(h)
        return out

text_encoder = SimpleTextEncoder()
denoiser = TinyUNet()

## Diffusion utilities

The forward process adds Gaussian noise according to a beta schedule, and the reverse process asks the U-Net to predict the added noise. Here we keep only a few steps for speed.

In [None]:
timesteps = 10
betas = torch.linspace(1e-4, 0.02, timesteps)
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0)

def add_noise(x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
    sqrt_ab = torch.sqrt(alpha_bars[t])[:, None, None, None]
    sqrt_one_minus = torch.sqrt(1 - alpha_bars[t])[:, None, None, None]
    return sqrt_ab * x0 + sqrt_one_minus * noise


## Training loop (tiny demo)

We train for a couple of epochs to show the mechanics. For a real project, increase the data size, epochs, and model capacity. Loss measures how well the model predicts the injected noise.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
text_encoder.to(device)
denoiser.to(device)
optimizer = torch.optim.Adam(list(text_encoder.parameters()) + list(denoiser.parameters()), lr=1e-3)

for epoch in range(2):
    for x0, tokens in train_dl:
        x0 = x0.to(device)
        tokens = tokens.to(device)
        t = torch.randint(0, timesteps, (x0.size(0),), device=device)
        noise = torch.randn_like(x0)
        noisy = add_noise(x0, t, noise)
        text_emb = text_encoder(tokens)
        pred_noise = denoiser(noisy, text_emb)
        loss = nn.functional.mse_loss(pred_noise, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}: loss={loss.item():.4f}')


## Sampling from noise

Starting from pure noise, we iteratively denoise using the learned model. The function below performs a handful of reverse steps and returns a generated image tensor.

In [None]:
@torch.no_grad()
def sample(prompt: str) -> torch.Tensor:
    denoiser.eval()
    text_encoder.eval()
    x = torch.randn(1, 3, train_ds.image_size, train_ds.image_size, device=device)
    tokens = train_ds._tokenize(prompt).unsqueeze(0).to(device)
    text_emb = text_encoder(tokens)
    for i in reversed(range(timesteps)):
        t = torch.full((1,), i, device=device, dtype=torch.long)
        noise_pred = denoiser(x, text_emb)
        alpha = alphas[i]
        alpha_bar = alpha_bars[i]
        if i > 0:
            noise = torch.randn_like(x)
            x = (1/torch.sqrt(alpha)) * (x - (1 - alpha) / torch.sqrt(1 - alpha_bar) * noise_pred) + torch.sqrt(betas[i]) * noise
        else:
            x = (1/torch.sqrt(alpha)) * (x - (1 - alpha) / torch.sqrt(1 - alpha_bar) * noise_pred)
    return x.clamp(0, 1).cpu().squeeze(0)

sample_image = sample('a calm sunset over low mountains')
plt.imshow(sample_image.permute(1, 2, 0))
plt.axis('off')
plt.show()
