In [None]:
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import os


import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.utils import save_image
from torch import optim
from tqdm import tqdm

from fastai.vision.all import untar_data, URLs
import sys

from torch.utils.data import (
    DataLoader,
)  # Gives easier dataset managment and creates mini batches


## Dataloader

In [None]:
class FlowerDataset(Dataset):
    def __init__(self, dir):
        self.dir = dir
        self.list_of_files = [dir for dir in os.listdir(self.dir) if dir != ".ipynb_checkpoints"] # for some reason colab created this folder inside train and/or test

       

    def __len__(self):
        return len(self.list_of_files)

    def __getitem__(self, index):
        # get actual file path
        filename = self.list_of_files[index]
        path = os.path.join(self.dir, filename)

        # load initial image with both coloured and gray image.
        image = Image.open(path)

        # Get the coloured image by converting it to a numpy array and subsetting.
        image_colour = np.array(image)[:, :512, :]

        #print('converting to tensor')
        to_tensor = transforms.ToTensor()
        image_colour = to_tensor(image_colour)

        #print('greyscaling bw image')
        bw_transform = transforms.Grayscale()

        image_bw = bw_transform(image_colour)

        normalise = transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
        image_colour = normalise(image_colour)

        #print('Done')
        return (image_bw, image_colour)

## Config

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 2e-4
TRAIN_DIR = "Data/train"
TEST_DIR = "Data/test"
# TRAIN_DIR = "quick_data/train"
# TEST_DIR = "quick_data/test"
EXAMPLE_DIR = "Example_results"

# The paper uses a batch size of 1
BATCH_SIZE = 1
NUM_WORKERS = 0
IMAGE_SIZE = 512
CHANNELS_IMG = 3
NUM_EPOCHS = 500
L1_LAMBDA = 100
LOAD_MODEL = False
SAVE_MODEL = True
SAVE_MODEL_EVERY_NTH = 5
CHECKPOINT_DISC = "Checkpoint/discriminator.pth.tar"
CHECKPOINT_GEN = "Checkpoint/generator.pth.tar"


## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            # Here in_channels is multiplied by two because we're going to send in both images concatenated in the channel dimension.
            # aka we'll be having 6 channels. (3 for each image)
            nn.Conv2d(in_channels*2, features[0], kernel_size=4,
                      stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2)
        )

        layers = []
        in_channels = features[0]

        for feature in features[1:]:
            layers.append(
                downsampling_conv(in_channels, feature,
                                  stride=1 if feature == features[-1] else 2)
            )
            in_channels = feature

        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            )
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        return self.model(x)


# Convolution segment for the discriminator(downsampling)
# Convolution -> BatchNorm -> leakyRelu
class downsampling_conv(nn.Module):
    def __init__(self, in_planes, out_planes, ks=4, stride=2, pad=1, bn=True):
        super().__init__()
        # stores if the batch norm should be applied
        self.bn = bn

        self.conv_downsample = nn.Conv2d(
            in_planes, out_planes, ks, stride, pad, padding_mode="reflect")
        self.batch_norm = nn.BatchNorm2d(out_planes)

        LEAKY_RELU_SLOPE = 0.2
        self.leaky_relu = nn.LeakyReLU(LEAKY_RELU_SLOPE)

    def forward(self, x):
        # Run through convolution then BatchNorm then leaky ReLU

        x = self.conv_downsample(x)

        if self.bn:
            x = self.batch_norm(x)

        x = self.leaky_relu(x)

        return x


## Generator

