In [1]:
# Import required libraries
import os
import numpy as np

# List all image paths in the directory
file_names = os.listdir('./img_align_celeba')
img_paths = ['./img_align_celeba/' + file_name for file_name in file_names]

# Split dataset into training and validation sets
num_train = 150000
train_imgpaths = img_paths[:num_train]
val_imgpaths = img_paths[num_train:]

# Function to generate mask image from bounding box
def bbox2mask(img_shape, bbox, dtype='uint8'):
    """
    Generate mask in ndarray from bbox.

    Args:
        img_shape (tuple[int]): Shape of the image (height, width, channels).
        bbox (tuple[int]): Configuration tuple, (top, left, height, width).
        dtype (str): Data type of the mask array.

    Returns:
        np.ndarray: Mask array.
    """
    height, width = img_shape[:2]
    mask = np.zeros((height, width, 1), dtype=dtype)
    mask[bbox[0]:bbox[0] + bbox[2], bbox[1]:bbox[1] + bbox[3], :] = 1
    return mask

FileNotFoundError: [Errno 2] No such file or directory: './img_align_celeba'

In [None]:
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

# Build dataset
class InpaintingDataset(Dataset):
    def __init__(self, img_paths, mask_mode, image_size=[256, 256]):
        """
        Args:
            img_paths (list): List of image file paths.
            mask_mode (str): Mode for generating masks (e.g., 'center').
            image_size (list): Dimensions to resize the image to [height, width].
        """
        self.img_paths = img_paths
        self.tfs = transforms.Compose([
            transforms.Resize((image_size[0], image_size[1])),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        self.mask_mode = mask_mode
        self.image_size = image_size

    def __getitem__(self, index):
        """
        Fetches an image and its corresponding mask.

        Args:
            index (int): Index of the image to fetch.

        Returns:
            dict: Dictionary containing:
                  - 'gt_image': Original ground truth image.
                  - 'cond_image': Image with random noise applied to the masked area.
                  - 'mask_image': Ground truth image with mask applied.
                  - 'mask': Mask generated for the image.
                  - 'path': File path of the image.
        """
        img_path = self.img_paths[index]
        img = Image.open(img_path).convert('RGB')
        img = self.tfs(img)

        mask = self.get_mask()
        cond_image = img * (1. - mask) + mask * torch.randn_like(img)
        mask_img = img * (1. - mask) + mask

        return {
            'gt_image': img,
            'cond_image': cond_image,
            'mask_image': mask_img,
            'mask': mask,
            'path': img_path
        }

    def __len__(self):
        """
        Returns:
            int: Total number of images in the dataset.
        """
        return len(self.img_paths)

    def get_mask(self):
        """
        Generates a mask based on the mask mode.

        Returns:
            torch.Tensor: Mask tensor.
        """
        if self.mask_mode == 'center':
            h, w = self.image_size
            mask = bbox2mask(self.image_size, (h // 4, w // 4, h // 2, w // 2))
        else:
            raise NotImplementedError(
                f'Mask mode {self.mask_mode} has not been implemented.'
            )
        return torch.from_numpy(mask).permute(2, 0, 1)

In [None]:
from torch import nn
import math

# Block class definition
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        h = self.bnorm1(self.relu(self.conv1(x)))
        time_emb = self.relu(self.time_mlp(t))  # Time embedding
        time_emb = time_emb[(...,) + (None,) * 2]
        h = h + time_emb  # Add time channel
        h = self.bnorm2(self.relu(self.conv2(h)))
        return self.transform(h)  # Down or Upsample


# SinusoidalPositionEmbeddings class definition
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


# SimpleUnet class definition
class SimpleUnet(nn.Module):
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)  # Channels for the downsampling path
        up_channels = (1024, 512, 256, 128, 64)    # Channels for the upsampling path
        out_dim = 3
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )

        # Initial convolution
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsampling blocks
        self.downs = nn.ModuleList([
            Block(down_channels[i], down_channels[i + 1], time_emb_dim)
            for i in range(len(down_channels) - 1)
        ])

        # Upsampling blocks
        self.ups = nn.ModuleList([
            Block(up_channels[i], up_channels[i + 1], time_emb_dim, up=True)
            for i in range(len(up_channels) - 1)
        ])

        # Output layer
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Initial convolution
        x = self.conv0(x)

        # Time embedding
        t = self.time_mlp(timestep)

        # Downsampling with residual connections
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)

        # Upsampling with residual connections
        for up in self.ups:
            residual_x = residual_inputs.pop()
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)

        return self.output(x)

