<font size="6"> 
<b>
TITLE
</b>
</font>

To start, install the following packages by running the code below.

`pip install ipykernel, numpy, torch, albumentations, pillow, tqdm, torchvision`

In [2]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

import copy

import numpy as np

import os

from PIL import Image

import random

import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from torchvision.utils import save_image

from tqdm import tqdm

<font size="5"> 
<b>
Config
</b>
</font>

The configurations for the GAN.

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Prefer to work on GPU if available
TRAIN_DIR = "Data/train" # TODO: directly install from the internet
VAL_DIR = "Data/val"
BATCH_SIZE = 1 # Should probably be higher?
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.0 
LAMBDA_CYCLE = 10 
NUM_WORKERS = 4
NUM_EPOCHS = 10
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_H = "genh.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar"
CHECKPOINT_CRITIC_H = "critich.pth.tar"
CHECKPOINT_CRITIC_Z = "criticz.pth.tar"

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5), # Double dataset by flipping images
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

<font size="5"> 
<b>
Dataset
</b>
</font>

We define the class of the dataset holding the photos and the Monet pictures, along with its generator, length, and item retrieval function. The retrieval finds the photo and Monet picture corresponding to the given index, finds the images from their respective paths, applies the given transformation on them (if any), and returns the two pictures. 

In [None]:
class PhotoMonetDataset(Dataset):
    def __init__(self, root_photo, root_monet, transform=None):
        self.root_photo = root_photo # dir to photos
        self.root_monet = root_monet # dir to monet pictures
        self.transform = transform

        self.photo_images = os.listdir(root_photo)
        self.monet_images = os.listdir(root_monet)
        self.photo_len = len(self.photo_images)
        self.monet_len = len(self.monet_images)
        self.length_dataset = max(self.photo_len, self.monet_len)

    def __len__(self):
        return self.length_dataset
    
    def __getitem__(self,index):
        photo_img = self.photo_images[index % self.photo_len ] # preventing index errors
        monet_img = self.monet_images[index % self.monet_len ]

        photo_path = os.path.join(self.root_photo, photo_img)
        monet_path = os.path.join(self.root_monet, monet_img)
        print(f"Trying to load: {monet_path}")
        if not os.path.exists(monet_path):
            print(f"❌ ERROR: File not found -> {monet_path}")
        photo_img = np.array(Image.open(photo_path).convert("RGB"))
        monet_img = np.array(Image.open(monet_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=photo_img, image0=monet_img)
            photo_img = augmentations["image"]
            monet_img = augmentations["image0"]

        return photo_img, monet_img

<font size="5"> 
<b>
Models
</b>
</font>

Here we describe what is being done, and why

<font size="4"> 
<b>
Discriminator model
</b>
</font>

__Block__: Shorthand for a convolution block, consisting of a 2d convolution layer, a normalization, and a ReLU. Used to define the discriminatory model.

__Discriminator__: Classifies images into "Real" and "Fake". Trained along side the other layers, and is itself used as a cost function to the other layers

In [None]:
import torch
import torch.nn as nn

class Block(nn.Module):   # inheriting from nn. Module
    def __init__(self, in_channels, out_channels, stride ):
        super().__init__()                          # is a way to call the constructor of a parent class in Python. It ensures that the parent class (nn.Module in PyTorch) is properly initialized when a child class is created.
        self.conv =nn.Sequential(
            nn.Conv2d(in_channels,out_channels,4,stride,1,bias=True,padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64,128,256,512]):
        super().__init__()
        self.initial = nn.Sequential(nn.Conv2d(in_channels,features[0],kernel_size=4,stride=2,padding=1, padding_mode="reflect"),nn.LeakyReLU(0.2))
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride = 1 if feature == features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels,1,kernel_size=4,stride=1,padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers) # unwrapping the list

    def forward(self,x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

    def test():
        x = torch.randn((5, 3, 256, 256))
        model = Discriminator(in_channels=3)
        preds = model(x)
        print(preds.shape)

In [None]:
# Run this block to test the discriminator
Discriminator.test()

<font size="4"> 
<b>
Generator model
</b>
</font>

__ConvBlock__: WHAT IS MEANT WITH DOWN AND UPSAMPLING?? TODO: look up "up and down" in GANs

__ResidualBlock__: Consists of two ConvBlock(down=True). Forward has a residual to avoid 0 gradients in deep networks.

__Generator__: TODO

In [None]:
class ConvBlock(nn.Module):    # Down and upsampling
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs) :#key word arguments
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels,out_channels,**kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self,x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block= nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1, stride=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1)
        )

    def forward(self,x):
        return x+ self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.ReLU(inplace=True)
        )
        self.down_blocks = nn.ModuleList(
           [ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
            ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1)]
        )
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range (num_residuals) ]
        )
        self.up_blocks = nn.ModuleList(
            [ConvBlock(num_features*4, num_features*2,down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
             ConvBlock(num_features*2, num_features*1,down=False, kernel_size=3, stride=2, padding=1, output_padding=1)]
        )
        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7,stride=1, padding=3, padding_mode="reflect")

    def forward (self,x):
        x= self.initial(x)
        for layer in self.down_blocks:
            x= layer(x)
        x=self.residual_blocks(x)
        for layer in self.up_blocks:
            x= layer(x)
        return torch.tanh(self.last(x))