In [None]:
class Generator(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()

        # input:
        self.input_conv = nn.Conv2d(
            in_channels, 64, kernel_size=4, stride=2, padding=1)
        self.leaky_relu = nn.LeakyReLU(0.2)
        # encoder:
        # C64-C128-C256-C512-C512-C512-C512-C512
        self.enc_c64_128 = enc_downsampling_conv(64, 128, bn=False)
        self.enc_c128_256 = enc_downsampling_conv(128, 256)
        self.enc_c256_512 = enc_downsampling_conv(256, 512)
        self.enc_c512_512 = enc_downsampling_conv(512, 512)

        # decoder
        # CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
        # Note this is the U-Net version
        self.dec_cd512_512 = dec_dropout_upsampling_conv(512, 512)
        self.dec_cd1024_512 = dec_dropout_upsampling_conv(1024, 512)

        self.dec_c1024_512 = dec_upsampling_conv(1024, 512)
        self.dec_c1024_256 = dec_dropout_upsampling_conv(1024, 256)
        self.dec_c512_128 = dec_upsampling_conv(512, 128)
        self.dec_c256_64 = dec_upsampling_conv(256, 64)

        # output
        self.output_conv = nn.ConvTranspose2d(
            128, out_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, out):

        # INPUT
        out = self.input_conv(out)
        out = self.leaky_relu(out)
        skip_l1 = out

        # ENCODE
        # down 1:
        out = self.enc_c64_128(out)
        skip_l2 = out
        # down 2:
        out = self.enc_c128_256(out)
        skip_l3 = out
        # down 3:
        out = self.enc_c256_512(out)
        skip_l4 = out
        # down 4:
        out = self.enc_c512_512(out)
        skip_l5 = out
        # down 5:
        out = self.enc_c512_512(out)
        skip_l6 = out
        # down 6:
        out = self.enc_c512_512(out)
        skip_l7 = out

        # bottleneck:
        out = self.enc_c512_512(out)

        # DECODE
        # up 1:
        out = self.dec_cd512_512(out)
        # up 2:
        out = torch.cat((out, skip_l7), 1)
        out = self.dec_cd1024_512(out)
        # up 3:
        out = torch.cat((out, skip_l6), 1)
        out = self.dec_cd1024_512(out)
        # up 4:
        out = torch.cat((out, skip_l5), 1)
        out = self.dec_c1024_512(out)
        # up 5:
        out = torch.cat((out, skip_l4), 1)
        out = self.dec_c1024_256(out)
        # up 6:
        out = torch.cat((out, skip_l3), 1)
        out = self.dec_c512_128(out)
        # up 7:
        out = torch.cat((out, skip_l2), 1)
        out = self.dec_c256_64(out)

        # OUTPUT
        out = torch.cat((out, skip_l1), 1)
        out = self.output_conv(out)
        out = self.tanh(out)

        return out


# Convolution segment for encoding (downsampling)
# Convolution -> BatchNorm -> leakyRelu
class enc_downsampling_conv(nn.Module):
    def __init__(self, in_planes, out_planes, ks=4, stride=2, pad=1, bn=True):
        super().__init__()
        # stores if the batch norm should be applied
        self.bn = bn

        self.conv_downsample = nn.Conv2d(
            in_planes, out_planes, ks, stride, pad, padding_mode="reflect")
        self.batch_norm = nn.BatchNorm2d(out_planes)

        LEAKY_RELU_SLOPE = 0.2
        self.leaky_relu = nn.LeakyReLU(LEAKY_RELU_SLOPE)

    def forward(self, x):
        # Run through convolution then BatchNorm then leaky ReLU

        x = self.conv_downsample(x)

        if self.bn:
            x = self.batch_norm(x)

        x = self.leaky_relu(x)

        return x

# Convolution segment for decoding (upsampling) with dropout
# Convolution -> BatchNorm -> Dropout -> Relu

class dec_dropout_upsampling_conv(nn.Module):
    def __init__(self, in_planes, out_planes, ks=4, stride=2, pad=1):
        super().__init__()

        self.conv_upsample = nn.ConvTranspose2d(
            in_planes, out_planes, ks, stride, pad)

        self.batch_norm = nn.BatchNorm2d(out_planes)

        DROPOUT_RATE = 0.5
        self.dropout = nn.Dropout2d(DROPOUT_RATE)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # Run through convolution then dropout then ReLU
        x = self.conv_upsample(x)
        x = self.batch_norm(x)
        x = self.dropout(x)
        x = self.relu(x)

        return x

# Convolution segment for decoding (upsampling) without dropout
# Convolution -> BatchNorm -> Relu


class dec_upsampling_conv(nn.Module):
    def __init__(self, in_planes, out_planes, ks=4, stride=2, pad=1):
        super().__init__()

        self.conv_upsample = nn.ConvTranspose2d(
            in_planes, out_planes, ks, stride, pad)

        self.batch_norm = nn.BatchNorm2d(out_planes)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # Run through convolution then BatchNorm then ReLU

        x = self.conv_upsample(x)
        x = self.batch_norm(x)
        x = self.relu(x)

        return x

## Saving Models

In [None]:
import itertools

def save_some_examples(gen, val_loader, epoch, folder, num_photos=12):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE, dtype=torch.float), y.to(DEVICE, dtype=torch.float)

    gen.eval()
    with torch.no_grad():
        index = 0
        label_tensor = None
        output_tensor = None
        for data in itertools.islice(iter(val_loader), num_photos):
          x, y = data
          x, y = x.to(DEVICE, dtype=torch.float), y.to(DEVICE, dtype=torch.float)

          y_fake = gen(x)

          # remove normalization#
          y = y * 0.5 + 0.5
          y_fake = y_fake * 0.5 + 0.5

          if label_tensor is None:
            label_tensor = torch.Tensor(y)
          else:
            label_tensor = torch.cat((label_tensor, y), 3) # The 3 makes the concatenation happen column-wise.

          y_fake_denormalized = y_fake
          

          if output_tensor is None:
            output_tensor = torch.Tensor(y_fake)
          else:
            output_tensor = torch.cat((output_tensor, y_fake), 3)      
            

        save_image(torch.cat((output_tensor, label_tensor), 2), folder + f"/epoch{epoch}.png")
    gen.train()


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

