In [2]:
import torch
import torch.nn as nn

In [58]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/train"
VAL_DIR = "data/val"
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 500
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"

both_transform = A.Compose(
    [A.Resize(width=256, height=256),], additional_targets={"image0": "image"},is_check_shapes=False
)

transform_only_input = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.2),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ],
    is_check_shapes=False
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ],
    is_check_shapes=False
)

In [39]:
# UTILS

import torch
from torchvision.utils import save_image

def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr



In [40]:
class CNNBlock(nn.Module):
  def __init__(self, in_channel, out_channel, stride=2):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channel, out_channel, 4, stride, bias=False, padding_mode="reflect"),
        nn.BatchNorm2d(out_channel),
        nn.LeakyReLU(0.2)

    )

  def forward(self, x):
    return self.conv(x)


# x, y <- concatenate these along the channels

class Discriminator(nn.Module):
  def __init__(self, in_channels=3, features=[64, 128, 256, 512]): # 256 input -> 30x30 output
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_channels * 2, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
        nn.LeakyReLU(0.2)
    )
    layers = []
    in_channels = features[0]
    for feature in features[1:]:
      layers.append(
          CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2)
      )
      in_channels = feature

    layers.append(
        nn.Conv2d(
            in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
        )
    )

    self.model = nn.Sequential(*layers)

  def forward(self, x, y):
    x = torch.cat([x,y], dim=1)
    x = self.initial(x)
    return self.model(x)


def test():
  x = torch.randn((1, 3, 256, 256))
  y = torch.randn((1, 3, 256, 256))
  model = Discriminator()
  preds = model(x,y)
  print(preds.shape)

test()

torch.Size([1, 1, 26, 26])


# UNET


In [41]:
class Block(nn.Module):
  def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4,2,1,bias=False,padding_mode="reflect")
        if down
        else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        # nn.InstanceNorm2d(out_channels, affine=True), # ACCORDING TO CYCLE GAN PAPER
        nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
    )
    self.use_dropout = use_dropout
    self.dropout = nn.Dropout(0.5)

  def forward(self, x):
    x = self.conv(x)
    return self.dropout(x) if self.use_dropout else x


class Generator(nn.Module):
  def __init__(self, in_channels=3, features=64):
    super().__init__()
    self.initial_down = nn.Sequential(
        nn.Conv2d(in_channels, features, 4,2,1, padding_mode="reflect"),
        nn.LeakyReLU(0.2)
    )

    self.down1 = Block(features, features*2, down=True, act="leaky", use_dropout=False) #64
    self.down2 = Block(features*2, features*4, down=True, act="leaky", use_dropout=False) #32
    self.down3 = Block(features*4, features*8, down=True, act="leaky", use_dropout=False) #16
    self.down4 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False) #8
    self.down5 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False) #4
    self.down6 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False) #2
    self.bottleneck = nn.Sequential(
        nn.Conv2d(features*8, features*8, 4, 2, 1, padding_mode="reflect"), nn.ReLU(),
    )
    self.up1 = Block(features*8, features*8, down=False, act="relu", use_dropout=True)
    self.up2 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
    self.up3 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
    self.up4 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=False)
    self.up5 = Block(features*8*2, features*4, down=False, act="relu", use_dropout=False)
    self.up6 = Block(features*4*2, features*2, down=False, act="relu", use_dropout=False)
    self.up7 = Block(features*2*2, features, down=False, act="relu", use_dropout=False)
    self.final_up = nn.Sequential(
        nn.ConvTranspose2d(features*2, in_channels, kernel_size=4, stride=2, padding=1),
        nn.Tanh(),
    )

  def forward(self,x):
    d1 = self.initial_down(x)
    d2 = self.down1(d1)
    d3 = self.down2(d2)
    d4 = self.down3(d3)
    d5 = self.down4(d4)
    d6 = self.down5(d5)
    d7 = self.down6(d6)
    bottleneck = self.bottleneck(d7)
    up1 = self.up1(bottleneck)
    up2 = self.up2(torch.cat([up1, d7], 1))
    up3 = self.up3(torch.cat([up2, d6], 1))
    up4 = self.up4(torch.cat([up3, d5], 1))
    up5 = self.up5(torch.cat([up4, d4], 1))
    up6 = self.up6(torch.cat([up5, d3], 1))
    up7 = self.up7(torch.cat([up6, d2], 1))

    return self.final_up(torch.cat([up7, d1], 1))



