<br> Ayman FAHSI | A20440820

Mouhammad BAZZI | A20522180


CS512 - Spring 2023</br> <h1><br><b><font color='red'>Project</font></br></h1>

## 0. **Libraries Importation**


In [None]:
# We import the libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.models import vgg19
from mmcv.ops import DeformConv2d
from torch.utils.data import Dataset
from PIL import Image
import os
import matplotlib.pyplot as plt
from torchvision import transforms
import torch.optim as optim
import time
from pytorch_ssim import ssim
from pytorch_fid import fid_score
import cv2
import numpy as np
from torch.utils.data import random_split
from torch.utils.data import DataLoader

## 1. **Custom Classes For Neural Networks**

### 1.1. **Nested Deformable Multi-Head Attention**

In [None]:
class NestedDMHA(nn.Module):
    def __init__(self, in_channels, num_heads):
        super(NestedDMHA, self).__init__()
        self.in_channels = in_channels
        self.num_heads = num_heads
        
        # Pointwise convolution and deformable convolution for Key
        self.key_conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        # number of output channels for the offset convolution layers number of output channels for the offset convolution layers
        # should be 2 * kernel_size * kernel_size (in this case, 2 * 3 * 3 = 18)
        self.key_offset = nn.Conv2d(in_channels, 18, kernel_size=3, padding=1)
        self.key_conv2 = DeformConv2d(in_channels, in_channels, kernel_size=3, padding=1)

        # Pointwise convolution and deformable convolution for Value
        self.value_conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.value_offset = nn.Conv2d(in_channels, 18, kernel_size=3, padding=1)
        self.value_conv2 = DeformConv2d(in_channels, in_channels, kernel_size=3, padding=1)

        # Pointwise convolution for the output Y
        self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)

        # Softmax for the attention mechanism
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, De, En):
        # Compute Key: pointwise convolution, deformable convolution
        key_offset = self.key_offset(De)
        Key = self.key_conv2(self.key_conv1(De), key_offset)

        # Compute Value: pointwise convolution, deformable convolution
        value_offset = self.value_offset(De)
        Value = self.value_conv2(self.value_conv1(De), value_offset)

        # Compute Query:
        Query = En

        # Compute Key_R
        Key_R = torch.matmul(De, Key)

        # Compute Value_R
        Value_R = torch.matmul(De, Value)

        # Compute Query_R
        Query_R = torch.matmul(En, Query)

        # Perform matrix multiplication of Query and transposed Key
        product = torch.matmul(Query_R, Key_R.transpose(-1, -2))

        # Apply softmax to the result of matrix multiplication
        attention = self.softmax(product)

        # Multiply the attention matrix with Value
        Y = torch.matmul(attention, Value_R)

        # Apply pointwise convolution to Y before returning
        Y = self.out_conv(Y)

        return Y, De

### 1.2. **Gated Feed Forward Layer**

In [None]:
class GFFL(nn.Module):
    def __init__(self, in_channels):
        super(GFFL, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)

    def forward(self, x):
        phi = self.conv1(x)
        psi = self.conv2(x)
        g_psi = F.gelu(psi)

        return phi + g_psi

### 1.3. **Nested Deformable Multi-Head Attention Layer**

In [None]:
class NDMAL(nn.Module):
    def __init__(self, in_channels, num_heads):
        super(NDMAL, self).__init__()
        self.in_channels = in_channels
        self.num_heads = num_heads

        self.dmha1 = NestedDMHA(in_channels, num_heads)
        self.dmha2 = NestedDMHA(in_channels, num_heads)

        self.norm1 = None
        self.norm2 = None
        self.norm3 = None

        self.gffl = GFFL(in_channels)

    def forward(self, De, En):
        # First DMHA
        Yp, En_p = self.dmha1(De, En)

        # Second DMHA
        Yu, _ = self.dmha2(Yp, En_p)

        # Computing Yd
        Yd = De + Yu
        if self.norm1 is None:
            self.norm1 = nn.LayerNorm(Yd.shape[1:])

        # Computing Ye
        Ye = En + Yp
        if self.norm2 is None:
            self.norm2 = nn.LayerNorm(Ye.shape[1:])

        # Gated Feed-Forward Layer
        Ygffl = self.gffl(Yd)

        # Computing Yd'
        Yd_prime = Yd + Ygffl
        if self.norm3 is None:
            self.norm3 = nn.LayerNorm(Yd_prime.shape[1:])

        # Concatenating Yd' and Ye
        Y_prime = torch.cat([Yd_prime, Ye], dim=1)

        return Y_prime