## Training Function

In [None]:
def train_func(disc, gen, loader, optim_disc, optim_gen, l1_loss, bce, g_scaler, d_scaler, epoch_no
               ):

    # tqdm is for progress bar
    loop = tqdm(loader, leave=True)
    for idx, (x, y) in enumerate(loop):
        
        # x is bw image, y is colour image
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        # Many functions only work with float data type
        x = x.float()
        y = y.float()

        # x3 = torch.cat((x, x, x), 1)

        #print(x.size())
        #print(y.size())

        # Train the discriminator
        with torch.cuda.amp.autocast():
            
            y_fake = gen(x)
            
            # Loss from fake image
            D_fake = disc(y, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))

            # Loss from real image
            D_real = disc(y, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))

            # Some sources says that the discriminator trains too fast compared to the generator, so it is halved
            # Another source did not have it halved
            D_loss = (D_real_loss + D_fake_loss) / 2

        if torch.sigmoid(D_fake).mean().item() > 0.35:
          disc.zero_grad()
          d_scaler.scale(D_loss).backward()
          d_scaler.step(optim_disc)
          d_scaler.update()

        # Train the generator
        with torch.cuda.amp.autocast():
            D_fake = disc(y, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))

            # L1 is a loss function (least absolute deviations)
            L1 = l1_loss(y_fake, y) * L1_LAMBDA
            G_loss = G_fake_loss + L1


        optim_gen.zero_grad()
        g_scaler.scale(G_loss).backward()

        g_scaler.step(optim_gen)
        g_scaler.update()


        if idx == 0:
            print(
                f"Epoch [{epoch_no}/{NUM_EPOCHS}] Batch {idx}/{len(loader)} \
                      Loss D: {D_loss:.4f}, loss G: {G_loss:.4f}"
            )

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )




def main():
    disc = Discriminator(in_channels=3).to(DEVICE)
    gen = Generator(in_channels=1, out_channels=3).to(DEVICE)

    # Standard values for Adam beta 1 is 0.9, but paper has 0.5

    # Optimiser for the discriminator and generator
    optim_disc = optim.Adam(
        disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    optim_gen = optim.Adam(
        gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

    # Loss
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    # Load model if we have LOAD_MODEL set to True in configs
    if LOAD_MODEL:
        load_checkpoint(CHECKPOINT_DISC, disc,
                        optim_disc, LEARNING_RATE)
        load_checkpoint(CHECKPOINT_GEN, gen, optim_gen,
                        LEARNING_RATE)


    #transform = transforms.Compose([transforms.ToTensor(),
     #                                     transforms.Normalize([0.485, 0.456, 0.406], 
      #                                                         [0.229, 0.224, 0.225])])
    # Load training dataset
    train_dataset = FlowerDataset(TRAIN_DIR)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                              shuffle=True, num_workers=NUM_WORKERS)

    # float16 training
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    # Load the testing/validation dataset
    test_dataset = FlowerDataset(TEST_DIR)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    


    for epoch in range(NUM_EPOCHS):
        print(f'====\nEpoch {epoch}\n====')

        
        train_func(
            disc, gen, train_loader, optim_disc, optim_gen, L1_LOSS, BCE, g_scaler, d_scaler, epoch,
        )

        if SAVE_MODEL and epoch % SAVE_MODEL_EVERY_NTH == 0:
            save_checkpoint(gen, optim_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(disc, optim_disc, filename=CHECKPOINT_DISC)

        save_some_examples(gen, test_loader, epoch, folder=EXAMPLE_DIR)

    # If by chance epoch number is not divisible, we save the very last epoch
    if SAVE_MODEL and epoch % SAVE_MODEL_EVERY_NTH != 0:
        save_checkpoint(gen, optim_gen, filename=CHECKPOINT_GEN)
        save_checkpoint(disc, optim_disc, filename=CHECKPOINT_DISC)



## Main Loop

In [None]:
main()