<h1>Image Recolouring Project</h1>

Project developed by Alejandro Cano Caldero and Jesús Moncada Ramírez for the subject Neural Networks and Deep Learning, University of Padova, 2022-23.


In [1]:
from PIL import Image

import numpy as np

import torch

import matplotlib.pyplot as plt

from torchvision import transforms, datasets
from torchvision.transforms import transforms

from torch.utils.data import DataLoader, Dataset  

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as op

In [2]:
# Define the execution device 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


<h2>1. Dataset</h2>

For the dataset we have used [ImageNette](https://github.com/fastai/imagenette), a reduced version of ImageNet, specifically the fill size images version.

In [5]:
class ImageDataset(Dataset):
    def __init__(self, image_path, transform=None):
        super(ImageDataset, self).__init__()
        self.data = datasets.ImageFolder(image_path,  transform)

    def __getitem__(self, idx):
        x, y = self.data[idx]
        return x

    def __len__(self):
        return len(self.data)

In [6]:
class TwoImagesDataset(Dataset):
  def __init__(self, dataset1, dataset2):
        super(TwoImagesDataset, self).__init__()
        self.dataset1 = dataset1
        self.dataset2 = dataset2

  def __getitem__(self, idx):
        return self.dataset1[idx], self.dataset2[idx]

  def __len__(self):
        return len(self.dataset1)

In [7]:
img_path = 'drive/MyDrive/imagenette2/train'

# Define normalization [0, 255] --> [-1, 1] (Owing to the use of the tanh activation function)

colored_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

grayscale_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Grayscale(3),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

colored_data = ImageDataset(img_path, colored_transform)
grayscale_data = ImageDataset(img_path, grayscale_transform)
dataset = TwoImagesDataset(colored_data, grayscale_data)

In [8]:
def print_image(img):
  plt.imshow(img.permute(1, 2, 0))

In [9]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, kernel_size, stride=1):
        super().__init__()
        
        self.layer1 = nn.Conv2d(in_channels, out_channels=64, kernel_size=kernel_size, stride=2, padding=1)
        self.layer2 = nn.Conv2d(64, out_channels=128, kernel_size=kernel_size, stride=2, padding=1)
        self.layer2_bn = nn.BatchNorm2d(128)
        self.layer3 = nn.Conv2d(128, out_channels=256, kernel_size=kernel_size, stride=2, padding=1)
        self.layer3_bn = nn.BatchNorm2d(256)
        self.layer4 = nn.Conv2d(256, out_channels=512, kernel_size=kernel_size, padding=1) # stride = 1
        self.layer4_bn = nn.BatchNorm2d(512)
        self.layer5 = nn.Conv2d(512, out_channels=1, kernel_size=kernel_size, padding=1)
        
    
    def forward(self, x):
        d = F.leaky_relu(self.layer1(x), 0.2)
        print(d.shape)
        d = F.leaky_relu(self.layer2_bn(self.layer2(d)), 0.2)
        print(d.shape)
        d = F.leaky_relu(self.layer3_bn(self.layer3(d)), 0.2)
        print(d.shape)
        d = F.leaky_relu(self.layer4_bn(self.layer4(d)), 0.2)
        print(d.shape)
        d = self.layer5(d)
        print(d.shape)
        
        # Each (1×1) of the 30×30 represents a 70×70 dimension 
        # in the input image (256×256), classifying a single patch of the original 
        # image as real or fake.
        return torch.sigmoid(d)
        

discriminator = Discriminator(3, 4)
discriminator.to(device)

loss_fn = nn.BCELoss(weight=torch.tensor(0.5))
optimizer = op.Adam(discriminator.parameters(), lr=0.0002, weight_decay=0.5)

In [None]:
for input_img, tg_img in dataset:
  disc_in = torch.cat((input_img, tg_img), 1)
  var = discriminator.forward(disc_in.unsqueeze(0))
  print(var)
  # print(var.shape)

