# CycleGAN implementation

[Original video](https://youtu.be/4LktBHGCNfw)

[Paper walkthrough video](https://youtu.be/5jziBapziYE)

[Paper](https://arxiv.org/abs/1703.10593)

[Horse2Zebra Dataset 1](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/)

[Horse2Zebra Dataset 2](https://www.kaggle.com/balraj98/horse2zebra-dataset)

[Source code](https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/GANs/CycleGAN)

## Get dataset. Import libraries

In [1]:
# Google CoLab has old version of albumentations library. Update it.
# Maybe after update restart the runtime.
!pip install -U git+https://github.com/albu/albumentations --no-cache-dir

Collecting git+https://github.com/albu/albumentations
  Cloning https://github.com/albu/albumentations to /tmp/pip-req-build-txxbxp6k
  Running command git clone -q https://github.com/albu/albumentations /tmp/pip-req-build-txxbxp6k
Collecting imgaug>=0.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/66/b1/af3142c4a85cba6da9f4ebb5ff4e21e2616309552caca5e8acefe9840622/imgaug-0.4.0-py2.py3-none-any.whl (948kB)
[K     |████████████████████████████████| 952kB 7.4MB/s 
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-0.5.2-cp37-none-any.whl size=88144 sha256=3a4e73ae3c55f3768829a52536c206412c8b4da32312b6c8591d4b7489de77e4
  Stored in directory: /tmp/pip-ephem-wheel-cache-_e_bkee1/wheels/45/8b/e4/2837bbcf517d00732b8e394f8646f22b8723ac00993230188b
Successfully built albumentations
Installing collected packages: imgaug, albumentations
  Found exis

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import multiprocessing
import albumentations as A

from PIL import Image
from torchvision.utils import save_image
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm  # for the progressbar

In [3]:
# # Delete dataset
# !rm -rf testA testB trainA trainB horse2zebra-dataset.zip metadata.csv
# !rm -rf horse2zebra

In [4]:
# # Download dataset from the CycleGAN webpage
# !wget 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip'


# Download dataset from Kaggle

# Info on how to get your api key (kaggle.json) here:
# https://github.com/Kaggle/kaggle-api#api-credentials

# Install kaggle packages if necessary. Not necessary for CoLab
# !pip install -q kaggle
# !pip install -q kaggle-cli

# Colab's file access feature
from google.colab import files

# Upload `kaggle.json` file
uploaded = files.upload()

# Retrieve uploaded file and print results
for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))


# Then copy kaggle.json into the folder where the API expects to find it.
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!ls ~/.kaggle

# Download the dataset
!kaggle datasets download -d balraj98/horse2zebra-dataset
# !kaggle datasets list -s horse2zebra-dataset  # show all visible datasets

Saving kaggle.json to kaggle.json
User uploaded file "kaggle.json" with length 65 bytes
kaggle.json
Downloading horse2zebra-dataset.zip to /content
 83% 93.0M/111M [00:02<00:00, 38.6MB/s]
100% 111M/111M [00:02<00:00, 51.5MB/s] 


In [5]:
# Extract data
import zipfile
import tarfile

def extract(fname):
    if fname.endswith(".tar.gz") or fname.endswith('.tgz'):
        ref = tarfile.open(fname, mode='r:gz')
    elif fname.endswith('.tar'):
        ref = tarfile.open(fname, mode='r:')
    elif fname.endswith('.tar.bz2') or fname.endswith('.tbz'):
        ref = tarfile.open(fname, mode='r:bz2')
    elif fname.endswith('.zip'):
        ref = zipfile.ZipFile(fname, mode='r')

    ref.extractall()
    ref.close()

extract('horse2zebra-dataset.zip')

In [6]:
horse_images = os.listdir('trainA')
zebra_images = os.listdir('trainB')
horse_images_val = os.listdir('testA')
zebra_images_val = os.listdir('testB')
print('horse train:', len(horse_images))
print('zebra train:', len(zebra_images))
print('horse test:', len(horse_images_val))
print('zebra test:', len(zebra_images_val))

horse train: 1067
zebra train: 1334
horse test: 120
zebra test: 140


## Discriminator model

In [7]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride, norm=True):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode='reflect'),
            nn.InstanceNorm2d(out_channels) if norm else nn.Identity(),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        layers = [Block(in_channels, features[0], stride=2, norm=False)]
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels, feature,
                                    stride=1 if feature == features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1,
                                 padding=1, padding_mode='reflect'))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return torch.sigmoid(self.model(x))  # output between (0, 1)