In [None]:
from tqdm import tqdm
from functools import partial
import numpy as np
import torch

def make_beta_schedule(schedule, n_timestep, linear_start=1e-5, linear_end=1e-2):
    """
    Generate a beta schedule for the diffusion process.

    Args:
        schedule (str): Type of schedule ('linear').
        n_timestep (int): Number of timesteps in the schedule.
        linear_start (float): Starting value for the linear schedule.
        linear_end (float): Ending value for the linear schedule.

    Returns:
        np.ndarray: Array of betas for each timestep.
    """
    if schedule == 'linear':
        betas = np.linspace(
            linear_start, linear_end, n_timestep, dtype=np.float64
        )
    else:
        raise NotImplementedError(f"Schedule type '{schedule}' is not implemented.")
    return betas

def get_index_from_list(vals, t, x_shape=(1, 1, 1, 1)):
    """
    Returns a specific index `t` of a list of values `vals`,
    while considering the batch dimension.

    Args:
        vals (torch.Tensor): List of values.
        t (torch.Tensor): Indices to gather.
        x_shape (tuple): Shape for the output.

    Returns:
        torch.Tensor: Values gathered from `vals` at indices `t`.
    """
    batch_size, *_ = t.shape
    out = vals.gather(-1, t)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(device)

In [None]:
import torch
from torch import nn
import numpy as np
from functools import partial

class InpaintingGaussianDiffusion(nn.Module):
    def __init__(self, unet_config, beta_schedule, **kwargs):
        super(InpaintingGaussianDiffusion, self).__init__(**kwargs)
        self.denoise_fn = UNet(**unet_config)  # Denoising UNet model
        self.beta_schedule = beta_schedule

    def set_new_noise_schedule(self, device):
        # Convert to torch tensors
        to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
        betas = make_beta_schedule(**self.beta_schedule)
        alphas = 1. - betas

        # Set timesteps and calculate cumulative products
        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        gammas = np.cumprod(alphas, axis=0)
        gammas_prev = np.append(1., gammas[:-1])

        # Register buffers for diffusion process parameters
        self.register_buffer("gammas", to_torch(gammas))
        self.register_buffer("sqrt_recip_gammas", to_torch(np.sqrt(1. / gammas)))
        self.register_buffer("sqrt_recipm1_gammas", to_torch(np.sqrt(1. / gammas - 1)))
        posterior_variance = betas * (1. - gammas_prev) / (1. - gammas)
        self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
        self.register_buffer("posterior_mean_coef1", to_torch(betas * np.sqrt(gammas_prev) / (1. - gammas)))
        self.register_buffer("posterior_mean_coef2", to_torch((1. - gammas_prev) * np.sqrt(alphas) / (1. - gammas)))

    def set_loss(self, loss_fn):
        # Set the loss function
        self.loss_fn = loss_fn

    def predict_start_from_noise(self, y_t, t, noise):
        # Predict start using noise
        return (
            get_index_from_list(self.sqrt_recip_gammas, t, y_t.shape) * y_t -
            get_index_from_list(self.sqrt_recipm1_gammas, t, y_t.shape) * noise
        )

    def q_posterior(self, y_0_hat, y_t, t):
        # Compute posterior distribution
        posterior_mean = (
            get_index_from_list(self.posterior_mean_coef1, t, y_t.shape) * y_0_hat +
            get_index_from_list(self.posterior_mean_coef2, t, y_t.shape) * y_t
        )
        posterior_log_variance_clipped = get_index_from_list(self.posterior_log_variance_clipped, t, y_t.shape)
        return posterior_mean, posterior_log_variance_clipped

    def p_mean_variance(self, y_t, t, clip_denoised: bool, y_cond=None):
        # Predict mean and variance
        noise_level = get_index_from_list(self.gammas, t, x_shape=(1, 1)).to(y_t.device)
        y_0_hat = self.predict_start_from_noise(
            y_t, t=t, noise=self.denoise_fn(torch.cat([y_cond, y_t], dim=1), noise_level)
        )
        if clip_denoised:
            y_0_hat.clamp_(-1., 1.)
        model_mean, posterior_log_variance = self.q_posterior(y_0_hat=y_0_hat, y_t=y_t, t=t)
        return model_mean, posterior_log_variance

    def q_sample(self, y_0, sample_gammas, noise=None):
        # Sample noisy data
        noise = noise if noise is not None else torch.randn_like(y_0)
        return (
            sample_gammas.sqrt() * y_0 +
            (1 - sample_gammas).sqrt() * noise
        )

    def forward(self, y_0, y_cond=None, mask=None, noise=None):
        # Forward process
        b, *_ = y_0.shape
        t = torch.randint(1, self.num_timesteps, (b,), device=y_0.device).long()
        gamma_t1 = get_index_from_list(self.gammas, t - 1, x_shape=(1, 1))
        sqrt_gamma_t2 = get_index_from_list(self.gammas, t, x_shape=(1, 1))
        sample_gammas = (sqrt_gamma_t2 - gamma_t1) * torch.rand((b, 1), device=y_0.device) + gamma_t1

        y_noisy = self.q_sample(y_0, sample_gammas.view(-1, 1, 1, 1), noise)
        if mask is not None:
            noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy * mask + (1. - mask) * y_0], dim=1), sample_gammas)
            loss = self.loss_fn(mask * noise, mask * noise_hat)
        else:
            noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy], dim=1), sample_gammas)
            loss = self.loss_fn(noise, noise_hat)
        return loss

    @torch.no_grad()
    def p_sample(self, y_t, t, clip_denoised=True, y_cond=None):
        # Sample from posterior
        model_mean, model_log_variance = self.p_mean_variance(y_t=y_t, t=t, clip_denoised=clip_denoised, y_cond=y_cond)
        noise = torch.randn_like(y_t) if any(t > 0) else torch.zeros_like(y_t)
        return model_mean + noise * (0.5 * model_log_variance).exp()

    @torch.no_grad()
    def restoration(self, y_cond, y_t=None, y_0=None, mask=None, sample_num=8):
        # Restore image from noisy input
        b, *_ = y_cond.shape
        sample_inter = self.num_timesteps // sample_num
        ret_arr = y_t
        y_t = y_t if y_t is not None else torch.randn_like(y_cond)
        for i in reversed(range(0, self.num_timesteps)):
            t = torch.full((b,), i, device=y_cond.device, dtype=torch.long)
            y_t = self.p_sample(y_t, t, y_cond=y_cond)
            if mask is not None:
                y_t = y_0 * (1. - mask) + mask * y_t
            if i % sample_inter == 0:
                ret_arr = torch.cat([ret_arr, y_t], dim=0)
        return y_t, ret_arr