def generator_test():
    img_channels = 3
    img_size =256
    x= torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels,64,9 )
    print(gen(x).shape)


In [None]:
# Run to test the generator
generator_test()

<font size="5"> 
<b>
Utils
</b>
</font>

Utility functions used in training.

In [None]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

<font size="5"> 
<b>
Training
</b>
</font>

What are we doing and why

In [None]:
def train_fn(disc_P, disc_M, gen_P, gen_M, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
    loop = tqdm(loader, leave=True)    # progress bar
    for idx, (monet, photo) in enumerate(loop):
        print(f"Entering iteration {idx}")
        photo = photo.to(DEVICE)
        monet = monet.to(DEVICE)

        # Train Discriminators H and Z.
        with torch.amp.autocast('cuda'):
            fake_Photo= gen_P(monet)
            D_P_real= disc_P(photo)
            D_P_fake= disc_P(fake_Photo.detach())
            D_P_real_loss = mse(D_P_real, torch.ones_like(D_P_real))
            D_P_fake_loss = mse(D_P_fake, torch.zeros_like(D_P_fake))
            D_P_loss = D_P_fake_loss+D_P_real_loss

            fake_Monet = gen_M(photo)
            D_M_real = disc_P(monet)
            D_M_fake = disc_P(fake_Monet.detach())
            D_M_real_loss = mse(D_M_real, torch.ones_like(D_M_real))
            D_M_fake_loss = mse(D_M_fake, torch.zeros_like(D_M_fake))
            D_M_loss = D_M_fake_loss + D_M_real_loss

            D_loss= (D_P_loss+D_M_loss)/2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generators P and M
        with ((torch.amp.autocast('cuda'))):
            # Adverserial loss
            D_P_fake = disc_P(fake_Photo)
            D_M_fake=  disc_M(fake_Monet)
            Loss_G_M= mse(D_M_fake, torch.ones_like(D_M_fake))
            Loss_G_P= mse(D_P_fake, torch.ones_like(D_P_fake))
            # Cycle loss
            cycle_monet= gen_M(fake_Photo)
            cycle_photo= gen_P(fake_Monet)
            cycle_monet_loss= l1(cycle_monet, monet)
            cycle_photo_loss= l1(cycle_photo,photo)
            # Identitiy loss
            identity_photo= gen_P(photo)
            identity_monet= gen_M(monet)
            identity_monet_loss= l1(identity_photo,photo)
            identity_photo_loss= l1(identity_monet,monet)
            # Add all together
            G_loss = (Loss_G_M+Loss_G_P
            + cycle_monet_loss* LAMBDA_CYCLE
            + cycle_photo_loss* LAMBDA_CYCLE
            + identity_monet_loss * LAMBDA_IDENTITY
            + identity_photo_loss * LAMBDA_IDENTITY)

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 20 == 0:
            print("hello") # TODO: change this maybe
            save_image(fake_Photo * 0.5 + 0.5, f"saved_images/photo_{idx}.png")
            save_image(fake_Monet * 0.5 + 0.5, f"saved_images/monet_{idx}.png")

<font size="5"> 
<b>
Main
</b>
</font>

Main function

In [None]:
def main():
    disc_P = Discriminator(in_channels=3).to(DEVICE)
    disc_M = Discriminator(in_channels=3).to(DEVICE)
    gen_P = Generator(img_channels=3, num_residuals=9). to (DEVICE)
    gen_M = Generator(img_channels=3, num_residuals=9). to (DEVICE)
    opt_disc = optim.Adam(
        list(disc_P.parameters()) + list(disc_M.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_P.parameters()) + list(gen_M.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_P,
            gen_P,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_GEN_M,
            gen_M,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_P,
            disc_P,
            opt_disc,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_M,
            disc_M,
            opt_disc,
            LEARNING_RATE,
        )
    # These checkpoint files allow the training process to resume from where it left off, without starting over from scratch.
    dataset = PhotoMonetDataset(
        root_photo=TRAIN_DIR + "/Photo",
        root_monet=TRAIN_DIR + "/Monet",
        transform=transforms,
    )
    # val_dataset = PhotoMonetDataset(
    #     root_photo="cyclegan_test/photo1",
    #     root_monet="cyclegan_test/monet1",
    #     transform=transforms,
    # )
    # val_loader = DataLoader(
    #     val_dataset,
    #     batch_size=1,
    #     shuffle=False,
    #     pin_memory=True,
    # )
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    g_scaler = torch.amp.GradScaler('cuda')
    d_scaler = torch.amp.GradScaler('cuda')

    for epoch in range(NUM_EPOCHS):
        train_fn(
            disc_P,
            disc_M,
            gen_P,
            gen_M,
            loader,
            opt_disc,
            opt_gen,
            L1,
            mse,
            d_scaler,
            g_scaler,
        )

        if SAVE_MODEL:
            save_checkpoint(gen_P, opt_gen, filename=CHECKPOINT_GEN_H)
            save_checkpoint(gen_M, opt_gen, filename=CHECKPOINT_GEN_Z)
            save_checkpoint(disc_P, opt_disc, filename=CHECKPOINT_CRITIC_H)
            save_checkpoint(disc_M, opt_disc, filename=CHECKPOINT_CRITIC_Z)

In [None]:
main()