# every cell from 30x30 grid sees 70x70 pixels on the 256x256 image
def test():
    x = torch.randn((8, 3, 256, 256))
    model = Discriminator()
    predictions = model(x)
    print(predictions.shape)
    assert predictions.shape == (8, 1, 30, 30)
    print('Test - OK')

test()

torch.Size([8, 1, 30, 30])
Test - OK


## Generator model

In [8]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode='reflect', **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity(),
        )

    def forward(self, x):
        return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            # stride=1 by default
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):
    # num_residuals=9 if size >= 256x256
    # num_residuals=6 if size <= 128x128
    def __init__(self, img_channels=3, num_features=64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3,
                      padding_mode='reflect'),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList([
            ConvBlock(num_features,   num_features*2, kernel_size=3, stride=2, padding=1),
            ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
        ])
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList([
            # output_padding adds additional padding after the convolutional block
            ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3,
                      stride=2, padding=1, output_padding=1),
            ConvBlock(num_features*2, num_features,   down=False, kernel_size=3,
                      stride=2, padding=1, output_padding=1),
        ])
        self.last = nn.Conv2d(num_features, img_channels, kernel_size=7,
                              stride=1, padding=3, padding_mode='reflect')
        
    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.residual_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))  # output between (-1, 1)


def test():
    x = torch.randn((8, 3, 256, 256))
    model = Generator()
    predictions = model(x)
    print(predictions.shape)
    assert predictions.shape == (8, 3, 256, 256)
    print('Test - OK')

test()

torch.Size([8, 3, 256, 256])
Test - OK


## Configuration

In [24]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TRAIN_DIR = '.'
VAL_DIR = '.'
BATCH_SIZE = 16
LEARNING_RATE = 2e-4
LAMBDA_IDENTIRY = 5.0
LAMBDA_CYCLE = 10.0
NUM_WORKERS = multiprocessing.cpu_count()
NUM_EPOCHS = 20
LOAD_MODEL = True
SAVE_MODEL = True
IMAGE_SIZE = 256
CHANNELS_IMG = 3
CHECKPOINT_DISC_H = "disc-h.pth.tar"
CHECKPOINT_DISC_Z = "disc-z.pth.tar"
CHECKPOINT_GEN_H = "gen-h.pth.tar"
CHECKPOINT_GEN_Z = "gen-z.pth.tar"

transform = A.Compose(
    [
        A.Resize(width=IMAGE_SIZE, height=IMAGE_SIZE),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.1),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={'image0': 'image'},
)

In [25]:
def save_example(gen_horse, gen_zebra, test_loader, epoch, folder):
    horse, zebra = next(iter(test_loader))
    horse = horse.to(DEVICE)
    zebra = zebra.to(DEVICE)
    os.makedirs(folder, exist_ok=True)

    gen_horse.eval()
    gen_zebra.eval()
    with torch.no_grad():
        fake_horse = gen_horse(zebra) * 0.5 + 0.5  # remove normalization
        fake_zebra = gen_zebra(horse) * 0.5 + 0.5
        save_image(fake_horse, folder + f'/fake_horse_{epoch}.jpg')
        save_image(fake_zebra, folder + f'/fake_zebra_{epoch}.jpg')
        if epoch == 0:
            save_image(horse * 0.5 + 0.5, folder + f'/_horse.jpg')
            save_image(zebra * 0.5 + 0.5, folder + f'/_zebra.jpg')
    gen_zebra.train()
    gen_horse.train()


