In [None]:
# download dataset and unzip it to train
!wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip
!unzip horse2zebra.zip

In [2]:

#discriminator model
import torch
import torch.nn as nn

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,
                out_channels, 
                kernel_size = 4, 
                stride = stride, 
                padding = 1, 
                bias=True, 
                padding_mode='reflect'), #padding model reflect helps to deal with artifacts
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    def forward(self, x):
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels = 3, features = [64, 128, 256, 512]):
        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size = 4,
                stride = 2,
                padding = 1,
                padding_mode = 'reflect'
            ),
            nn.LeakyReLU(0.2)
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride = 1 if feature == features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(
            in_channels, 
            out_channels = 1, 
            kernel_size=4,
            stride = 1,
            padding = 1,
            padding_mode = 'reflect'
        ))

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))
    

def test():
    x = torch.randn((5, 3, 256, 256))
    model = Discriminator(in_channels = 3)
    pred = model(x)
    print(pred.shape)
test()

torch.Size([5, 1, 30, 30])


In [3]:
# generator

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act = True, **kwargs):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode= 'reflect', **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act = False, kernel_size = 3, padding=1)
        )
    
    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals = 9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size = 7, stride = 1, padding = 3, padding_mode = 'reflect'),
            nn.ReLU(inplace=True)
        )
        self.down_blocks = nn.ModuleList([
            ConvBlock(num_features, num_features*2, kernel_size=3, stride = 2, padding = 1),
            ConvBlock(num_features*2, num_features*4, kernel_size=3, stride = 2, padding = 1)
        ])
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList([
            ConvBlock(num_features*4, num_features*2, down=False, kernel_size = 3, stride = 2, padding = 1, output_padding=1),
            ConvBlock(num_features*2, num_features, down=False, kernel_size = 3, stride = 2, padding = 1, output_padding=1),
        ])
        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size = 7, stride = 1, padding=3, padding_mode='reflect')

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x= layer(x)
        x = self.residual_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))


def test():
    img_channels = 3
    img_size = 256
    x = torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels, 9)
    print(gen(x).shape)
test()

torch.Size([2, 3, 256, 256])


In [13]:
# config file
# import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TRAIN_DIR = '/content/horse2zebra'
VAL_DIR = '/content/horse2zebra'
BATCH_SIZE = 1
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 4
NUM_EPOCHES = 10
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_H = '/content/cpt/gen_horse.pth.tar'
CHECKPOINT_GEN_Z = '/content/cpt/gen_zebra.pth.tar'
CHECKPOINT_CRITIC_H = '/content/cpt/critic_horse.pth.tar'
CHECKPOINT_CRITIC_Z = '/content/cpt/critic_zebra.pth.tar'

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets = {'image0': 'image'}
)

In [14]:
#@title Default title text
# utils.py

import random, torch, os
import numpy as np
# import config
import copy

def save_checkpoint(model, optimizer, filename = 'my_checkpoint.pth.tar'):
    print('{} \n saving checkpoint'.format('=*'*25))
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer_dict': optimizer.state_dict()
    }

    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('{} \n loading checkpoint'.format('=*'*25))
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_dict'])
    # including learning rate for stable training
    for param in optimizer.param_groups:
        param['lr'] = lr

