## Imports, GPU

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from PIL import Image
from random import randrange


In [None]:
"""
Enable CUDA if the GPU is available
"""
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

## Config

In [None]:
DRIVE_PREFIX = './drive/MyDrive/UChicago/Computer Vision/Vision Final'
TRAIN_DIR = DRIVE_PREFIX + "/train"
VAL_DIR = DRIVE_PREFIX + "/val"
RUN_NUM = 0
BATCH_SIZE = 1 # Changing this to 100 used up all GPU memory
LEARNING_RATE = 1e-5
LAMBDA_CYCLE = 10
NUM_WORKERS = 2
NUM_EPOCHS = 150
LOAD_MODEL = True
SAVE_MODEL = False
CHECKPOINT_GEN_H = DRIVE_PREFIX + "/genh.pth.tar"
CHECKPOINT_GEN_S = DRIVE_PREFIX + "/gens.pth.tar"
CHECKPOINT_CRITIC_H = DRIVE_PREFIX + "/critich.pth.tar"
CHECKPOINT_CRITIC_S = DRIVE_PREFIX + "/critics.pth.tar"

transforms_list = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

## Utils

In [None]:
def save_checkpoint(model, optimizer, filename):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename+'.pth.tar')
    !cp {filename+'.pth.tar'} DRIVE_PREFIX


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

## Load datasets

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Train/test dataloader.
class SimpsonsHumansDataset(Dataset):
    def __init__(self, root_simpsons, root_human, transform=None):
        super().__init__()
        self.root_simpsons = root_simpsons
        self.root_human = root_human
        self.transform = transform

        self.simpsons_images = os.listdir(root_simpsons)
        self.human_images = os.listdir(root_human)
        self.simpsons_len = len(self.simpsons_images)
        self.humans_len = len(self.human_images)
        self.length_dataset = max(self.simpsons_len, self.humans_len)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        simpsons_img = self.simpsons_images[index % self.simpsons_len]
        human_img = self.human_images[index % self.humans_len]

        simpsons_path = os.path.join(self.root_simpsons, simpsons_img)
        humans_path = os.path.join(self.root_human, human_img)

        simpsons_img = np.array(Image.open(simpsons_path).convert("RGB"))
        human_img = np.array(Image.open(humans_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=simpsons_img, image0=human_img)
            simpsons_img = augmentations["image"]
            human_img = augmentations["image0"]

        return simpsons_img, human_img

## Generator

In [None]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
        if down
        else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(inplace=True) if use_act else nn.Identity()
    )

  def forward(self, x):
    return self.conv(x)

class ResidualBlock(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.block = nn.Sequential(
        ConvBlock(channels, channels, kernel_size=3, padding=1),
        ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1)
    )

  def forward(self, x):
    return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                img_channels,
                num_features,
                kernel_size=7,
                stride=1,
                padding=3,
                padding_mode="reflect",
            ),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(
                    num_features, num_features * 2, kernel_size=3, stride=2, padding=1
                ),
                ConvBlock(
                    num_features * 2,
                    num_features * 4,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                ),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(
                    num_features * 4,
                    num_features * 2,
                    down=False,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                ),
                ConvBlock(
                    num_features * 2,
                    num_features * 1,
                    down=False,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                ),
            ]
        )

        # To RGB
        self.last = nn.Conv2d(
            num_features * 1,
            img_channels,
            kernel_size=7,
            stride=1,
            padding=3,
            padding_mode="reflect",
        )

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

## Discriminator

In [None]:
class Block(nn.Module):
  def __init__(self, in_channels, out_channels, stride):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
        nn.InstanceNorm2d(out_channels), # normalizes each sample, not entire batch
        nn.LeakyReLU(0.2)
    )

  def forward(self, x):
    return self.conv(x)

class Discriminator(nn.Module):
  def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
        nn.LeakyReLU(0.2)
    )

    layers=[]
    in_channels = features[0]
    for feature in features[1:]:
      layers.append(Block(in_channels, feature, stride=1 if features[-1] else 2)) # join (downsize conv2d blocks) until last block
      in_channels = feature
    layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect')) # add final conv2d
    self.model = nn.Sequential(*layers)

  def forward(self, x):
    x = self.initial(x)
    return torch.sigmoid(self.model(x))

## Training