def save_checkpoint(model, optimizer, filename):
    print('=> Saving checkpoint')
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('=> Loading checkpoint')
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    # Replace old learning rate from the saved model
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

## Dataset

In [26]:
class HorseZebraDataset(Dataset):
    def __init__(self, root_horse, root_zebra, transform):
        super().__init__()
        self.root_horse = root_horse
        self.root_zebra = root_zebra
        self.transform = transform

        self.horse_images = os.listdir(self.root_horse)
        self.zebra_images = os.listdir(self.root_zebra)
        self.horse_len = len(self.horse_images)
        self.zebra_len = len(self.zebra_images)
        self.length_dataset = max(self.horse_len, self.zebra_len)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        horse_img = self.horse_images[index % self.horse_len]
        zebra_img = self.zebra_images[index % self.zebra_len]
        horse_path = os.path.join(self.root_horse, horse_img)
        zebra_path = os.path.join(self.root_zebra, zebra_img)
        # *.convert('RGB') for gray scale images (1 channel to 3)
        horse_img = np.array(Image.open(horse_path).convert('RGB'))
        zebra_img = np.array(Image.open(zebra_path).convert('RGB'))
        augmentations = self.transform(image=horse_img, image0=zebra_img)
        horse_img, zebra_img = augmentations['image'], augmentations['image0']
        return horse_img, zebra_img

## Train

