# Preparation

In [None]:
# Run this cell if you only opened this notebook in colab, and need to get all the sources
!git clone https://github.com/kokoslik/cartoonface-bot.git
import os
os.chdir('cartoonface-bot/src')

In [None]:
# Importing stuff that we need
import os
from torch.utils.data import DataLoader
import torchvision.transforms as tt
import torch
from torch import nn
from data.dataset import SingleFolderDataset
from model.generator import Generator
from model.discriminator import Discriminator
from utils.image_pool import ImagePool
from utils.train_loop import train

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

# Option 1: Horse2Zebra dataset

Run cells in this section to train on horse2zebra dataset

In [None]:
!wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip

In [None]:
!unzip -q horse2zebra.zip

In [None]:
pic_size=128
means = [0.5,0.5,0.5]
stds = [0.5,0.5,0.5]
transform = tt.Compose([tt.Resize(pic_size), tt.CenterCrop(pic_size), tt.ToTensor(), tt.Normalize(means, stds)])
datasets = {
    'trainA': SingleFolderDataset('horse2zebra/trainB', transform=transform, cache=True),
    'trainB': SingleFolderDataset('horse2zebra/trainA', transform=transform, cache=True),
    'testA':SingleFolderDataset('horse2zebra/testB', transform=transform,cache=True),
    'testB':SingleFolderDataset('horse2zebra/testA', transform=transform,cache=True)
}

# Option 2: Cartoon Faces dataset

Download the faces2k.zip from here:

and unzip it in the src directory

In [None]:
pic_size=128
means = [0.5,0.5,0.5]
stds = [0.5,0.5,0.5]
transformB = tt.Compose([tt.CenterCrop(320), tt.Resize(pic_size), tt.CenterCrop(pic_size), tt.ToTensor(), tt.Normalize(means, stds)])
transformA = tt.Compose([tt.Resize(pic_size), tt.CenterCrop(pic_size), tt.ToTensor(),tt.Normalize(means, stds)])
datasets = {
    'trainA': SingleFolderDataset('faces2k/trainB', transform=transformA, cache=True),
    'trainB': SingleFolderDataset('faces2k/trainA', transform=transformB, cache=True, ext='png'),
    'testA':SingleFolderDataset('faces2k/testB', transform=transformA, cache=True),
    'testB':SingleFolderDataset('faces2k/testA', transform=transformB, cache=True, ext='png')
}

# Dataloaders creation

In [None]:
batch_size = 8
dataloaders = {
    'trainA': DataLoader(datasets['trainA'], num_workers=2, shuffle=True, batch_size=batch_size, drop_last=True),
    'trainB': DataLoader(datasets['trainB'], num_workers=2, shuffle=True, batch_size=batch_size, drop_last=True),
    'testA': DataLoader(datasets['testA'], num_workers=2, shuffle=False, batch_size=batch_size),
    'testB': DataLoader(datasets['testB'], num_workers=2, shuffle=False, batch_size=batch_size)
}

# Models, losses and pools creation

In [None]:
lr = 0.0001
pool_size = 50

model = {
    "genA": Generator(instance_norm=True).to(device),
    "genB": Generator(instance_norm=True).to(device),
    "disA": Discriminator(instance_norm=True).to(device),
    "disB": Discriminator(instance_norm=True).to(device)
}

criterion = {
    "disA": nn.MSELoss(),
    "disB": nn.MSELoss(),
    "cycle": nn.L1Loss(),
    "identity": nn.L1Loss(),
    "gen": nn.MSELoss()
}

pools = {'A': ImagePool(pool_size, pic_size, device,logic='new'),
         'B': ImagePool(pool_size, pic_size, device,logic='new')}

# Training

In [None]:
#dumps folder will be created, and checkpoints will be saved in it every 10 epochs
num_epochs = 200
os.makedirs('dumps', exist_ok=True)
losses = train(dataloaders, model, criterion,num_epochs,lr,pools,device)