In [None]:
def train_fn(
    disc_H, disc_S, gen_S, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)
    avg_loss_D, avg_loss_G = None, None

    for idx, (simpsons, human) in enumerate(loop):
        simpsons = simpsons.to(device)
        human = human.to(device)

        # Train Discriminators H and S
        with torch.cuda.amp.autocast():
            fake_human = gen_H(simpsons)
            D_H_real = disc_H(human)
            D_H_fake = disc_H(fake_human.detach())
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_simpsons = gen_S(human)
            D_S_real = disc_S(simpsons)
            D_S_fake = disc_S(fake_simpsons.detach())
            D_S_real_loss = mse(D_S_real, torch.ones_like(D_S_real))
            D_S_fake_loss = mse(D_S_fake, torch.zeros_like(D_S_fake))
            D_S_loss = D_S_real_loss + D_S_fake_loss

            # put it togethor
            D_loss = (D_H_loss + D_S_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and S
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_human)
            D_S_fake = disc_S(fake_simpsons)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_S = mse(D_S_fake, torch.ones_like(D_S_fake))

            # cycle loss
            cycle_simpsons = gen_S(fake_human)
            cycle_human = gen_H(fake_simpsons)
            cycle_simpsons_loss = l1(simpsons, cycle_simpsons)
            cycle_human_loss = l1(human, cycle_human)

            # add all togethor
            G_loss = (
                loss_G_S
                + loss_G_H
                + cycle_simpsons_loss * LAMBDA_CYCLE
                + cycle_human_loss * LAMBDA_CYCLE
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        # Keep track of loss over time
        if avg_loss_D is None:
          avg_loss_D = D_loss
        else:
          avg_loss_D = avg_loss_D + ((1 / idx + 1)) * (D_loss - avg_loss_D)
        if avg_loss_G is None:
          avg_loss_G = G_loss
        else:
          avg_loss_G = avg_loss_G + ((1 / idx + 1)) * (G_loss - avg_loss_G)

        # Save images as we go, except for every image in the last epoch
        if idx % 200 == 0 or epoch == NUM_EPOCHS - 1:
            save_image(fake_human * 0.5 + 0.5, DRIVE_PREFIX + f"/mid/human{epoch}_{idx}_{RUN_NUM}.png")
            save_image(fake_simpsons * 0.5 + 0.5, DRIVE_PREFIX + f"/mid/simpsons{epoch}_{idx}_{RUN_NUM}.png")

        loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))

        del simpsons, human, 
        fake_human, D_H_loss, D_H_real_loss, D_H_fake_loss,
        fake_simpsons, D_S_loss, D_S_real_loss, D_S_fake_loss,
        D_loss,
        D_H_fake, D_S_fake, loss_G_H, loss_G_S,
        cycle_simpsons, cycle_human, cycle_simpsons_loss, cycle_human_loss,
        G_loss
        # free up GPU memory
    loss_D_f = open(DRIVE_PREFIX + '/loss/training_loss_D.txt', 'a')
    loss_G_f = open(DRIVE_PREFIX + '/loss/training_loss_G`.txt', 'a')
    loss_D_f.write(f'{avg_loss_D}\n')
    loss_G_f.write(f'{avg_loss_G}\n')
    loss_D_f.close()
    loss_G_f.close()

## Testing

In [None]:
# Basically the same as the training code, except we don't update the model.
# Choose one random image to test loss for each epoch, and one consistent image to save to track progress
# If "final" is set to True, then evaluate and save all images
def test_fn(
    disc_H, disc_S, gen_S, gen_H, loader, opt_disc, l1, mse, final=False
):
    rand_idx = randrange(100)
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)
    loss_D_f = open(DRIVE_PREFIX + '/loss/test_loss_D.txt', 'a')
    loss_G_f = open(DRIVE_PREFIX + '/loss/test_loss_G.txt', 'a')

    for idx, (simpsons, human) in enumerate(loop):
        if not final and idx != rand_idx and idx != 0:
          continue
        simpsons = simpsons.to(device)
        human = human.to(device)

        # Train Discriminators H and S
        with torch.cuda.amp.autocast():
            fake_human = gen_H(simpsons)
            D_H_real = disc_H(human)
            D_H_fake = disc_H(fake_human.detach())
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_simpsons = gen_S(human)
            D_S_real = disc_S(simpsons)
            D_S_fake = disc_S(fake_simpsons.detach())
            D_S_real_loss = mse(D_S_real, torch.ones_like(D_S_real))
            D_S_fake_loss = mse(D_S_fake, torch.zeros_like(D_S_fake))
            D_S_loss = D_S_real_loss + D_S_fake_loss

            # put it togethor
            D_loss = (D_H_loss + D_S_loss) / 2

        # Train Generators H and S
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_human)
            D_S_fake = disc_S(fake_simpsons)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_S = mse(D_S_fake, torch.ones_like(D_S_fake))

            # cycle loss
            cycle_simpsons = gen_S(fake_human)
            cycle_human = gen_H(fake_simpsons)
            cycle_simpsons_loss = l1(simpsons, cycle_simpsons)
            cycle_human_loss = l1(human, cycle_human)

            # add all togethor
            G_loss = (
                loss_G_S
                + loss_G_H
                + cycle_simpsons_loss * LAMBDA_CYCLE
                + cycle_human_loss * LAMBDA_CYCLE
            )

        # Keep track of loss over time
        if not final and idx == rand_idx:
            loss_D_f.write(f'{D_loss}\n')
            loss_G_f.write(f'{G_loss}\n')

        # Save first image at every epoch
        if not final and idx == 0:
          save_image(fake_human * 0.5 + 0.5, DRIVE_PREFIX + f"/tests/human{epoch}_{idx}_{RUN_NUM}.png")
          save_image(fake_simpsons * 0.5 + 0.5, DRIVE_PREFIX + f"/tests/simpsons{epoch}_{idx}_{RUN_NUM}.png")

        if final:
          save_image(fake_human * 0.5 + 0.5, DRIVE_PREFIX + f"/final_output/human_{idx}.png")
          save_image(fake_simpsons * 0.5 + 0.5, DRIVE_PREFIX + f"/final_output/simpsons_{idx}.png")

        loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))

        del simpsons, human, 
        fake_human, D_H_loss, D_H_real_loss, D_H_fake_loss,
        fake_simpsons, D_S_loss, D_S_real_loss, D_S_fake_loss,
        D_loss,
        D_H_fake, D_S_fake, loss_G_H, loss_G_S,
        cycle_simpsons, cycle_human, cycle_simpsons_loss, cycle_human_loss,
        G_loss
        # free up GPU memory
    loss_D_f.close()
    loss_G_f.close()

