# Inpainting with RePaint

One of the many possible downstream applications of diffusion models is inpainting, the task of filling in missing data given some context. One approach to inpaiting has been proposed by [Lugmayr et al. 2022](https://arxiv.org/abs/2201.09865) with their method named RePaint. It proposed an adapted sampling strategy based on a standard diffusion model. If you are interested in how to train such models from scratch for new datasets, head over to this [tutorial].

## Background

Short introduction to RePaint. Maybe just include the figure

## Imports

In [19]:
import numpy as np
import kornia.augmentation as K
from torch.utils.data import DataLoader

from denoising_diffusion_pytorch.repaint import GaussianDiffusion as RePaint

from torchgeo.datasets import MillionAID

import torch
import matplotlib.pyplot as plt
from torch import Tensor

## Dataloader

The model has been pretrained on the [MillionAid Dataset](https://captain-whu.github.io/DiRS/) that we will load with [torchgeo](https://github.com/microsoft/torchgeo).

In [14]:
ds = MillionAID(root="/mnt/SSD2/nils/ocean_bench_exps/diffusion/data/million", split="train")


def collate_fn(batch):
    """"Collate function for resizing images to the same size and normalization."""
    resize = K.Resize(size=(224, 224))
    images = [resize(item["image"].float()) for item in batch]
    images = torch.stack(images) / 255.
    return images

# to easily generate a batch of images 
dl = DataLoader(ds, batch_size=16, shuffle=True, collate_fn=collate_fn)

## Inpainting Task

For inpainting, some areas of the image are missing, and we use the diffusion model to fill ("inpaint") these areas to obtain a complete image. We will simulate this by creating masks so we can visualize the results.

The implementation expects the following inputs:

- images with applied mask to be inpainted
- the mask itself (0 denotes missing pixels, 1 denotes pixels that can be used as context for inpainting)

In [16]:
# TODO inpainting code

def plot_results(target, masked_gt, mask, inpainted):
    """Plot results.

    Args:
        target: full target
        masked_gt: target with mask applied
        mask: mask tensor
        inpainted: inpainted tensor
    """
    batch_size = target.size(0)

    fig, axs = plt.subplots(batch_size, 4, figsize=(30, 5 * batch_size))

    for i in range(batch_size):
        target_np = target[i].numpy().transpose(1, 2, 0)
        masked_gt_np = masked_gt[i].numpy().transpose(1, 2, 0)
        mask_np = mask[i].numpy().transpose(1, 2, 0)
        inpainted_np = inpainted[i].numpy().transpose(1, 2, 0)

        axs[i, 0].imshow(target_np)
        axs[i, 0].axis("off")

        axs[i, 1].imshow(masked_gt_np)
        axs[i, 1].axis("off")

        axs[i, 2].imshow(mask_np, cmap="gray")
        axs[i, 2].axis("off")

        axs[i, 3].imshow(inpainted_np)
        axs[i, 3].axis("off")

    axs[0, 0].set_title("Original Image", fontsize=40)
    axs[0, 1].set_title("Masked Input", fontsize=40)
    axs[0, 2].set_title("Mask", fontsize=40)
    axs[0, 3].set_title("Inpainted Image", fontsize=40)

    plt.subplots_adjust(wspace=0.02, hspace=0.02)
    plt.tight_layout()


def create_center_square_mask(image_size: int, mask_size: int):
    """Create a mask that is a center square in the image."""
    assert image_size >= mask_size, "Mask size should be smaller or equal to image size"

    mask = torch.zeros((image_size, image_size))
    start = (image_size - mask_size) // 2
    end = start + mask_size
    mask[start:end, start:end] = 1

    return (mask - 1) * -1

def create_middle_column_mask(image_size: int, mask_size: int):
    """Create a mask that is a center column down the image"""
    mask = torch.zeros((image_size, image_size))
    start = (image_size - mask_size) // 2
    end = start + mask_size
    mask[:, start:end] = 1
    return (mask - 1) * -1

def prepare_data(imgs: Tensor) -> dict[str, Tensor]:
    """Prepare images for inpainting."""
    image_size = imgs.shape[-1]  
    mask_size = image_size // 3 
    masks = (
        create_center_square_mask(image_size, mask_size)
        .repeat(imgs.shape[0], 1, 1)
        .unsqueeze(1)
    )

    # Apply the mask to the image
    masked_imgs = imgs * masks

    return {
        "image": imgs,
        "mask": masks,
        "masked_image": masked_imgs,
    }

In [17]:
imgs = next(iter(dl))

data = prepare_data(imgs)

## Inpainting with UQ

The diffusion model is stochastic meaning, that running the model multiple times for the same input will generate varying realizations. We can use this stochasticity as a notion of uncertainty for the inpainted regions. The code below will demonstrate how to do this and render some visualizations to give an intuition.

In [20]:
def plot_samples_with_uq(masked_gt, mask, samples, uncertainty, num_datapoints: int = 4, num_samples: int=5):
    """Plot random results.

    Args:
        masked_gt: target with mask applied tensor of shape [batch_size, C, H, W]
        mask: mask tensor of shape [batch_size, 1, H, W]
        samples: sample tensor of shape [batch_size, num_samples, C, H, W]
        uncertainty: uncertainty tensor [batch_size, C, H, W]
    """
    indices = np.random.choice(masked_gt.size(0), size=num_datapoints, replace=False)
    sample_indices = np.random.choice(samples.size(1), size=num_samples, replace=False)

    fig, axs = plt.subplots(4, 2+num_samples, figsize=(60, 20))

    for i, idx in enumerate(indices):
        masked_gt_np = masked_gt[idx].numpy().transpose(1, 2, 0)
        uncertainty_np = uncertainty[idx].numpy().transpose(1, 2, 0)

        axs[i, 0].imshow(masked_gt_np)
        axs[i, 0].axis("off")

        # plot the samples
        for j, sample_idx in enumerate(sample_indices):
            sample_np = samples[idx, sample_idx].numpy().transpose(1, 2, 0)
            axs[i, j+1].imshow(sample_np)
            axs[i, j+1].axis("off")

        axs[i, 7].imshow(uncertainty_np)
        axs[i, 7].axis("off")

    
    axs[0, 0].set_title("Masked Input", fontsize=40)
    for j in range(5):
        axs[0, j+1].set_title(f"Sample {j+1}", fontsize=40)
    axs[0, 6].set_title("Uncertainty", fontsize=40)

    plt.subplots_adjust(wspace=0.02, hspace=0.02)
    plt.tight_layout()

The UQ capabilities of diffusion models have been hightlighted in several publications, however, we believe there are many more interesting applications that could benefit from them. We hope that this tutorial potentially provided some ideas and insights that you might find useful for your tasks or research.