### 1.4. **Gated Convolutional Layer**

In [None]:
class GatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
        super(GatedConv2d, self).__init__()
        self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation)
        self.conv_gate = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        features = self.conv_feature(x)
        gates = self.sigmoid(self.conv_gate(x))
        return features * gates

### 1.5. **Gated Deconvolutional Layer**

In [None]:
class GatedDeconv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1):
        super(GatedDeconv2d, self).__init__()
        self.deconv_feature = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding, dilation=dilation)
        self.deconv_gate = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding, dilation=dilation)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        features = self.deconv_feature(x)
        gates = self.sigmoid(self.deconv_gate(x))
        return features * gates

## 2. **Model Architecture**

In [None]:
class CustomInpainting(nn.Module):
    def __init__(self, in_channels):
        super(CustomInpainting, self).__init__()

        self.gated_conv_A_1 = GatedConv2d(in_channels, 64, 3, stride=2)
        self.gated_conv_A_2 = GatedConv2d(64, 128, 3, stride=2)
        self.gated_conv_A_3 = GatedConv2d(128, 256, 3, stride=2)
        self.gated_conv_A_4 = GatedConv2d(256, 512, 3, stride=2)

        self.gated_conv_B_1 = GatedConv2d(512, 512, 3, stride=1, padding=1)
        self.gated_conv_B_2 = GatedConv2d(512, 512, 3, stride=1, padding=1)
        self.gated_conv_B_3 = GatedConv2d(512, 512, 3, stride=1, padding=1)
        self.gated_conv_B_4 = GatedConv2d(512, 512, 3, stride=1, padding=1)

        self.ndmal1 = NDMAL(512, num_heads=8)
        self.deconv1 = GatedDeconv2d(1024, 256, 3, stride=2)
        self.ndmal2 = NDMAL(256, num_heads=8)
        self.deconv2 = GatedDeconv2d(512, 128, 3, stride=2)
        self.ndmal3 = NDMAL(128, num_heads=8)
        self.deconv3 = GatedDeconv2d(256, 64, 3, stride=2)
        self.ndmal4 = NDMAL(64, num_heads=8)

        self.output_deconv = GatedDeconv2d(128, in_channels-1, 3, 2, output_padding=1)
        self.output_act = nn.Sigmoid()


    def forward(self, x, mask):
        x = torch.cat([x, mask], dim=1) # We concatenate the mask with the image

        x_A1 = self.gated_conv_A_1(x)
        x_A1 = torch.nn.functional.leaky_relu(x_A1, negative_slope=0.2)
        x_A2 = self.gated_conv_A_2(x_A1)
        x_A2 = torch.nn.functional.leaky_relu(x_A2, negative_slope=0.2)
        x_A3 = self.gated_conv_A_3(x_A2)
        x_A3 = torch.nn.functional.leaky_relu(x_A3, negative_slope=0.2)
        x_A4 = self.gated_conv_A_4(x_A3)
        x_A4 = torch.nn.functional.leaky_relu(x_A4, negative_slope=0.2)

        x_B1 = self.gated_conv_B_1(x_A4)
        x_B1 = torch.nn.functional.leaky_relu(x_B1, negative_slope=0.2)
        x_B2 = self.gated_conv_B_2(x_B1)
        x_B2 = torch.nn.functional.leaky_relu(x_B2, negative_slope=0.2)
        x_B3 = self.gated_conv_B_3(x_B2)
        x_B3 = torch.nn.functional.leaky_relu(x_B3, negative_slope=0.2)
        x_B4 = self.gated_conv_B_4(x_B3)
        x_B4 = torch.nn.functional.leaky_relu(x_B4, negative_slope=0.2)

        ndmal1_out = self.ndmal1(x_B4, x_A4)
        deconv1_out = self.deconv1(ndmal1_out)
        ndmal2_out = self.ndmal2(deconv1_out, x_A3)
        deconv2_out = self.deconv2(ndmal2_out)
        ndmal3_out = self.ndmal3(deconv2_out, x_A2)
        deconv3_out = self.deconv3(ndmal3_out)
        ndmal4_out = self.ndmal4(deconv3_out, x_A1)

        output_deconv = self.output_deconv(ndmal4_out)
        output = self.output_act(output_deconv)

        return output