In [27]:
def train(disc_h, disc_z, gen_h, gen_z, loader, opt_disc, opt_gen, l1, mse,
          d_scaler, g_scaler):
    loop = tqdm(loader, leave=False)

    for idx, (horse, zebra) in enumerate(loop):
        horse = horse.to(DEVICE)
        zebra = zebra.to(DEVICE)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_horse = gen_h(zebra)
            fake_zebra = gen_z(horse)

            d_h_real = disc_h(horse)
            d_h_fake = disc_h(fake_horse.detach())
            d_z_real = disc_z(zebra)
            d_z_fake = disc_z(fake_zebra.detach())

            d_h_real_loss = mse(d_h_real, torch.ones_like(d_h_real))
            d_h_fake_loss = mse(d_h_fake, torch.zeros_like(d_h_fake))
            d_z_real_loss = mse(d_z_real, torch.ones_like(d_z_real))
            d_z_fake_loss = mse(d_z_fake, torch.zeros_like(d_z_fake))

            d_h_loss = d_h_real_loss + d_h_fake_loss
            d_z_loss = d_z_real_loss + d_z_fake_loss

            d_loss = (d_h_loss + d_z_loss) / 2.0

        opt_disc.zero_grad()
        d_scaler.scale(d_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generator H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            d_h_fake = disc_h(fake_horse)
            d_z_fake = disc_z(fake_zebra)

            g_h_loss = mse(d_h_fake, torch.ones_like(d_h_fake))
            g_z_loss = mse(d_z_fake, torch.ones_like(d_z_fake))

            # cycle loss
            cycle_h_loss = l1(horse, gen_h(fake_zebra))
            cycle_z_loss = l1(zebra, gen_z(fake_horse))

            # Identity loss. Necessary for coloring. Not necessary for shapes
            # identity_h_loss = l1(horse, gen_h(horse))
            # identity_z_loss = l1(zebra, gen_z(zebra))

            # add all together
            g_loss = (
                g_h_loss + g_z_loss
                + (cycle_h_loss + cycle_z_loss) * LAMBDA_CYCLE
                # + (identity_h_loss + identity_z_loss) * LAMBDA_IDENTIRY
            )

        opt_gen.zero_grad()
        g_scaler.scale(g_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 20 == 0:
            os.makedirs('evaluation', exist_ok=True)
            save_image(fake_horse * 0.5 + 0.5, f'evaluation/horse_{idx}.jpg')
            save_image(fake_zebra * 0.5 + 0.5, f'evaluation/zebra_{idx}.jpg')


def main():
    disc_h = Discriminator().to(DEVICE)
    disc_z = Discriminator().to(DEVICE)
    gen_h = Generator().to(DEVICE)
    gen_z = Generator().to(DEVICE)

    opt_disc = optim.Adam(
        list(disc_h.parameters()) + list(disc_z.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )
    opt_gen = optim.Adam(
        list(gen_h.parameters()) + list(gen_z.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    l1 = nn.L1Loss()  # cycle consistency loss
    mse = nn.MSELoss()  # adversary loss

    if (LOAD_MODEL
        and os.path.exists(CHECKPOINT_DISC_H)
        and os.path.exists(CHECKPOINT_DISC_Z)
        and os.path.exists(CHECKPOINT_GEN_H)
        and os.path.exists(CHECKPOINT_GEN_Z)):

        load_checkpoint(CHECKPOINT_DISC_H, disc_h, opt_disc, LEARNING_RATE)
        load_checkpoint(CHECKPOINT_DISC_Z, disc_z, opt_disc, LEARNING_RATE)
        load_checkpoint(CHECKPOINT_GEN_H, gen_h, opt_gen, LEARNING_RATE)
        load_checkpoint(CHECKPOINT_GEN_Z, gen_z, opt_gen, LEARNING_RATE)

    dataset = HorseZebraDataset(root_horse=os.path.join(TRAIN_DIR, 'trainA'),
                                root_zebra=os.path.join(TRAIN_DIR, 'trainB'),
                                transform=transform)
    test_dataset = HorseZebraDataset(root_horse=os.path.join(VAL_DIR, 'testA'),
                                     root_zebra=os.path.join(VAL_DIR, 'testB'),
                                     transform=transform)
    loader = DataLoader(dataset=dataset,
                        batch_size=BATCH_SIZE,
                        shuffle=True,
                        num_workers=NUM_WORKERS,
                        pin_memory=True)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=NUM_WORKERS,
                             pin_memory=True)
    
    d_scaler = torch.cuda.amp.GradScaler()  # for float16 training to save memory
    g_scaler = torch.cuda.amp.GradScaler()

    disc_h.train()
    disc_z.train()
    gen_h.train()
    gen_z.train()

    for epoch in range(NUM_EPOCHS):
        train(disc_h, disc_z, gen_h, gen_z, loader, opt_disc, opt_gen, l1, mse,
              d_scaler, g_scaler)
        
        if SAVE_MODEL and (epoch+1) % 10 == 0:
            print('epoch:', epoch)
            save_checkpoint(disc_h, opt_disc, CHECKPOINT_DISC_H)
            save_checkpoint(disc_z, opt_disc, CHECKPOINT_DISC_Z)
            save_checkpoint(gen_h, opt_gen, CHECKPOINT_GEN_H)
            save_checkpoint(gen_z, opt_gen, CHECKPOINT_GEN_Z)
        
        save_example(gen_h, gen_z, test_loader, epoch, folder='evaluation')


if __name__ == '__main__':
    main()

=> Loading checkpoint
=> Loading checkpoint
=> Loading checkpoint
=> Loading checkpoint


HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))


epoch: 9
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=84.0), HTML(value='')))


epoch: 19
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


In [28]:
from google.colab import drive
drive.mount('/content/gdrive')

copy_to = '/content/gdrive/MyDrive/Colab Notebooks/PyTorch tutorial'
!zip -qr evaluation.zip evaluation/

!cp -rf evaluation.zip     '$copy_to'
!cp -rf $CHECKPOINT_DISC_H '$copy_to'
!cp -rf $CHECKPOINT_DISC_Z '$copy_to'
!cp -rf $CHECKPOINT_GEN_H  '$copy_to'
!cp -rf $CHECKPOINT_GEN_Z  '$copy_to'

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
