# Diffusion Models in PyTorch: Text-to-Image (COCO) + Inpainting (CelebA)

I implemented 2 related diffusion tasks:
1. Text-to-Image generation on MS COCO captions
2. Image inpainting on CelebA

Use the flags in Section 0 to run either part.


In [None]:
import os
import math
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, datasets
from torchvision.datasets import CocoCaptions, CelebA
from torchvision.utils import make_grid, save_image

from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel



#------- Flags to control sections---------------

RUN_TEXT_TO_IMAGE = True    # Part1: COCO text-to-image (Change to False to skip )
RUN_INPAINTING = True   # Part2: CelebA inpainting


# --------- General configuration --------


IMAGE_SIZE = 64     # 64x64 for speed
BATCH_SIZE = 96
NUM_DIFFUSION_STEPS = 300   # T steps in diffusion (you need to increase for better results)
LEARNING_RATE = 1e-4

# To speed up / demo training
MAX_TRAIN_SAMPLES_COCO   = 5000   # None for full dataset
MAX_TRAIN_SAMPLES_CELEBA = 20000  # None for full dataset

NUM_EPOCHS_TEXT2IMG = 25 # will be something after 90+ epochs
NUM_EPOCHS_INPAINT  = 10 # will be werygood after 50-60 epochs

SAMPLE_STEPS_TEXT2IMG = 50   # number of sampling steps to actually use (<= NUM_DIFFUSION_STEPS)
SAMPLE_STEPS_INPAINT  = 50


# COCO: point these to COCO2017 train images and captions JSON
COCO_ROOT     = "content/coco/train2017"
COCO_ANN_FILE = "content/coco/annotations/captions_train2017.json"

# Celeba root
CELEBA_ROOT = "content/celeba" 

# Device & seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

#Change mps to cuda if using nvidia gpu <<<<<<<<<<<<<
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

print("Using device:", device)


## 1. Diffusion utilities (DDPM-style)

In [None]:

def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
    """Linear beta schedule from DDPM paper."""
    return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)

betas = linear_beta_schedule(NUM_DIFFUSION_STEPS).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat(
    [torch.tensor([1.0], device=device), alphas_cumprod[:-1]], dim=0
)

sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)


# For gpu use this
# def extract(a, t, x_shape):
#     """Extract values from a 1-D tensor a at positions t and reshape."""
#     b = t.shape[0]
#     out = a.gather(-1, t.cpu()).to(t.device)
#     return out.view(b, *([1] * (len(x_shape) - 1)))

# For mps compatibility (I'm on mac)
def extract(a, t, x_shape):
    """
    a: (T,) tensor
    t: (B,) int64 timesteps on device
    """
    b = t.shape[0]
    a = a.to(t.device)
    out = a.gather(0, t)
    return out.view(b, *([1] * (len(x_shape) - 1)))


def q_sample(x_start, t, noise=None):
    """Forward diffusion: q(x_t | x_0)."""
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_ac_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_om_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)
    return sqrt_ac_t * x_start + sqrt_om_t * noise

@torch.no_grad()
def p_sample(model, x, t, text_emb=None, mask=None, x_known=None):
    """One reverse step x_t -> x_{t-1}."""
    if mask is not None and x_known is not None:
        x_in = x * (1 - mask) + x_known * mask
        net_input = torch.cat([x_in, mask], dim=1)
    else:
        net_input = x

    eps_theta = model(net_input, t, text_emb)

    betas_t = extract(betas, t, x.shape)
    sqrt_om_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_a_t = extract(sqrt_recip_alphas, t, x.shape)

    model_mean = sqrt_recip_a_t * (x - betas_t * eps_theta / sqrt_om_t)

    post_var_t = extract(posterior_variance, t, x.shape)
    noise = torch.randn_like(x)
    nonzero_mask = (t > 0).float().view(-1, 1, 1, 1)
    x_prev = model_mean + nonzero_mask * torch.sqrt(post_var_t) * noise

    if mask is not None and x_known is not None:
        x_prev = x_prev * (1 - mask) + x_known * mask
    return x_prev


## 2. UNet backbone with time + optional text conditioning