## Run Training

In [None]:
disc_H = Discriminator(in_channels=3).to(device)
disc_S = Discriminator(in_channels=3).to(device)
gen_S = Generator(img_channels=3, num_residuals=9).to(device)
gen_H = Generator(img_channels=3, num_residuals=9).to(device)
opt_disc = optim.Adam(
    list(disc_H.parameters()) + list(disc_S.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

opt_gen = optim.Adam(
    list(gen_S.parameters()) + list(gen_H.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

L1 = nn.L1Loss()
mse = nn.MSELoss()

if LOAD_MODEL:
    load_checkpoint(
        DRIVE_PREFIX+'/CHECKPOINT_GEN_H.pth.tar',
        gen_H,
        opt_gen,
        LEARNING_RATE,
    )
    load_checkpoint(
        DRIVE_PREFIX+'/CHECKPOINT_GEN_S.pth.tar',
        gen_S,
        opt_gen,
        LEARNING_RATE,
    )
    load_checkpoint(
        DRIVE_PREFIX+'/CHECKPOINT_CRITIC_H.pth.tar',
        disc_H,
        opt_disc,
        LEARNING_RATE,
    )
    load_checkpoint(
        DRIVE_PREFIX+'/CHECKPOINT_CRITIC_S.pth.tar',
        disc_S,
        opt_disc,
        LEARNING_RATE,
    )

dataset = SimpsonsHumansDataset(
    root_human=TRAIN_DIR + "/humans",
    root_simpsons=TRAIN_DIR + "/simpsons",
    transform=transforms_list,
)
val_dataset = SimpsonsHumansDataset(
    root_human=VAL_DIR + "/humans",
    root_simpsons=VAL_DIR + "/simpsons",
    transform=transforms_list,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
)
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()


=> Loading checkpoint
=> Loading checkpoint
=> Loading checkpoint
=> Loading checkpoint


In [None]:
for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch}/{NUM_EPOCHS-1}")
    train_fn(
        disc_H,
        disc_S,
        gen_S,
        gen_H,
        loader,
        opt_disc,
        opt_gen,
        L1,
        mse,
        d_scaler,
        g_scaler
    )

    if SAVE_MODEL:
        save_checkpoint(gen_H, opt_gen, filename='CHECKPOINT_GEN_H')
        save_checkpoint(gen_S, opt_gen, filename='CHECKPOINT_GEN_S')
        save_checkpoint(disc_H, opt_disc, filename='CHECKPOINT_CRITIC_H')
        save_checkpoint(disc_S, opt_disc, filename='CHECKPOINT_CRITIC_S')

    test_fn(
        disc_H,
        disc_S,
        gen_S,
        gen_H,
        val_loader,
        opt_disc,
        L1,
        mse
    )

## Generate final images

In [None]:
test_fn(
      disc_H,
      disc_S,
      gen_S,
      gen_H,
      val_loader,
      opt_disc,
      L1,
      mse,
      final=True
  )

100%|██████████| 100/100 [01:52<00:00,  1.13s/it, H_fake=0.393, H_real=0.634]