In [None]:
import torch.nn.functional as F
from torch import nn

def mse_loss(output, target):
    """
    Mean Squared Error Loss Function
    Args:
        output (torch.Tensor): Predicted values.
        target (torch.Tensor): Ground truth values.
    Returns:
        torch.Tensor: Computed MSE loss.
    """
    return F.mse_loss(output, target)

def mae(input, target):
    """
    Mean Absolute Error Loss Function
    Args:
        input (torch.Tensor): Predicted values.
        target (torch.Tensor): Ground truth values.
    Returns:
        torch.Tensor: Computed MAE loss.
    """
    with torch.no_grad():
        loss = nn.L1Loss()
        output = loss(input, target)
    return output

In [2]:
import time
import torch
from tqdm import tqdm

class Trainer:
    def __init__(self, model, optimizers, train_loader, val_loader, epochs, sample_num, device, save_model):
        self.model = model.to(device)
        self.optimizer = torch.optim.Adam(
            list(filter(lambda p: p.requires_grad, self.model.parameters())),
            **optimizers
        )
        self.model.set_loss(mse_loss)
        self.model.set_new_noise_schedule(device)
        self.sample_num = sample_num
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.epochs = epochs
        self.save_model = save_model + "/best_model.pth"

    def train_step(self):
        """
        Executes a single training step (epoch).
        """
        losses = []
        for batch in tqdm(self.train_loader, desc="Training Batch"):
            gt_image = batch['gt_image'].to(self.device)
            cond_image = batch['cond_image'].to(self.device)
            mask = batch['mask'].to(self.device)

            self.optimizer.zero_grad()
            loss = self.model(gt_image, cond_image, mask=mask)
            loss.backward()
            losses.append(loss.item())
            self.optimizer.step()

        return sum(losses) / len(losses)

    def val_step(self):
        """
        Executes a validation step to compute loss and MAE metrics.
        """
        losses, metrics = [], []
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validation Batch"):
                gt_image = batch['gt_image'].to(self.device)
                cond_image = batch['cond_image'].to(self.device)
                mask = batch['mask'].to(self.device)

                loss = self.model(gt_image, cond_image, mask=mask)
                output, _ = self.model.restoration(
                    cond_image, y_t=cond_image, y_0=gt_image, mask=mask, sample_num=self.sample_num
                )
                mae_score = mae(gt_image, output)

                losses.append(loss.item())
                metrics.append(mae_score.item())

        return sum(losses) / len(losses), sum(metrics) / len(metrics)

    def train(self):
        """
        Training loop across all epochs with validation and model checkpointing.
        """
        best_mae = float('inf')
        for epoch in range(self.epochs):
            epoch_start_time = time.time()

            train_loss = self.train_step()
            val_loss, val_mae = self.val_step()

            if val_mae < best_mae:
                best_mae = val_mae
                torch.save(self.model.state_dict(), self.save_model)

            print("-" * 59)
            print(
                "| End of epoch {:3d} | Time: {:5.2f}s | Train Loss: {:8.3f} | Valid Loss: {:8.3f} | Valid MAE: {:8.3f} |".format(
                    epoch + 1, time.time() - epoch_start_time, train_loss, val_loss, val_mae
                )
            )
            print("-" * 59)

        self.model.load_state_dict(torch.load(self.save_model))