In [15]:
class Generator(nn.Module):
  def __init__(self, in_channels, stride=1):
    super().__init__()

    # ----------------------------- ENCODER ----------------------------
    #            ENCODER MODEL: C64-C128-C256-C512-C512-C512-C512
    self.layer1 = nn.Conv2d(in_channels, out_channels=64, kernel_size=4, stride=2, padding=1)
    self.layer2 = nn.Conv2d(64, out_channels=128, kernel_size=4, stride=2, padding=1)
    self.layer2_bn = nn.BatchNorm2d(128)
    self.layer3 = nn.Conv2d(128, out_channels=256, kernel_size=4, stride=2, padding=1)
    self.layer3_bn = nn.BatchNorm2d(256)
    self.layer4 = nn.Conv2d(256, out_channels=512, kernel_size=4, stride=2, padding=1)
    self.layer4_bn = nn.BatchNorm2d(512)
    self.layer5 = nn.Conv2d(512, out_channels=512, kernel_size=4, stride=2, padding=1)
    self.layer5_bn = nn.BatchNorm2d(512)
    self.layer6 = nn.Conv2d(512, out_channels=512, kernel_size=4, stride=2, padding=1)
    self.layer6_bn = nn.BatchNorm2d(512)
    self.layer7 = nn.Conv2d(512, out_channels=512, kernel_size=4, stride=2, padding=1)
    self.layer7_bn = nn.BatchNorm2d(512)
    # ----------------------------- /ENCODER ----------------------------


    # ----------------------------- BOTTLENECK ----------------------------
    self.bottleneck_layer = nn.Conv2d(512, out_channels=512, kernel_size=4, stride=2, padding=1)
    # ----------------------------- /BOTTLENECK ----------------------------


    # ----------------------------- DECODER ----------------------------
    #           DECODER MODEL: CD512-CD512-CD512-CD512-CD256-CD128-CD64
    self.layer8 = nn.ConvTranspose2d(512, out_channels=512, kernel_size=4, stride=2, padding=1)
    self.layer8_bn = nn.BatchNorm2d(512)
    self.layer8_dpout = nn.Dropout()
    self.layer9 = nn.ConvTranspose2d(512, out_channels=512, kernel_size=4, stride=2, padding=1)
    self.layer9_bn = nn.BatchNorm2d(512)
    self.layer9_dpout = nn.Dropout()
    self.layer10 = nn.ConvTranspose2d(512, out_channels=512, kernel_size=4, stride=2, padding=1)
    self.layer10_bn = nn.BatchNorm2d(512)
    self.layer10_dpout = nn.Dropout()
    self.layer11 = nn.ConvTranspose2d(512, out_channels=512, kernel_size=4, stride=2, padding=1)
    self.layer11_bn = nn.BatchNorm2d(512)
    self.layer12 = nn.ConvTranspose2d(512, out_channels=256, kernel_size=4, stride=2, padding=1)
    self.layer12_bn = nn.BatchNorm2d(256)
    self.layer13 = nn.ConvTranspose2d(256, out_channels=128, kernel_size=4, stride=2, padding=1)
    self.layer13_bn = nn.BatchNorm2d(128)
    self.layer14 = nn.ConvTranspose2d(128, out_channels=64, kernel_size=4, stride=2, padding=1)
    self.layer14_bn = nn.BatchNorm2d(64)
    # ----------------------------- /DECODER ----------------------------

    # ----------------------------- OUTPUT ----------------------------
    self.layer15 = nn.ConvTranspose2d(64, out_channels=3, kernel_size=4, stride=2, padding=1)

  def forward(self, x):

    # ----------------------------- ENCODER ----------------------------
    e1 = F.leaky_relu(self.layer1(x), 0.2)
    e2 = F.leaky_relu(self.layer2_bn(self.layer2(e1)), 0.2)
    e3 = F.leaky_relu(self.layer3_bn(self.layer3(e2)), 0.2)
    e4 = F.leaky_relu(self.layer4_bn(self.layer4(e3)), 0.2)
    e5 = F.leaky_relu(self.layer5_bn(self.layer5(e4)), 0.2)
    e6 = F.leaky_relu(self.layer6_bn(self.layer6(e5)), 0.2)
    e7 = F.leaky_relu(self.layer7_bn(self.layer7(e6)), 0.2) 
    # ----------------------------- /ENCODER ----------------------------


    # ----------------------------- BOTTLENECK ----------------------------
    b = F.relu(self.bottleneck_layer(e7))


    # ----------------------------- DECODER ----------------------------
    d1 = F.relu(torch.cat((self.layer8_dpout(self.layer8_bn(self.layer8(b))), e7)))
    d2 = F.relu(torch.cat((self.layer9_dpout(self.layer9_bn(self.layer9(d1))), e6)))
    d3 = F.relu(torch.cat((self.layer10_dpout(self.layer10_bn(self.layer10(d2))), e5)))
    d4 = F.relu(torch.cat((self.layer11_bn(self.layer11(d3))), e4))
    d5 = F.relu(torch.cat((self.layer12_bn(self.layer12(d4))), e3))
    d6 = F.relu(torch.cat((self.layer13_bn(self.layer13(d5))), e2))
    d7 = F.relu(torch.cat((self.layer14_bn(self.layer14(d6))), e1))

    # ----------------------------- OUTPUT ----------------------------
    o = F.tanh(self.layer15(d7))

    return o

In [16]:
generator = Generator(3)

In [19]:
for input_img, tg_img in dataset:
  var = generator.forward(input_img.unsqueeze(0))
  print(var)

RuntimeError: ignored

In [None]:
adversarial_loss = nn.BCELoss(weight=torch.tensor(0.5))
l1_loss = nn.L1Loss()

In [None]:
def generator_loss(generator_image, target_image, discriminator_predictions, real_target):
  gen_loss = adversarial_loss(discriminator_predictions, real_target)
  l1_l = l1_loss(generator_image, target_image)
  result = gen_loss + (100 * l1_l)

  return result

In [None]:
def discriminator_loss(output, label):
  return adversarial_loss(output, label)