## Image Inpainting with a Simple Diffusion Model
This notebook trains a minimal DDPM-style model for option 2 (image inpainting) using the same dataset and preprocessing as Task 1.

In [None]:

import math
import os
import random
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms.functional as TF
from tqdm.auto import tqdm

# Reuse the same dataset path and image size from Task 1
DATASET_PATH = "data"
IMAGE_SIZE = 32
BATCH_SIZE = 32
EPOCHS = 5
LEARNING_RATE = 2e-4
TIMESTEPS = 200
MASK_RATIO = 0.4  # central square size relative to image

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:

class ShapeDataset(Dataset):
    """Same preprocessing as Task 1 for the circle/square/triangle sketches."""

    def __init__(self, data_dir: str = DATASET_PATH, img_size: int = IMAGE_SIZE, max_items_per_class: int = 2000, augment: bool = True):
        self.samples = []
        for name in ["circle", "square", "triangle"]:
            path = os.path.join(data_dir, f"{name}.npy")
            arr = np.load(path, mmap_mode="r")
            arr = arr[:max_items_per_class]
            if arr.ndim == 4:  # (N,H,W,C)
                arr = arr[..., 0]
            self.samples.extend(arr.astype(np.float32))
        self.img_size = img_size
        self.augment = augment

    def __len__(self):
        return len(self.samples)

    def _prepare(self, img: np.ndarray) -> torch.Tensor:
        tensor = torch.as_tensor(img, dtype=torch.float32)
        if tensor.ndim == 1:
            side = int(math.sqrt(tensor.numel()))
            tensor = tensor.view(side, side)
        if tensor.ndim == 2:
            tensor = tensor.unsqueeze(0)
        if tensor.shape[0] > 1:
            tensor = tensor.mean(dim=0, keepdim=True)
        return tensor

    def __getitem__(self, idx: int) -> torch.Tensor:
        img = self._prepare(self.samples[idx])
        img = TF.resize(img, [self.img_size, self.img_size])
        if self.augment:
            angle = random.uniform(-8, 8)
            img = TF.rotate(img, angle, fill=0.0)
            img = TF.affine(img, angle=0.0, translate=(random.uniform(-2, 2), random.uniform(-2, 2)), scale=1.0, shear=0.0, fill=0.0)
        img = (img / 255.0).clamp(0, 1) * 2 - 1  # [-1, 1]
        return img


def central_square_mask(shape: Tuple[int, int], ratio: float = MASK_RATIO) -> torch.Tensor:
    h, w = shape
    mask = torch.zeros((1, h, w))
    size = int(min(h, w) * ratio)
    top = (h - size) // 2
    left = (w - size) // 2
    mask[:, top:top + size, left:left + size] = 1.0
    return mask


def apply_inpainting_mask(imgs: torch.Tensor, ratio: float = MASK_RATIO):
    """Return masked images and the corresponding binary mask (1 = hole)."""
    b, c, h, w = imgs.shape
    mask = central_square_mask((h, w), ratio).to(imgs.device).expand(b, -1, -1, -1)
    noise = torch.randn_like(imgs)
    masked = imgs * (1 - mask) + noise * mask
    return masked, mask


def get_dataloader(batch_size=BATCH_SIZE, img_size=IMAGE_SIZE, max_items_per_class=2000, augment=True):
    dataset = ShapeDataset(img_size=img_size, max_items_per_class=max_items_per_class, augment=augment)
    print(f"Loaded {len(dataset)} sketches")
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())


dataloader = get_dataloader()


In [None]:

# Peek at a few masked samples
imgs = next(iter(dataloader))[:6]
masked_imgs, masks = apply_inpainting_mask(imgs, ratio=MASK_RATIO)
fig, axes = plt.subplots(2, 6, figsize=(12, 4))
for i in range(6):
    axes[0, i].imshow(((imgs[i].squeeze() + 1) / 2).clamp(0, 1), cmap="gray")
    axes[0, i].axis("off")
    axes[1, i].imshow(((masked_imgs[i].squeeze() + 1) / 2).clamp(0, 1), cmap="gray")
    axes[1, i].axis("off")
