# Image Inpainting with Diffusion on Simple Shapes

This notebook builds a diffusion-based image inpainting model for simple grayscale shapes. Image inpainting fills in missing regions of an image; here we use a denoising diffusion process together with a small U-Net to reconstruct masked portions. The dataset contains 32×32 single-channel images of circles, squares, and triangles stored as NumPy arrays.

## Table of Contents

- [Dataset and preprocessing](#Dataset-and-preprocessing)
- [Inpainting mask](#Inpainting-mask)
- [Diffusion model and U-Net](#Diffusion-model-and-U-Net)
- [Training loop](#Training-loop)
- [Inpainting and evaluation](#Inpainting-and-evaluation)

## Dataset and preprocessing

Images are loaded from `.npy` files (`circle.npy`, `square.npy`, `triangle.npy`) located in the `data` folder. Each image is grayscale and reshaped to a tensor of shape `(1, 32, 32)`, converted to `float32`, and normalized to the range `[−1, 1]`. An optional limit `max_items_per_class` allows sub-sampling each shape class. A `DataLoader` is created with the configured batch size.

In [None]:

import os
import math
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Paths and hyperparameters
DATASET_PATH = "data"
IMAGE_SIZE = 32
CHANNELS = 1
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
EPOCHS = 30
TIMESTEPS = 300
MASK_RATIO = 0.5
max_items_per_class = None  # set an integer to limit samples per shape

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def to_tensor(img: np.ndarray) -> torch.Tensor:
    """Convert a numpy image to a normalized torch tensor in [-1, 1]."""
    img = np.asarray(img, dtype=np.float32)
    if img.ndim == 3 and img.shape[-1] == 1:
        img = img[..., 0]
    if img.ndim == 1:
        side = int(math.sqrt(img.size))
        if side * side != img.size:
            raise ValueError(f"Cannot reshape flat image of size {img.size} into a square.")
        img = img.reshape(side, side)
    if img.ndim != 2:
        raise ValueError(f"Expected 2D grayscale image, got shape {img.shape}.")

    tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
    tensor = F.interpolate(tensor, size=(IMAGE_SIZE, IMAGE_SIZE), mode="bilinear", align_corners=False)
    tensor = tensor.squeeze(0)  # (1, IMAGE_SIZE, IMAGE_SIZE)
    tensor = (tensor / 127.5) - 1.0
    return tensor


class SimpleShapesDataset(Dataset):
    """Dataset loading simple shape numpy arrays and returning tensors."""

    def __init__(self, root: str, max_items_per_class=None):
        super().__init__()
        self.samples = []
        for name in ["circle", "square", "triangle"]:
            path = os.path.join(root, f"{name}.npy")
            data = np.load(path)
            if max_items_per_class is not None:
                data = data[:max_items_per_class]
            self.samples.extend([to_tensor(x) for x in data])
        self.samples = torch.stack(self.samples)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


# Load dataset and create dataloader
dataset = SimpleShapesDataset(DATASET_PATH, max_items_per_class=max_items_per_class)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

print(f"Loaded {len(dataset)} samples on {DEVICE}.")


## Inpainting mask

We remove a central square region from each image using a binary mask with value `1` inside the missing area and `0` elsewhere. The `MASK_RATIO` controls the fraction of the image side covered by the square. The same mask is used during training and inference to hide the target region and to guide reconstruction.

In [None]:

def central_square_mask(image_shape, ratio: float):
    """Create a central square mask with ones in the missing region."""
    _, h, w = image_shape
    mask = torch.zeros(image_shape, dtype=torch.float32)
    size_h, size_w = int(h * ratio), int(w * ratio)
    start_h = (h - size_h) // 2
    start_w = (w - size_w) // 2
    mask[:, start_h : start_h + size_h, start_w : start_w + size_w] = 1.0
    return mask


def apply_inpainting_mask(imgs: torch.Tensor, ratio: float = MASK_RATIO):
    """Apply a central mask to a batch of images and return masked images and mask."""
    b, c, h, w = imgs.shape
    mask = central_square_mask((c, h, w), ratio).to(imgs.device)
    mask = mask.unsqueeze(0).expand(b, -1, -1, -1)
    # Zero out the masked region for visualization
    masked_imgs = imgs * (1 - mask)
    return masked_imgs, mask


# Visualize a few masked samples
batch = next(iter(loader))
masked_batch, vis_mask = apply_inpainting_mask(batch)

fig, axes = plt.subplots(2, 6, figsize=(12, 4))
for i in range(6):
    axes[0, i].imshow(batch[i].squeeze(), cmap="gray", vmin=-1, vmax=1)
    axes[0, i].axis("off")
    axes[0, i].set_title("Original")
    axes[1, i].imshow(masked_batch[i].squeeze(), cmap="gray", vmin=-1, vmax=1)
    axes[1, i].axis("off")
    axes[1, i].set_title("Masked")
plt.tight_layout()
plt.show()


## Diffusion model and U-Net

A Denoising Diffusion Probabilistic Model (DDPM) gradually adds Gaussian noise to an image through a forward process \(q(x_t \mid x_0)\) controlled by a noise schedule. With pre-computed betas \(eta_t\), we derive \(lpha_t = 1 - eta_t\) and the cumulative product \(ar{lpha}_t\). The forward sampling is

\[ x_t = \sqrt{ar{lpha}_t} x_0 + \sqrt{1-ar{lpha}_t}\, \epsilon. \]

The model learns to predict the noise \(\epsilon\) given a noisy input. We concatenate the masked noisy image with the binary mask and provide a sinusoidal time embedding to a compact U-Net that outputs the predicted noise.

In [None]:
# Diffusion utilities


def sinusoidal_embedding(timesteps: torch.Tensor, dim: int) -> torch.Tensor:
    timesteps = timesteps.to(torch.float32)
    device = timesteps.device
    half = dim // 2
    freqs = torch.exp(-math.log(10000) * torch.arange(half, device=device).float() / (half - 1))
    args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1)
    if dim % 2 == 1:
        emb = torch.cat([emb, torch.zeros(timesteps.size(0), 1, device=device)], dim=1)
    return emb


class ConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, time_dim: int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.time_mlp = nn.Sequential(
            nn.Linear(time_dim, out_ch),
            nn.ReLU(inplace=True)
        )
        self.act = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        h = self.act(self.bn1(self.conv1(x)))
        t_emb = self.time_mlp(t).view(t.size(0), -1, 1, 1)
        h = h + t_emb
        h = self.act(self.bn2(self.conv2(h)))
        return h


class Diffusion:
    def __init__(self, timesteps: int, device: torch.device, beta_start: float = 1e-4, beta_end: float = 0.02):
        self.timesteps = timesteps
        self.device = device
        self.betas = torch.linspace(beta_start, beta_end, timesteps, device=device)
        self.alphas = 1.0 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)

    def q_sample(self, x_start: torch.Tensor, t: torch.Tensor, noise: torch.Tensor | None = None) -> torch.Tensor:
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_ab = torch.sqrt(self.alpha_cumprod[t]).view(-1, 1, 1, 1)
        sqrt_one_minus_ab = torch.sqrt(1 - self.alpha_cumprod[t]).view(-1, 1, 1, 1)
        return sqrt_ab * x_start + sqrt_one_minus_ab * noise


In [None]:

class Down(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.conv = ConvBlock(in_ch, out_ch, time_dim)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x, t):
        h = self.conv(x, t)
        p = self.pool(h)
        return p, h  # pooled for next stage, h saved as skip


class Up(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.conv = ConvBlock(in_ch + out_ch, out_ch, time_dim)

    def forward(self, x, skip, t):
        x = nn.functional.interpolate(x, scale_factor=2, mode="nearest")
        x = torch.cat([x, skip], dim=1)
        return self.conv(x, t)


class SimpleUNet(nn.Module):
    def __init__(self, img_channels=1, time_dim=128):
        super().__init__()
        in_ch = img_channels + 1  # extra channel for mask
        self.time_dim = time_dim

        self.inc = ConvBlock(in_ch, 32, time_dim)
        self.down1 = Down(32, 64, time_dim)
        self.down2 = Down(64, 128, time_dim)

        self.bot = ConvBlock(128, 128, time_dim)

        self.up2 = Up(128, 128, time_dim)
        self.up1 = Up(128, 64, time_dim)
        self.outc = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, img_channels, 1)
        )

    def forward(self, x, t):
        t_emb = sinusoidal_embedding(t, self.time_dim)

        x1 = self.inc(x, t_emb)
        x2, skip1 = self.down1(x1, t_emb)
        x3, skip2 = self.down2(x2, t_emb)

        bottleneck = self.bot(x3, t_emb)

        x = self.up2(bottleneck, skip2, t_emb)
        x = self.up1(x, skip1, t_emb)
        return self.outc(x)


## Training loop

Training samples a random timestep `t` and noise `\epsilon ~ N(0, I)`, forms `x_t = q(x_t | x_0, t, \epsilon)`, and builds a masked version that removes the original signal inside the hole:

\[ x^{	ext{masked}}_t = x_t (1-	ext{mask}) + \sqrt{1-ar{lpha}_t}\, \epsilon \cdot 	ext{mask}. \]

This keeps the unmasked area consistent with `x_t` while ensuring the masked area contains only the same noise used in `q_sample`, eliminating leakage of `x_0`. The old code incorrectly injected fresh noise inside the mask, forcing the model to predict noise different from its input. Here the identical noise tensor is used for both `q_sample` and constructing `x^{	ext{masked}}_t`, so the supervised target matches the visible input statistics. The loss emphasises the masked region while keeping a global term.

In [None]:

diffusion = Diffusion(timesteps=TIMESTEPS, device=DEVICE)
model = SimpleUNet(img_channels=CHANNELS, time_dim=128).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
mse_loss = nn.MSELoss()


def train_epoch(loader):
    model.train()
    epoch_loss = 0.0
    for imgs in tqdm(loader, leave=False):
        imgs = imgs.to(DEVICE)
        _, mask = apply_inpainting_mask(imgs, ratio=MASK_RATIO)

        optimizer.zero_grad()
        b = imgs.size(0)
        t = torch.randint(0, diffusion.timesteps, (b,), device=DEVICE)
        noise = torch.randn_like(imgs)

        x_t = diffusion.q_sample(imgs, t, noise)
        alpha_bar_t = diffusion.alpha_cumprod[t].view(b, 1, 1, 1)
        sqrt_one_minus_ac = torch.sqrt(1 - alpha_bar_t)

        # Remove the clean component inside the mask; keep unmasked region intact
        x_t_masked = x_t * (1 - mask) + sqrt_one_minus_ac * noise * mask

        net_inp = torch.cat([x_t_masked, mask], dim=1)
        pred_noise = model(net_inp, t)

        masked_loss = F.mse_loss(pred_noise * mask, noise * mask, reduction="sum") / (mask.sum() + 1e-8)
        global_loss = mse_loss(pred_noise, noise)
        loss = 0.7 * masked_loss + 0.3 * global_loss

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
    return epoch_loss / len(loader)


def train(epochs=EPOCHS):
    losses = []
    for epoch in range(1, epochs + 1):
        loss = train_epoch(loader)
        losses.append(loss)
        print(f"Epoch {epoch:03d} | loss: {loss:.4f}")
    return losses

# Uncomment to train (may take a few minutes depending on hardware)
# training_losses = train()


## Inpainting and evaluation

To inpaint, we start from random noise and iterate the reverse diffusion steps. At each timestep we:

1. Hide the masked region with fresh noise and keep the known pixels.
2. Predict the noise with the U-Net and compute \(x_{t-1}\) using the DDPM reverse formula.
3. Enforce the known pixels outside the mask to match the original image.
4. Blend the update with the predicted clean image using a guidance weight.

We visualise the original, masked, and inpainted images, and optionally report the mean squared error inside the masked region.

In [None]:

@torch.no_grad()
def p_sample_step(x, t, eps_pred, diffusion: Diffusion):
    beta_t = diffusion.betas[t]
    alpha_t = diffusion.alphas[t]
    alpha_bar_t = diffusion.alpha_cumprod[t]
    sqrt_one_minus_ab = torch.sqrt(1 - alpha_bar_t)
    sqrt_recip_alpha = torch.sqrt(1 / alpha_t)

    coeff = (1 - alpha_t) / sqrt_one_minus_ab
    mean = sqrt_recip_alpha * (x - coeff * eps_pred)

    if t == 0:
        return mean
    noise = torch.randn_like(x)
    return mean + torch.sqrt(beta_t) * noise


@torch.no_grad()
def inpaint(imgs: torch.Tensor, ratio: float = MASK_RATIO, steps: int = TIMESTEPS, guide_strength: float = 0.3):
    model.eval()
    imgs = imgs.to(DEVICE)
    b = imgs.size(0)
    masked_imgs, mask = apply_inpainting_mask(imgs, ratio=ratio)

    x = torch.randn_like(imgs)  # start from pure noise
    for step in reversed(range(steps)):
        t = torch.full((b,), step, device=DEVICE, dtype=torch.long)
        alpha_bar_t = diffusion.alpha_cumprod[t].view(b, 1, 1, 1)
        sqrt_one_minus_ab = torch.sqrt(1 - alpha_bar_t)

        # Hide the target region with matching noise statistics
        masked_noise = torch.randn_like(x) * sqrt_one_minus_ab
        x_masked = x * (1 - mask) + masked_noise * mask
        net_inp = torch.cat([x_masked, mask], dim=1)

        eps_pred = model(net_inp, t)
        x0_pred = (x - sqrt_one_minus_ab * eps_pred) / torch.sqrt(alpha_bar_t)

        x = p_sample_step(x, step, eps_pred, diffusion)

        # Keep known pixels close to the original
        x = x * mask + imgs * (1 - mask)
        # Blend with predicted clean image
        x = (1 - guide_strength) * x + guide_strength * x0_pred

    return x.clamp(-1, 1), masked_imgs, mask


def denorm(x):
    return (x + 1) / 2


def visualize_inpainting(batch):
    inpainted, masked_imgs, _ = inpaint(batch, ratio=MASK_RATIO, steps=TIMESTEPS)
    batch = batch.cpu()
    masked_imgs = masked_imgs.cpu()
    inpainted = inpainted.cpu()

    fig, axes = plt.subplots(3, 6, figsize=(12, 6))
    for i in range(6):
        axes[0, i].imshow(denorm(batch[i]).squeeze(), cmap="gray", vmin=0, vmax=1)
        axes[0, i].axis("off")
        axes[0, i].set_title("Original")

        axes[1, i].imshow(denorm(masked_imgs[i]).squeeze(), cmap="gray", vmin=0, vmax=1)
        axes[1, i].axis("off")
        axes[1, i].set_title("Masked")

        axes[2, i].imshow(denorm(inpainted[i]).squeeze(), cmap="gray", vmin=0, vmax=1)
        axes[2, i].axis("off")
        axes[2, i].set_title("Inpainted")
    plt.tight_layout()
    plt.show()


def masked_mse(inpainted, original, mask):
    return F.mse_loss(inpainted * mask, original * mask)


# Example usage (requires a trained model)
# visualize_inpainting(next(iter(loader)))


## Conclusion

A compact diffusion model paired with a U-Net can inpaint missing regions in simple shape images. The critical fix was to ensure the same noise tensor is used both for generating \(x_t\) and for constructing the masked input; previously, unrelated noise inside the mask prevented learning. Visual inspection of masked versus reconstructed images, along with masked-region MSE, shows the model can recover the withheld central area after training.