**Stable Model but incomplete**

In [None]:
import torch
import torch.nn as nn

class InpaintingModel(nn.Module):
    def __init__(self, in_channels):
        super(InpaintingModel, self).__init__()
        
        # Encoder layers
        self.encoder_conv1 = GatedConv2d(in_channels, 64, kernel_size=3, stride=1, padding=1)
        self.encoder_conv2 = GatedConv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.encoder_conv3 = GatedConv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.encoder_conv4 = GatedConv2d(256, 512, kernel_size=3, stride=2, padding=1)

        self.gated_conv_B_1 = GatedConv2d(512, 512, 3, stride=1, padding=1)
        self.gated_conv_B_2 = GatedConv2d(512, 512, 3, stride=1, padding=1)
        self.gated_conv_B_3 = GatedConv2d(512, 512, 3, stride=1, padding=1)
        self.gated_conv_B_4 = GatedConv2d(512, 512, 3, stride=1, padding=1)

        
        # Decoder layers
        self.decoder_conv1 = GatedDeconv2d(1024, 256, kernel_size=4, stride=2, padding=1)
        self.decoder_conv2 = GatedDeconv2d(512, 128, kernel_size=4, stride=2, padding=1)
        self.decoder_conv3 = GatedDeconv2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.decoder_conv4 = GatedDeconv2d(64, in_channels-1, kernel_size=3, stride=1, padding=1)

        # NDMAL Layers
        self.ndmal1 = NDMAL(512, num_heads=8)
        self.ndmal2 = NDMAL(256, num_heads=8)
        self.ndmal3 = NDMAL(128, num_heads=8)
        self.ndmal4 = NDMAL(64, num_heads=8)
        
        self.activation = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, mask):
        x = torch.cat([x, mask], dim=1) # We concatenate the mask with the image
        # Encoder pass
        x_A1 = self.encoder_conv1(x)
        x_A1 = torch.nn.functional.leaky_relu(x_A1, negative_slope=0.2)
        x_A2 = self.encoder_conv2(x_A1)
        x_A2 = torch.nn.functional.leaky_relu(x_A2, negative_slope=0.2)
        x_A3 = self.encoder_conv3(x_A2)
        x_A3 = torch.nn.functional.leaky_relu(x_A3, negative_slope=0.2)
        x_A4 = self.encoder_conv4(x_A3)
        x_A4 = torch.nn.functional.leaky_relu(x_A4, negative_slope=0.2)
  
        x_B1 = self.gated_conv_B_1(x_A4)
        x_B1 = torch.nn.functional.leaky_relu(x_B1, negative_slope=0.2)
        x_B2 = self.gated_conv_B_2(x_B1)
        x_B2 = torch.nn.functional.leaky_relu(x_B2, negative_slope=0.2)
        x_B3 = self.gated_conv_B_3(x_B2)
        x_B3 = torch.nn.functional.leaky_relu(x_B3, negative_slope=0.2)
        x_B4 = self.gated_conv_B_4(x_B3)
        x_B4 = torch.nn.functional.leaky_relu(x_B4, negative_slope=0.2)
        

        # Decoder & NDMAL pass
        ndmal_1 = self.ndmal1(x_B4, x_A4)
        deconv_1 = self.decoder_conv1(ndmal_1)
        deconv_1 = torch.nn.functional.leaky_relu(deconv_1, negative_slope=0.2)
        ndmal_2 = self.ndmal2(deconv_1, x_A3)
        deconv_2 = self.decoder_conv2(ndmal_2)
        deconv_2 = torch.nn.functional.leaky_relu(deconv_2, negative_slope=0.2)
        deconv_3 = self.decoder_conv3(deconv_2)
        deconv_3 = torch.nn.functional.leaky_relu(deconv_3, negative_slope=0.2)
        deconv_4 = self.decoder_conv4(deconv_3)
        deconv_4 = self.sigmoid(deconv_4)
        
        return deconv_4