def test():
  x = torch.randn((1,3,256,256))
  model = Generator(in_channels=3, features=64)
  preds = model(x)
  print(preds.shape)

test()






torch.Size([1, 3, 256, 256])


# Dataset Loading

In [42]:
!pip install -q kaggle

In [12]:
!ls -lha kaggle.json
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/

!kaggle datasets download -d ktaebum/anime-sketch-colorization-pair

-rw-r--r-- 1 root root 71 Nov 22 10:36 kaggle.json
Downloading anime-sketch-colorization-pair.zip to /content
100% 11.6G/11.6G [02:26<00:00, 127MB/s]
100% 11.6G/11.6G [02:26<00:00, 85.2MB/s]


In [13]:
!unzip anime-sketch-colorization-pair.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: data/train/2906115.png  
  inflating: data/train/2906116.png  
  inflating: data/train/2906119.png  
  inflating: data/train/2906139.png  
  inflating: data/train/2906140.png  
  inflating: data/train/2906141.png  
  inflating: data/train/2906143.png  
  inflating: data/train/2907043.png  
  inflating: data/train/2907052.png  
  inflating: data/train/2907059.png  
  inflating: data/train/2907062.png  
  inflating: data/train/2907075.png  
  inflating: data/train/2907105.png  
  inflating: data/train/2907107.png  
  inflating: data/train/2907108.png  
  inflating: data/train/2907113.png  
  inflating: data/train/2907130.png  
  inflating: data/train/2907138.png  
  inflating: data/train/2907144.png  
  inflating: data/train/2907146.png  
  inflating: data/train/2907148.png  
  inflating: data/train/2907149.png  
  inflating: data/train/2908002.png  
  inflating: data/train/2908023.png  
  inflating: data/train

In [15]:
DEVICE

'cuda'

In [57]:
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset

class MapDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        input_image = image[:, 600:, :]
        target_image = image[:, :600, :]

        augmentations = both_transform(image=input_image, image0=target_image)
        input_image = augmentations["image"]
        target_image = augmentations["image0"]

        input_image = transform_only_input(image=input_image)["image"]
        target_image = transform_only_mask(image=target_image)["image"]

        return input_image, target_image

dataset = MapDataset("data/train")
len(dataset)
# loader = DataLoader(dataset, batch_size=5)
# for x, y in loader:
#     print(x.shape)
#     save_image(x, "x.png")
#     save_image(y, "y.png")
#     import sys

#     sys.exit()

14224

# TRAIN

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from tqdm import tqdm # PROGRESS BAR
from torchvision.utils import save_image

torch.backends.cudnn.benchmark = True

def train_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler):
  loop = tqdm(loader, leave=True)

  for idx, (x, y) in enumerate(loop):
      x = x.to(DEVICE)
      y = y.to(DEVICE)

      # Train Discriminator
      with torch.cuda.amp.autocast():
          y_fake = gen(x)
          D_real = disc(x, y)
          D_real_loss = bce(D_real, torch.ones_like(D_real))
          D_fake = disc(x, y_fake.detach())
          D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
          D_loss = (D_real_loss + D_fake_loss) / 2

      disc.zero_grad()
      d_scaler.scale(D_loss).backward()
      d_scaler.step(opt_disc)
      d_scaler.update()

      # Train generator
      with torch.cuda.amp.autocast():
          D_fake = disc(x, y_fake)
          G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
          L1 = l1_loss(y_fake, y) * L1_LAMBDA # SPECIFIED IN PAPER TO COMPUTE L1 LOSS
          G_loss = G_fake_loss + L1

      opt_gen.zero_grad()
      g_scaler.scale(G_loss).backward()
      g_scaler.step(opt_gen)
      g_scaler.update()

      if idx % 10 == 0:
          loop.set_postfix(
              D_real=torch.sigmoid(D_real).mean().item(),
              D_fake=torch.sigmoid(D_fake).mean().item(),
          )
    #loss.backward(retain_graph=True) alternative to y_fake.detach() for avoid breaking the computational graph