# Define hyperparameters and configurations
epochs = 200
sample_num = 8
save_model = "./save_model"
optimizers = {"lr": 5e-5, "weight_decay": 0}
device = "cuda" if torch.cuda.is_available() else "cpu"

unet_config = {
    "in_channel": 6,
    "out_channel": 3,
    "inner_channel": 64,
    "channel_mults": [1, 2, 4, 8],
    "attn_res": [16],
    "num_head_channels": 32,
    "res_blocks": 2,
    "dropout": 0.2,
    "image_size": 256
}

beta_schedule = {
    "schedule": "linear",
    "n_timestep": 20,
    "linear_start": 1e-4,
    "linear_end": 0.09
}

inpainting_model = InpaintingGaussianDiffusion(unet_config, beta_schedule)

# Initialize the trainer
trainer = Trainer(inpainting_model, optimizers, train_loader, val_loader, epochs, sample_num, device, save_model)

NameError: name 'InpaintingGaussianDiffusion' is not defined

In [None]:
# Load the model and set the noise schedule
inpainting_model = InpaintingGaussianDiffusion(unet_config, beta_schedule)
inpainting_model.set_new_noise_schedule(device)

# Load pre-trained weights
load_state = torch.load('./save_model/best_model_200.pth')
inpainting_model.load_state_dict(load_state, strict=True)
inpainting_model.eval().to(device)

In [None]:
# Test image path
test_imgpath = img_paths[16]

# Create test dataset
test_dataset = InpaintingDataset([test_imgpath], mask_mode='center')

# Fetch a sample from the dataset
test_sample = next(iter(test_dataset))

In [None]:
def inference(model, test_sample):
    """
    Perform inference using the trained model.

    Args:
        model (nn.Module): Trained inpainting model.
        test_sample (dict): Test sample including condition image, mask, etc.

    Returns:
        tuple: Output image and visuals (additional details, if applicable).
    """
    with torch.no_grad():
        output, visuals = model.restoration(
            test_sample['cond_image'].unsqueeze(0).to(device),
            y_t=test_sample['cond_image'].unsqueeze(0).to(device),
            y_0=test_sample['cond_image'].unsqueeze(0).to(device),
            mask=test_sample['mask'].unsqueeze(0).to(device)
        )
    return output, visuals

# Run inference
output, visuals = inference(inpainting_model, test_sample)

In [None]:
def show_tensor_image(image, show=True):
    """
    Visualize a tensor image.

    Args:
        image (torch.Tensor): Tensor image to be displayed.
        show (bool): Whether to display the image.

    Returns:
        None: Displays the image.
    """
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),  # Unnormalize
        transforms.Lambda(lambda t: t.permute(1, 2, 0)),  # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),  # Scale to [0, 255]
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),  # Convert to numpy
        transforms.ToPILImage()  # Convert to PIL image
    ])

    # Take the first image in the batch if batched
    if len(image.shape) == 4:
        image = image[0, :, :, :]

    if show:
        plt.imshow(reverse_transforms(image))
        plt.axis('off')  # Remove axes for better visualization