def seed_everything(seed=42):
    os.environment['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministics = True
    torch.backends.cudnn.benchmarcks = False

In [15]:
# dataset loading

import torch
from torch.utils.data import Dataset
from PIL import Image
import os

# import config

class HorseZebraDataset(Dataset):
    def __init__(self, root_horse, root_zebra, transform=None):
        self.root_horse = root_horse
        self.root_zebra = root_zebra
        self.transform = transform

        self.horse_images = os.listdir(self.root_horse)
        self.zebra_images = os.listdir(self.root_zebra)
        self.horse_len = len(self.horse_images)
        self.zebra_len = len(self.zebra_images)
        self.length_dataset = max(len(self.horse_images), len(self.zebra_images))

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        horse_img = self.horse_images[index % self.horse_len]
        zebra_img = self.zebra_images[index % self.zebra_len]

        horse_path = os.path.join(self.root_horse, horse_img)
        zebra_path = os.path.join(self.root_zebra, zebra_img)

        horse_img = np.array(Image.open(horse_path).convert('RGB'))
        zebra_img = np.array(Image.open(zebra_path).convert('RGB'))

        if self.transform:
            augmentations = self.transform(image=zebra_img, image0=horse_img)
            zebra_img = augmentations['image']
            horse_img = augmentations['image0']
        
        return zebra_img, horse_img

In [16]:

# train file

import torch
# from dataset import HorseZebraDataset
import sys
# from utils import save_checkpoint, load_checkpoint
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
# import config
from tqdm import tqdm
from torchvision.utils import save_image
# from discriminator_model import Discriminator
# from generator_model import Generator

def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (zebra, horse) in enumerate(loop):
        zebra = zebra.to(DEVICE)
        horse = horse.to(DEVICE)
        with torch.cuda.amp.autocast():
            fake_horse = gen_H(zebra)
            D_H_real = disc_H(horse)
            D_H_fake = disc_H(fake_horse.detach())
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_zebra = gen_H(horse)
            D_Z_real = disc_H(zebra)
            D_Z_fake = disc_H(fake_zebra.detach())
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss

            D_loss = (D_H_loss + D_Z_loss)

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        with torch.cuda.amp.autocast():
            D_H_fake = disc_H(fake_horse)
            D_Z_fake = disc_Z(fake_zebra)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            cycle_zebra = gen_Z(fake_horse)
            cycle_horse = gen_H(fake_zebra)
            cycle_zebra_loss = l1(zebra, cycle_zebra)
            cycle_horse_loss = l1(horse, cycle_horse)

            identity_zebra = gen_Z(zebra)
            identity_horse = gen_H(horse)
            identity_zebra_loss = l1(zebra, identity_zebra)
            identity_horse_loss = l1(horse, identity_horse)

            # add all togethor
            G_loss = (
                loss_G_Z
                + loss_G_H
                + cycle_zebra_loss * LAMBDA_CYCLE
                + cycle_horse_loss * LAMBDA_CYCLE
                + identity_horse_loss * LAMBDA_IDENTITY
                + identity_zebra_loss * LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 200 == 0:
            save_image(fake_horse*0.5+0.5, f"/content/saved_images/horse_{idx}.png")
            save_image(fake_zebra*0.5+0.5, f"/content/saved_images/zebra_{idx}.png")

        loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))



def main():
    disc_H = Discriminator(in_channels=3).to(DEVICE)
    disc_Z = Discriminator(in_channels=3).to(DEVICE)
    gen_H = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    gen_Z = Generator(img_channels=3, num_residuals=9).to(DEVICE)

    opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr = LEARNING_RATE,
        betas = (0.5, 0.999)
    )
    opt_gen = optim.Adam(
        list(gen_H.parameters()) + list(gen_Z.parameters()),
        lr = LEARNING_RATE,
        betas = (0.5, 0.999)
    )

    l1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_H, gen_H, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_GEN_Z, gen_Z, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_H, disc_H, opt_disc, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, LEARNING_RATE,
        )
    dataset = HorseZebraDataset(TRAIN_DIR+'/trainA', TRAIN_DIR+'/trainB', transform=transforms)

    loader = DataLoader(
        dataset,
        batch_size= BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHES):
        train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler)

        if SAVE_MODEL:
            save_checkpoint(gen_H, opt_gen, filename=CHECKPOINT_GEN_H)
            save_checkpoint(gen_Z, opt_gen, filename=CHECKPOINT_GEN_Z)
            save_checkpoint(disc_H, opt_disc, filename=CHECKPOINT_CRITIC_H)
            save_checkpoint(disc_Z, opt_disc, filename=CHECKPOINT_CRITIC_Z)

# if __name__ == '__main__':
os.makedirs('/content/saved_images', exist_ok = True)
os.makedirs('/content/cpt', exist_ok= True)

main()

100%|██████████| 1334/1334 [05:19<00:00,  4.17it/s, H_fake=0, H_real=0]


=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint


100%|██████████| 1334/1334 [05:22<00:00,  4.14it/s, H_fake=0, H_real=0]


=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint


100%|██████████| 1334/1334 [05:22<00:00,  4.14it/s, H_fake=0, H_real=0]


=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint


100%|██████████| 1334/1334 [05:22<00:00,  4.14it/s, H_fake=0, H_real=0]


=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint


100%|██████████| 1334/1334 [05:21<00:00,  4.14it/s, H_fake=0, H_real=0]


=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint


100%|██████████| 1334/1334 [05:21<00:00,  4.14it/s, H_fake=0, H_real=0]


=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint


100%|██████████| 1334/1334 [05:21<00:00,  4.15it/s, H_fake=0, H_real=0]


=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint


100%|██████████| 1334/1334 [05:22<00:00,  4.14it/s, H_fake=0, H_real=0]


=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint


100%|██████████| 1334/1334 [05:22<00:00,  4.14it/s, H_fake=0, H_real=0]


=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint


100%|██████████| 1334/1334 [05:21<00:00,  4.15it/s, H_fake=0, H_real=0]


=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=* 
 saving checkpoint