## 3. **Loss Function**

In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1_loss = nn.L1Loss()
        self.vgg = vgg19(pretrained=True).features.eval()
        self.vgg = self.vgg.to(device)
        self.lambdas = {
            'l1': 10,
            'adv': 0,
            'edge': 2,
            'perceptual': 3
        }

    def adversarial_loss(self, corrupted, target):
        pass

    def perceptual_loss(self, target, inpainted):
        real_features = self.vgg(target)
        inpainted_features = self.vgg(inpainted)
        return self.l1_loss(real_features, inpainted_features)

    def edge_loss(self, target, inpainted):
        target = 0.299 * target[:, 0, :, :] + 0.587 * target[:, 1, :, :] + 0.114 * target[:, 2, :, :]
        inpainted = 0.299 * inpainted[:, 0, :, :] + 0.587 * inpainted[:, 1, :, :] + 0.114 * inpainted[:, 2, :, :]
        sobel_x = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).view(1, 1, 3, 3).to(target.device)
        sobel_y = torch.Tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).view(1, 1, 3, 3).to(target.device)
        real_edges_x = nn.functional.conv2d(target, sobel_x, padding=1)
        real_edges_y = nn.functional.conv2d(target, sobel_y, padding=1)
        inpainted_edges_x = nn.functional.conv2d(inpainted, sobel_x, padding=1)
        inpainted_edges_y = nn.functional.conv2d(inpainted, sobel_y, padding=1)
        epsilon = 1e-7 # to avoid division by zero
        real_edges = torch.sqrt(real_edges_x ** 2 + real_edges_y ** 2 + epsilon)
        inpainted_edges = torch.sqrt(inpainted_edges_x ** 2 + inpainted_edges_y ** 2 + epsilon)
        return self.l1_loss(real_edges, inpainted_edges)

    def forward(self, target, inpainted):
        l1_loss = self.l1_loss(target, inpainted)
        #adv_loss = self.adversarial_loss(corrupted, inpainted)
        perc_loss = self.perceptual_loss(target, inpainted)
        edge_loss = self.edge_loss(target, inpainted)
        total_loss = self.lambdas['l1'] * l1_loss + self.lambdas['edge'] * edge_loss + self.lambdas['perceptual'] * perc_loss
            #+ self.lambdas['adv'] * adv_loss
        return total_loss


## 4. **Metrics**

### 6.1. **PSNR: Peak Signal to Noise Ratio**

In [None]:
def compute_psnr(img1, img2, max_pixel_value=1.0):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return 100
    return 20 * torch.log10(max_pixel_value / torch.sqrt(mse))


### 6.2. **SSIM: Structural Similarity Index**