def main():
  disc = Discriminator(in_channels=3).to(DEVICE)
  gen = Generator(in_channels=3).to(DEVICE)
  opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  BCE = nn.BCEWithLogitsLoss()
  L1_LOSS = nn.L1Loss() # You just do different combinations to check which is best and authors already specified that

  if LOAD_MODEL:
    load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)
    load_checkpoint(CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE)

  train_dataset = MapDataset(root_dir="data/train")
  train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

  # FOR LOW VRAM
  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()

  val_dataset = MapDataset(root_dir="data/val") # Validation dataset
  val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

  for epoch in range(NUM_EPOCHS):
    train_fn(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler)

    if SAVE_MODEL and epoch % 5 == 0:
      save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
      save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)

    save_some_examples(gen, val_loader, epoch, folder="evaluation") # if ERROR: Create a folder evaluation



main()


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:28<00:00,  4.27it/s, D_fake=0.48, D_real=0.387]


=> Saving checkpoint
=> Saving checkpoint


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:32<00:00,  4.18it/s, D_fake=0.383, D_real=0.279]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:32<00:00,  4.19it/s, D_fake=0.216, D_real=0.882]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:29<00:00,  4.24it/s, D_fake=0.275, D_real=0.64]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:28<00:00,  4.27it/s, D_fake=0.278, D_real=0.474]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:28<00:00,  4.27it/s, D_fake=0.334, D_real=0.795]


=> Saving checkpoint
=> Saving checkpoint


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:32<00:00,  4.19it/s, D_fake=0.205, D_real=0.663]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:30<00:00,  4.22it/s, D_fake=0.291, D_real=0.734]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:29<00:00,  4.24it/s, D_fake=0.249, D_real=0.621]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:29<00:00,  4.24it/s, D_fake=0.397, D_real=0.544]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:29<00:00,  4.25it/s, D_fake=0.353, D_real=0.585]


=> Saving checkpoint
=> Saving checkpoint


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:35<00:00,  4.12it/s, D_fake=0.195, D_real=0.896]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:35<00:00,  4.12it/s, D_fake=0.229, D_real=0.514]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:32<00:00,  4.18it/s, D_fake=0.377, D_real=0.452]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:31<00:00,  4.21it/s, D_fake=0.428, D_real=0.322]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:30<00:00,  4.23it/s, D_fake=0.257, D_real=0.78]


=> Saving checkpoint
=> Saving checkpoint


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:35<00:00,  4.13it/s, D_fake=0.118, D_real=0.833]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:33<00:00,  4.17it/s, D_fake=0.295, D_real=0.771]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:30<00:00,  4.23it/s, D_fake=0.478, D_real=0.543]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:31<00:00,  4.20it/s, D_fake=0.41, D_real=0.747]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:30<00:00,  4.21it/s, D_fake=0.391, D_real=0.705]


=> Saving checkpoint
=> Saving checkpoint


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:33<00:00,  4.17it/s, D_fake=0.428, D_real=0.56]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:33<00:00,  4.17it/s, D_fake=0.204, D_real=0.77]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:29<00:00,  4.24it/s, D_fake=0.428, D_real=0.643]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:29<00:00,  4.24it/s, D_fake=0.173, D_real=0.695]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:30<00:00,  4.23it/s, D_fake=0.148, D_real=0.837]


=> Saving checkpoint
=> Saving checkpoint


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:33<00:00,  4.17it/s, D_fake=0.271, D_real=0.851]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:30<00:00,  4.23it/s, D_fake=0.268, D_real=0.654]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:32<00:00,  4.19it/s, D_fake=0.336, D_real=0.448]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:30<00:00,  4.22it/s, D_fake=0.0599, D_real=0.929]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 889/889 [03:30<00:00,  4.22it/s, D_fake=0.166, D_real=0.64]


=> Saving checkpoint
=> Saving checkpoint


 83%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Ž | 742/889 [02:56<00:31,  4.66it/s, D_fake=0.042, D_real=0.959]

In [54]:
import os

num_workers = os.cpu_count()
num_workers

2

In [60]:
!sudo rm -r evaluation/*.png