In [None]:

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, emb_dim, groups=8):
        super().__init__()
        self.norm1 = nn.GroupNorm(groups, in_ch)
        self.norm2 = nn.GroupNorm(groups, out_ch)
        self.act = nn.SiLU()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.emb_proj = nn.Linear(emb_dim, out_ch)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, emb):
        h = self.conv1(self.act(self.norm1(x)))
        emb_out = self.emb_proj(emb).unsqueeze(-1).unsqueeze(-1)
        h = h + emb_out
        h = self.conv2(self.act(self.norm2(h)))
        return h + self.skip(x)

def sinusoidal_time_embedding(timesteps, dim):
    device = timesteps.device
    half_dim = dim // 2
    emb_factor = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=device) * -emb_factor)
    emb = timesteps.float().unsqueeze(1) * emb.unsqueeze(0)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if dim % 2 == 1:
        emb = F.pad(emb, (0, 1))
    return emb

class UNetModel(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_channels=64,
                 channel_mults=(1,2,4), time_emb_dim=256, text_context_dim=None):
        super().__init__()
        emb_dim = base_channels * 4
        self.time_emb_dim = time_emb_dim

        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, emb_dim),
            nn.SiLU(),
            nn.Linear(emb_dim, emb_dim),
        )
        if text_context_dim is not None:
            self.text_mlp = nn.Sequential(
                nn.Linear(text_context_dim, emb_dim),
                nn.SiLU(),
                nn.Linear(emb_dim, emb_dim),
            )
        else:
            self.text_mlp = None

        self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        self.downs = nn.ModuleList()
        ch = base_channels
        self.down_channels = [ch]
        for mult in channel_mults:
            out_ch = base_channels * mult
            self.downs.append(nn.ModuleDict({
                "block1": ResBlock(ch, out_ch, emb_dim),
                "block2": ResBlock(out_ch, out_ch, emb_dim),
                "down": nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1),
            }))
            ch = out_ch
            self.down_channels.append(ch)

        self.mid_block1 = ResBlock(ch, ch, emb_dim)
        self.mid_block2 = ResBlock(ch, ch, emb_dim)

        self.ups = nn.ModuleList()
        for mult in reversed(channel_mults):
            out_ch = base_channels * mult
            skip_ch = self.down_channels.pop()
            self.ups.append(nn.ModuleDict({
                "up": nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1),
                "block1": ResBlock(ch + skip_ch, out_ch, emb_dim),
                "block2": ResBlock(out_ch, out_ch, emb_dim),
            }))
            ch = out_ch

        self.out_norm = nn.GroupNorm(8, ch)
        self.out_conv = nn.Conv2d(ch, out_channels, 3, padding=1)

    def forward(self, x, t, text_context=None):
        t_emb = sinusoidal_time_embedding(t, self.time_emb_dim)
        t_emb = self.time_mlp(t_emb)
        if self.text_mlp is not None and text_context is not None:
            emb = t_emb + self.text_mlp(text_context)
        else:
            emb = t_emb

        x = self.init_conv(x)
        skips = [x]
        for down in self.downs:
            x = down["block1"](x, emb)
            x = down["block2"](x, emb)
            skips.append(x)
            x = down["down"](x)

        x = self.mid_block1(x, emb)
        x = self.mid_block2(x, emb)

        for up in self.ups:
            x = up["up"](x)
            skip = skips.pop()
            x = torch.cat([x, skip], dim=1)
            x = up["block1"](x, emb)
            x = up["block2"](x, emb)

        x = F.silu(self.out_norm(x))
        return self.out_conv(x)


## 3. Helper functions

In [None]:

def denormalize_to_uint8(x):
    x = x.clamp(-1, 1)
    return (x + 1) / 2.0

def show_images_grid(images, nrow=4, filename=None):
    images = denormalize_to_uint8(images)
    grid = make_grid(images, nrow=nrow)
    if filename is not None:
        save_image(grid, filename)
    return grid

def random_rectangle_mask(batch_size, height, width, min_ratio=0.25, max_ratio=0.5, device=None):
    if device is None:
        device = torch.device("cpu")
    masks = torch.ones(batch_size, 1, height, width, device=device)
    for i in range(batch_size):
        h = random.randint(int(height * min_ratio), int(height * max_ratio))
        w = random.randint(int(width * min_ratio), int(width * max_ratio))
        top = random.randint(0, height - h)
        left = random.randint(0, width - w)
        masks[i, :, top:top+h, left:left+w] = 0.0
    return masks


# Part 1 — Text-to-Image on MS COCO

In [None]:

if RUN_TEXT_TO_IMAGE:
    TEXT_ENCODER_MODEL_NAME = "distilbert-base-uncased"
    tokenizer = AutoTokenizer.from_pretrained(TEXT_ENCODER_MODEL_NAME)
    text_encoder = AutoModel.from_pretrained(TEXT_ENCODER_MODEL_NAME).to(device)
    text_encoder.eval()
    for p in text_encoder.parameters():
        p.requires_grad = False
    text_embedding_dim = text_encoder.config.hidden_size
    print("Text encoder dim:", text_embedding_dim)


In [None]:

if RUN_TEXT_TO_IMAGE:
    coco_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3),
    ])

    class COCOTxt2ImgDataset(Dataset):
        def __init__(self, root, annFile, transform=None, max_samples=None):
            super().__init__()
            self.coco = CocoCaptions(root=root, annFile=annFile, transform=transform)
            indices = list(range(len(self.coco)))
            if max_samples is not None:
                random.shuffle(indices)
                indices = indices[:max_samples]
            self.indices = indices

        def __len__(self): return len(self.indices)

        def __getitem__(self, idx):
            real_idx = self.indices[idx]
            img, captions = self.coco[real_idx]
            caption = random.choice(captions) if isinstance(captions, list) else captions
            return img, caption

    print("Loading COCO captions dataset....")
    coco_dataset = COCOTxt2ImgDataset(
        root=COCO_ROOT,
        annFile=COCO_ANN_FILE,
        transform=coco_transform,
        max_samples=MAX_TRAIN_SAMPLES_COCO,
    )
    coco_loader = DataLoader(coco_dataset, batch_size=BATCH_SIZE, shuffle=True,
                             num_workers=0, pin_memory=True, drop_last=True) # For mac num workers = 0 change if needed
    print("COCO samples:", len(coco_dataset))


In [None]:

if RUN_TEXT_TO_IMAGE:
    text2img_model = UNetModel(in_channels=3, out_channels=3, base_channels=64,
                               channel_mults=(1,2,4), time_emb_dim=256,
                               text_context_dim=text_embedding_dim).to(device)
    opt_text2img = torch.optim.AdamW(text2img_model.parameters(), lr=LEARNING_RATE)
    os.makedirs("outputs_text2img", exist_ok=True)

    for epoch in range(NUM_EPOCHS_TEXT2IMG):
        text2img_model.train()
        pbar = tqdm(coco_loader, desc=f"[Text2Img] Epoch {epoch+1}/{NUM_EPOCHS_TEXT2IMG}")
        for imgs, captions in pbar:
            imgs = imgs.to(device)
            enc = tokenizer(list(captions), padding=True, truncation=True, max_length=32, return_tensors="pt")
            input_ids = enc["input_ids"].to(device)
            attention_mask = enc["attention_mask"].to(device)
            with torch.no_grad():
                text_out = text_encoder(input_ids=input_ids, attention_mask=attention_mask)
                text_emb = text_out.last_hidden_state.mean(dim=1)

            b = imgs.size(0)
            t = torch.randint(0, NUM_DIFFUSION_STEPS, (b,), device=device).long()
            noise = torch.randn_like(imgs)
            x_t = q_sample(imgs, t, noise)
            pred_noise = text2img_model(x_t, t, text_emb)
            loss = F.mse_loss(pred_noise, noise)

            opt_text2img.zero_grad()
            loss.backward()
            opt_text2img.step()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        ckpt_path = f"outputs_text2img/text2img_unet_epoch{epoch+1}.pt"
        torch.save(text2img_model.state_dict(), ckpt_path)
        print("Saved:", ckpt_path)


In [None]:

if RUN_TEXT_TO_IMAGE:
    @torch.no_grad()
    def sample_text2img(prompts, num_steps=SAMPLE_STEPS_TEXT2IMG):
        text2img_model.eval()
        enc = tokenizer(prompts, padding=True, truncation=True, max_length=32, return_tensors="pt")
        input_ids = enc["input_ids"].to(device)
        attention_mask = enc["attention_mask"].to(device)
        with torch.no_grad():
            text_out = text_encoder(input_ids=input_ids, attention_mask=attention_mask)
            text_emb = text_out.last_hidden_state.mean(dim=1)

        B = len(prompts)
        x = torch.randn(B, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)
        timesteps = torch.linspace(0, NUM_DIFFUSION_STEPS-1, num_steps, dtype=torch.long)

        for i in tqdm(reversed(range(num_steps)), desc="Sampling (text2img)"):
            t_int = timesteps[i]
            t = torch.full((B,), t_int, device=device, dtype=torch.long)
            x = p_sample(text2img_model, x, t, text_emb=text_emb)
        return x

    demo_prompts = ["a red car parked on the street", "a dog playing in the park",
                    "a bowl of fruits on a table", "a city skyline at sunset"]
    samples = sample_text2img(demo_prompts)
    show_images_grid(samples, nrow=4, filename="outputs_text2img/samples_text2img.png")