In [None]:
# SSIM
def ssim_manual(img1, img2, window_size=11, k1=0.01, k2=0.03, L=1.0):
    C1 = (k1 * L) ** 2
    C2 = (k2 * L) ** 2

    # Compute the means of img1 and img2
    mu1 = F.avg_pool2d(img1, kernel_size=window_size, padding=window_size//2, stride=1)
    mu2 = F.avg_pool2d(img2, kernel_size=window_size, padding=window_size//2, stride=1)

    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2

    # Compute the variances and covariance
    sigma1_sq = F.avg_pool2d(img1 * img1, kernel_size=window_size, padding=window_size//2, stride=1) - mu1_sq
    sigma2_sq = F.avg_pool2d(img2 * img2, kernel_size=window_size, padding=window_size//2, stride=1) - mu2_sq
    sigma12 = F.avg_pool2d(img1 * img2, kernel_size=window_size, padding=window_size//2, stride=1) - mu1_mu2

    # Calculate SSIM
    ssim_numerator = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2)
    ssim_denominator = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
    ssim_value = ssim_numerator / ssim_denominator

    # Return the mean SSIM value
    return torch.mean(ssim_value)

### 6.3. **L1 Loss**

In [None]:
# L1 loss
def l1_loss(image1, image2):
    return F.l1_loss(image1, image2)

## 5. **Data Importation**

In [None]:
####################
# Global variables #
####################

# Device to use for training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Batch size
batch_size = 1
# Image paths
image_folder = '/content/drive/MyDrive/testFaces/test/corrupt train'
mask_folder = '/content/drive/MyDrive/testFaces/test/masks'
target_folder = '/content/drive/MyDrive/testFaces/test/target train'

# For data split
training_percentage = 0.6
validation_percentage = 0.2

# Random state
random_state = 0


In [None]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, rgb=True):
        self.root_dir = root_dir
        self.image_files = os.listdir(root_dir)
        self.rgb = rgb
        self.image_numbers = [int(file.split(".")[0]) for file in self.image_files]
        self.sorted_image_files = [x for _, x in sorted(zip(self.image_numbers, self.image_files))]

    def __len__(self):
        return len(self.image_files)
    
    # I want to return to tensors, one for the image and one for the images number (id)
    def __getitem__(self, idx):
        # Get the image path
        img_path = os.path.join(self.root_dir, self.sorted_image_files[idx])
        if self.rgb:
            # Open the image using openCV
            image = cv2.imread(img_path, cv2.IMREAD_COLOR)
            # Convert the image from BGR to RGB
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        else:
            # Open the image using openCV
            image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        # Resize the image to 512x512
        image = cv2.resize(image, (512, 512))
        # Normalize the image
        image = image / 255.0
        # Convert the image to a tensor
        if self.rgb:
            image = torch.from_numpy(image).float().permute(2, 0, 1)
        else:
            image = torch.from_numpy(image).float().unsqueeze(0)

        return image


# Instantiate the dataset
image_dataset = ImageDataset(root_dir=image_folder, rgb=True)
mask_dataset = ImageDataset(root_dir=mask_folder, rgb=False)
target_dataset = ImageDataset(root_dir=target_folder, rgb=True)


# We split the data into training, validation and test sets
train_size = int(training_percentage * len(image_dataset))
val_size = int(validation_percentage * len(image_dataset))
test_size = len(image_dataset) - train_size - val_size


train_image, val_image, test_image = random_split(image_dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(random_state))
train_mask, val_mask, test_mask = random_split(mask_dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(random_state))
train_target, val_target, test_target = random_split(target_dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(random_state))

# We create the data loaders
train_loader = DataLoader(train_image, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_image, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_image, batch_size=batch_size, shuffle=False)

# We create the mask loaders
train_mask_loader = DataLoader(train_mask, batch_size=batch_size, shuffle=False)
val_mask_loader = DataLoader(val_mask, batch_size=batch_size, shuffle=False)
test_mask_loader = DataLoader(test_mask, batch_size=batch_size, shuffle=False)

# We create the target loaders
train_target_loader = DataLoader(train_target, batch_size=batch_size, shuffle=False)
val_target_loader = DataLoader(val_target, batch_size=batch_size, shuffle=False)
test_target_loader = DataLoader(test_target, batch_size=batch_size, shuffle=False)

# We will merge the train and validation sets to create a new training set
train_val_loader = DataLoader(train_image + val_image, batch_size=batch_size, shuffle=False)
train_val_mask_loader = DataLoader(train_mask + val_mask, batch_size=batch_size, shuffle=False)
train_val_target_loader = DataLoader(train_target + val_target, batch_size=batch_size, shuffle=False)


In [None]:
# Plot one image
fig, ax = plt.subplots(1, 3, figsize=(15, 15))
for i, (image, mask, target) in enumerate(zip(train_val_loader, train_val_mask_loader, train_val_target_loader)):
    ax[0].imshow(image[0].squeeze().numpy().transpose(1, 2, 0))
    ax[0].set_title("Image")
    ax[1].imshow(mask[0].squeeze().numpy(), cmap="gray")
    ax[1].set_title("Mask")
    ax[2].imshow(target[0].squeeze().numpy().transpose(1, 2, 0))
    ax[2].set_title("Target")
    break

## 6. **Training**

In [None]:
####################
# Global variables #
####################

# Number of epochs
num_epochs = 12
# Learning rate
learning_rate = 0.00008
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Beta2 hyperparam for Adam optimizers
beta2 = 0.99

In [None]:
# Instantiate the model
model = InpaintingModel(in_channels=4)

# Move the model to the device
model = model.to(device)

# Loss criterion and optimizer
criterion = CustomLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(beta1, beta2))

# we want to check if device is cuda or cpu
is_cuda = torch.cuda.is_available()


# Training loop
loss_list = []
psnr_list = []
ssim_list = []
l1_list = []
for epoch in range(num_epochs):
    time_start = time.time()
    count_batch = 0
    loss_epoch = []
    psnr_epoch = []
    ssim_epoch = []
    l1_epoch = []
    for (input_image, mask, target) in zip(train_val_loader, train_val_mask_loader, train_val_target_loader):
        # Move the tensors to the device
        input_image = input_image.to(device)
        mask = mask.to(device)
        target = target.to(device)

        # Forward pass
        inpainted_images = model(input_image, mask)

        # Plot the images
        if count_batch % 200 == 0:
            plt.figure(figsize=(20, 20))
            plt.subplot(1, 4, 1)

            if is_cuda:
                plt.imshow(input_image[0].permute(1, 2, 0).detach().cpu())
            else:
                plt.imshow(input_image[0].permute(1, 2, 0))
            plt.subplot(1, 4, 2)
            if is_cuda:
                plt.imshow(mask[0][0].detach().cpu(), cmap="gray")
            else:
                plt.imshow(mask[0][0], cmap="gray")
            plt.subplot(1, 4, 3)
            if is_cuda:
                plt.imshow(inpainted_images[0].permute(1, 2, 0).detach().cpu())
            else:
                plt.imshow(inpainted_images[0].permute(1, 2, 0).detach().cpu())
            plt.subplot(1, 4, 4)
            if is_cuda:
                plt.imshow(target[0].permute(1, 2, 0).detach().cpu())
            else:
                plt.imshow(target[0].permute(1, 2, 0))
            plt.show()

        # Calculate the losses
        loss = criterion(target, inpainted_images)
        loss_epoch.append(loss.item())

        # Calculate the metrics
        psnr = compute_psnr(target, inpainted_images)
        ssim = ssim_manual(target, inpainted_images)
        l1 = l1_loss(target, inpainted_images)
        psnr_epoch.append(psnr.item())
        ssim_epoch.append(ssim.item())
        l1_epoch.append(l1.item())

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print the percentage of batches done each 10 batches
        if count_batch % 200 == 0:
            print("Epoch: {}/{}...".format(epoch+1, num_epochs),
                    "Batch: {}...".format(count_batch),
                    "Loss: {:.4f}...".format(np.mean(loss_epoch)),
                    "PSNR: {:.4f}...".format(np.mean(psnr_epoch)),
                    "SSIM: {:.4f}...".format(np.mean(ssim_epoch)),
                    "L1: {:.4f}".format(np.mean(l1_epoch)))
        count_batch += 1
    
    # Calculate the loss and metrics for the epoch
    loss_list.append(np.mean(loss_epoch))
    psnr_list.append(np.mean(psnr_epoch))
    ssim_list.append(np.mean(ssim_epoch))
    l1_list.append(np.mean(l1_epoch))


    time_end = time.time()
    print("Epoch: {}/{}...".format(epoch+1, num_epochs),
            "Loss: {:.4f}...".format(np.mean(loss_epoch)),
            "PSNR: {:.4f}...".format(np.mean(psnr_epoch)),
            "SSIM: {:.4f}...".format(np.mean(ssim_epoch)),
            "L1: {:.4f}...".format(np.mean(l1_epoch)),
            "Time: {:.4f}".format(time_end - time_start))
    

# Plot the loss and metrics on a subplot
fig, ax = plt.subplots(1, 2, figsize=(15, 15))
ax[0].plot(loss_list)
ax[0].set_title("Loss")
ax[1].plot(psnr_list)
ax[1].set_title("PSNR")
ax[2].plot(ssim_list)
ax[2].set_title("SSIM")
ax[3].plot(l1_list)
ax[3].set_title("L1")

    
# Save the trained model
torch.save(model.state_dict(), "inpainting_model.pth")


In [None]:
# Now we want to evaluate the model on the test set using the trained model

# Load the trained model
# !!!!!!!!!!!!!
#model.load_state_dict(torch.load("inpainting_model.pth"))
# !!!!!!!!!!!!!

# Set the model to evaluation mode
model.eval()

# Test loop
test_loss_list = []
test_psnr_list = []
test_ssim_list = []
test_l1_list = []
inpainted_images_list = []

for (input_image, mask, target) in zip(test_loader, test_mask_loader, test_target_loader):
    # Move the tensors to the device
    input_image = input_image.to(device)
    mask = mask.to(device)
    target = target.to(device)

    # Forward pass
    start_time = time.time()
    inpainted_images = model(input_image, mask)
    end_time = time.time()
    print("Time: {:.4f}".format(end_time - start_time))
    inpainted_images_list.append(inpainted_images)

    # Calculate the losses
    loss = criterion(target, inpainted_images)

    # Calculate the metrics
    psnr = compute_psnr(target, inpainted_images)
    ssim = ssim_manual(target, inpainted_images)
    l1 = l1_loss(target, inpainted_images)
    test_loss_list.append(loss.item())
    test_psnr_list.append(psnr.item())
    test_ssim_list.append(ssim.item())
    test_l1_list.append(l1.item())



# Print the test metrics
print("Test Loss: {:.4f}".format(np.mean(test_loss_list)),
        "Test PSNR: {:.4f}".format(np.mean(test_psnr_list)),
        "Test SSIM: {:.4f}".format(np.mean(test_ssim_list)),
        "Test L1: {:.4f}".format(np.mean(test_l1_list)))

# Plot 2 images from the test set
plt.figure(figsize=(20, 20))
plt.subplot(1, 4, 1)
plt.imshow(input_image[0].permute(1, 2, 0).detach().cpu())
plt.subplot(1, 4, 2)
plt.imshow(mask[0][0].detach().cpu(), cmap="gray")
plt.subplot(1, 4, 3)
plt.imshow(inpainted_images[0].permute(1, 2, 0).detach().cpu())
plt.subplot(1, 4, 4)
plt.imshow(target[0].permute(1, 2, 0).detach().cpu())
plt.show()

# Plot 2 images from the test set
plt.figure(figsize=(20, 20))
plt.subplot(1, 4, 1)
plt.imshow(input_image[1].permute(1, 2, 0).detach().cpu())
plt.subplot(1, 4, 2)
plt.imshow(mask[1][0].detach().cpu(), cmap="gray")
plt.subplot(1, 4, 3)
plt.imshow(inpainted_images[1].permute(1, 2, 0).detach().cpu())
plt.subplot(1, 4, 4)
plt.imshow(target[1].permute(1, 2, 0).detach().cpu())
plt.show()