axes[0, 0].set_ylabel("original")
axes[1, 0].set_ylabel("masked")
plt.tight_layout()
plt.show()


In [None]:

# Simple U-Net with sinusoidal time embeddings

def sinusoidal_embedding(timesteps: torch.Tensor, dim: int) -> torch.Tensor:
    device = timesteps.device
    half = dim // 2
    freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=device) / half)
    angles = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
    if dim % 2 == 1:
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1)
    return emb


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.SiLU(),
        )
        self.time_mlp = nn.Sequential(nn.Linear(time_dim, out_ch), nn.SiLU())

    def forward(self, x, t_emb):
        h = self.conv(x)
        t = self.time_mlp(t_emb)[:, :, None, None]
        return h + t


class Down(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.block = ConvBlock(in_ch, out_ch, time_dim)
        self.down = nn.Conv2d(out_ch, out_ch, 4, stride=2, padding=1)

    def forward(self, x, t_emb):
        h = self.block(x, t_emb)
        return self.down(h), h  # pooled, skip


class Up(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch, time_dim):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
        self.block = ConvBlock(out_ch + skip_ch, out_ch, time_dim)

    def forward(self, x, skip, t_emb):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        return self.block(x, t_emb)


class SimpleUNet(nn.Module):
    def __init__(self, img_channels=1, base=32, time_dim=128):
        super().__init__()
        in_channels = img_channels + 1  # image + mask
        self.time_dim = time_dim
        self.head = ConvBlock(in_channels, base, time_dim)
        self.down1 = Down(base, base * 2, time_dim)
        self.down2 = Down(base * 2, base * 4, time_dim)
        self.bot = ConvBlock(base * 4, base * 4, time_dim)
        self.up1 = Up(base * 4, base * 4, base * 2, time_dim)
        self.up2 = Up(base * 2, base * 2, base, time_dim)
        self.outc = nn.Conv2d(base, img_channels, 1)

    def forward(self, x, t):
        t_emb = sinusoidal_embedding(t, self.time_dim)
        x1 = self.head(x, t_emb)          # 32x32
        x2, skip1 = self.down1(x1, t_emb) # 16x16
        x3, skip2 = self.down2(x2, t_emb) # 8x8
        mid = self.bot(x3, t_emb)         # 8x8
        x = self.up1(mid, skip2, t_emb)   # back to 16x16
        x = self.up2(x, skip1, t_emb)     # back to 32x32
        return self.outc(x)


In [None]:

class Diffusion:
    def __init__(self, timesteps: int = TIMESTEPS, device: torch.device = device):
        self.timesteps = timesteps
        self.device = device
        self.betas = torch.linspace(1e-4, 0.02, timesteps, device=device)
        self.alphas = 1.0 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)

    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_ac = torch.sqrt(self.alpha_cumprod[t])[:, None, None, None]
        sqrt_om = torch.sqrt(1 - self.alpha_cumprod[t])[:, None, None, None]
        return sqrt_ac * x0 + sqrt_om * noise

    def p_sample(self, model_out, x, t):
        beta_t = self.betas[t][:, None, None, None]
        sqrt_one_minus_ac = torch.sqrt(1 - self.alpha_cumprod[t])[:, None, None, None]
        sqrt_recip_alpha = torch.sqrt(1.0 / self.alphas[t])[:, None, None, None]
        model_mean = sqrt_recip_alpha * (x - beta_t / sqrt_one_minus_ac * model_out)
        if (t == 0).all():
            return model_mean
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(beta_t) * noise


In [None]:

# Instantiate model, optimizer, and diffusion utilities
model = SimpleUNet().to(device)
diffusion = Diffusion()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)