# Part 2 — Image Inpainting on CelebA

In [None]:
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader

class CelebAImagesOnly(Dataset):
    def __init__(self, img_dir, transform=None, max_samples=None):
        self.img_paths = sorted(Path(img_dir).glob("*.jpg"))
        if max_samples is not None:
            self.img_paths = self.img_paths[:max_samples]
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img

CELEBA_IMG_DIR = "content/celeba/img_align_celeba"

celeba_transform = transforms.Compose([
    transforms.CenterCrop(148),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

celeba = CelebAImagesOnly(
    CELEBA_IMG_DIR,
    transform=celeba_transform,
    max_samples=MAX_TRAIN_SAMPLES_CELEBA
)

celeba_loader = DataLoader(
    celeba,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0, #because mps, workers=0 for mac Change if needed
    pin_memory=True,
    drop_last=True
)

print("CelebA images:", len(celeba))


In [None]:

if RUN_INPAINTING:
    inpaint_model = UNetModel(in_channels=4, out_channels=3, base_channels=64,
                              channel_mults=(1,2,4), time_emb_dim=256,
                              text_context_dim=None).to(device)
    opt_inpaint = torch.optim.AdamW(inpaint_model.parameters(), lr=LEARNING_RATE)
    os.makedirs("outputs_inpaint", exist_ok=True)

    for epoch in range(NUM_EPOCHS_INPAINT):
        inpaint_model.train()
        pbar = tqdm(celeba_loader, desc=f"[Inpaint] Epoch {epoch+1}/{NUM_EPOCHS_INPAINT}")
        for imgs in pbar:
            imgs = imgs.to(device)
            B, C, H, W = imgs.shape
            mask = random_rectangle_mask(B, H, W, device=device)  # 1 known, 0 hole
            t = torch.randint(0, NUM_DIFFUSION_STEPS, (B,), device=device).long()
            noise = torch.randn_like(imgs)
            x_t = q_sample(imgs, t, noise)
            x_t_in = x_t * (1 - mask) + imgs * mask
            net_input = torch.cat([x_t_in, mask], dim=1)
            pred_noise = inpaint_model(net_input, t, None)

            hole = (1 - mask)
            loss = ((pred_noise - noise)**2 * hole).sum() / hole.sum()

            opt_inpaint.zero_grad()
            loss.backward()
            opt_inpaint.step()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        ckpt = f"outputs_inpaint/inpaint_unet_epoch{epoch+1}.pt"
        torch.save(inpaint_model.state_dict(), ckpt)
        print("Saved:", ckpt)


In [None]:

if RUN_INPAINTING:
    @torch.no_grad()
    def inpaint_images(model, imgs, masks, num_steps=SAMPLE_STEPS_INPAINT):
        model.eval()
        B = imgs.size(0)
        x_known = imgs * masks
        x = torch.randn_like(imgs)
        timesteps = torch.linspace(0, NUM_DIFFUSION_STEPS-1, num_steps, dtype=torch.long)
        for i in tqdm(reversed(range(num_steps)), desc="Sampling (inpaint)"):
            t_int = timesteps[i]
            t = torch.full((B,), t_int, device=device, dtype=torch.long)
            x = p_sample(model, x, t, None, mask=masks, x_known=x_known)
        return x

    it = iter(celeba_loader)
    imgs = next(it)
    imgs = imgs.to(device)
    B, C, H, W = imgs.shape
    masks = random_rectangle_mask(B, H, W, device=device)
    corrupted = imgs * masks
    inpainted = inpaint_images(inpaint_model, imgs, masks)
    vis = torch.cat([imgs, corrupted, inpainted], dim=0)
    show_images_grid(vis, nrow=B, filename="outputs_inpaint/inpainting_results.png")


The model is training now. After the results I will post them in github