def train_epoch(loader):
    model.train()
    total_loss = 0.0
    for imgs in tqdm(loader, desc="train"):
        imgs = imgs.to(device)
        masked_imgs, mask = apply_inpainting_mask(imgs)
        b = imgs.size(0)
        t = torch.randint(0, diffusion.timesteps, (b,), device=device).long()
        noise = torch.randn_like(imgs)
        x_t = diffusion.q_sample(imgs, t, noise)
        x_t_masked = x_t * (1 - mask) + torch.randn_like(x_t) * mask
        net_inp = torch.cat([x_t_masked, mask], dim=1)
        pred_noise = model(net_inp, t)
        # emphasize masked region while keeping global smoothness
        masked_loss = nn.functional.mse_loss(pred_noise * mask, noise * mask, reduction="sum") / (mask.sum() + 1e-8)
        global_loss = nn.functional.mse_loss(pred_noise, noise)
        loss = 0.7 * masked_loss + 0.3 * global_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * b
    return total_loss / len(loader.dataset)


for epoch in range(1, EPOCHS + 1):
    avg_loss = train_epoch(dataloader)
    print(f"Epoch {epoch}: loss={avg_loss:.4f}")


In [None]:

@torch.no_grad()
def inpaint(imgs, ratio: float = MASK_RATIO, steps: int = TIMESTEPS, guide_strength: float = 0.3):
    model.eval()
    original = imgs.to(device)
    b, c, h, w = original.shape
    mask = central_square_mask((h, w), ratio).to(device).expand(b, -1, -1, -1)
    masked_imgs = original * (1 - mask) + torch.randn_like(original) * mask
    img = torch.randn_like(masked_imgs)

    for i in reversed(range(steps)):
        t = torch.full((b,), i, device=device, dtype=torch.long)
        img = img * mask + original * (1 - mask)
        net_inp = torch.cat([img, mask], dim=1)
        pred_noise = model(net_inp, t)

        beta_t = diffusion.betas[t][:, None, None, None]
        alpha_t = diffusion.alphas[t][:, None, None, None]
        alpha_bar_t = diffusion.alpha_cumprod[t][:, None, None, None]
        sqrt_recip_alpha = torch.rsqrt(alpha_t)
        sqrt_one_minus_ac = torch.sqrt(1 - alpha_bar_t)

        pred_x0 = (img - sqrt_one_minus_ac * pred_noise) / torch.sqrt(alpha_bar_t)
        pred_x0 = pred_x0 * mask + original * (1 - mask)

        model_mean = sqrt_recip_alpha * (img - beta_t / sqrt_one_minus_ac * pred_noise)
        if i > 0:
            noise = torch.randn_like(img)
            img = model_mean + torch.sqrt(beta_t) * noise
        else:
            img = model_mean

        img = img * (1 - guide_strength) + pred_x0 * guide_strength
        img = img * mask + original * (1 - mask)

    return img, masked_imgs, mask


In [None]:

# Visualize a few inpainting results
samples = next(iter(dataloader))[:6]
recon, masked, mask = inpaint(samples, steps=50)

fig, axes = plt.subplots(3, 6, figsize=(12, 6))
for i in range(6):
    axes[0, i].imshow(((samples[i].cpu().squeeze() + 1) / 2).clamp(0, 1), cmap="gray")
    axes[0, i].axis("off")
    axes[1, i].imshow(((masked[i].cpu().squeeze() + 1) / 2).clamp(0, 1), cmap="gray")
    axes[1, i].axis("off")
    axes[2, i].imshow(((recon[i].cpu().squeeze() + 1) / 2).clamp(0, 1), cmap="gray")
    axes[2, i].axis("off")
axes[0, 0].set_ylabel("original")
axes[1, 0].set_ylabel("masked")
axes[2, 0].set_ylabel("inpainted")
plt.tight_layout()
plt.show()
