In [None]:
#IMPORT LIBRARIES
import os
import time
import datetime
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.optim.lr_scheduler import LambdaLR
from PIL import Image

In [None]:
#OPTIONS
choices = ['ae_photos', 'apple2orange', 'cezanne2photo','cityscapes',
          'facades', 'horse2zebra', 'iphone2dslr_flower', 'maps',
          'monet2photo', 'summer2winter_yosemite', 'ukiyoe2photo','vangogh2photo']
set_dataset_name = choices[5]  
set_dataset_dir = 'cyclegan_data'
set_debugs_dir = f'cyclegan_debugs/{set_dataset_name}'
set_outimages_dir = f'cyclegan_images/{set_dataset_name}'
set_outmodels_dir = f'cyclegan_models/{set_dataset_name}'
set_G_AB_file = f'{set_outmodels_dir}/G_AB' 
set_G_BA_file = f'{set_outmodels_dir}/G_BA' 
set_D_AB_file = f'{set_outmodels_dir}/D_AB' 
set_D_BA_file = f'{set_outmodels_dir}/D_BA' 
set_random_seed = 1
set_epoch_start = 0
set_epoch_decay = 50
set_epoch_end = 100
set_batch_size = 2 ##if run out of memory, choose a smaller number for set_batch_size
set_lr = 0.0002
set_beta1 = 0.5 
set_beta2 = 0.999 
set_num_cpu = 1
set_save_interval = 100 ##if out of storage, choose a larger number for set_save_interval
set_lambda_cyc = 10.0
set_lambda_id = 5.0

In [None]:
#DOWNLOAD DATASET
download_file = set_dataset_name + '.zip'
url = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/' + download_file
!mkdir -p $set_dataset_dir
%cd $set_dataset_dir
!wget -N $url
!unzip -o $download_file 
%cd -

In [None]:
#DATASET
##define custom dataset
class CustomDataset(Dataset):
    def __init__(self, root, train=True, transforms=None):
        self.train = train
        if train:
            path_A, dirs_A, file_names_A = next(os.walk(root + '/trainA/'))
            path_B, dirs_B, file_names_B = next(os.walk(root + '/trainB/'))
        else:
            path_A, dirs_A, file_names_A  = next(os.walk(root + '/testA/'))
            path_B, dirs_B, file_names_B = next(os.walk(root + '/testB/'))
        self.abs_paths_A = [path_A + name for name in file_names_A]
        self.abs_paths_B = [path_B + name for name in file_names_B]
        ##transform
        self.transforms = transforms
        
    def convert_rgb(self, image):
        if image.mode != "RGB":
            image = Image.new("RGB", image.size)
            image.paste(image)
        return image

    def __getitem__(self, index):
        ###pair (image_A, image_B) 
        index_A = random.randint(0, len(self.abs_paths_A) - 1)
        index_B = random.randint(0, len(self.abs_paths_B) - 1)
        image_A = Image.open(self.abs_paths_A[index_A])
        image_B = Image.open(self.abs_paths_B[index_B])
        ###Convert grayscale images to rgb
        image_A = self.convert_rgb(image_A)
        image_B = self.convert_rgb(image_B)
        ###transform to tensor    
        if self.transforms is not None:
            image_A = self.transforms(image_A)
            image_B = self.transforms(image_B)
        return {"A": image_A, "B": image_B }

    def __len__(self):
        return max(len(self.abs_paths_A ), len(self.abs_paths_B))


##image transformations
custom_transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
    transforms.Resize((286,286)),
    transforms.RandomCrop((256,256)),
    transforms.ToTensor(),###to range [0, 1]
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])###to range [-1 1]
    
##create dataloader    
train_dataset = CustomDataset(f'{set_dataset_dir}/{set_dataset_name}', True, custom_transform)
val_dataset = CustomDataset(f'{set_dataset_dir}/{set_dataset_name}', False, custom_transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=set_batch_size, shuffle=True, num_workers=set_num_cpu)
val_loader = DataLoader(dataset=val_dataset, batch_size=5, shuffle=True, num_workers=set_num_cpu)
save_condition = set_save_interval <= len(train_loader)
assert save_condition, f'''
The set_save_interval must be less than or equal to total number batches in one epoch of train dataset,
so it must lie between [1, {len(train_loader)})
'''


In [None]:
#DEFINE BASE MODELS
class ConvolutionalBlock(nn.Module):
    def __init__(self, kernel_name, in_channels, out_channels, kernel_size, stride, padding, activation_name=None):
        super(ConvolutionalBlock, self).__init__()
        kernel_layer = getattr(nn, kernel_name)
        self.covolutional_block = nn.Sequential(
            kernel_layer(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        )
        if activation_name is not None:
            activation = getattr(nn, activation_name) 
            self.covolutional_block.add_module('activation', activation())
                
    def forward(self, x):
        return self.covolutional_block(x)
##--------------------------------------------------------------------------------------------------------------   
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.residual = nn.Sequential(
            ConvolutionalBlock('Conv2d', in_channels, in_channels, 3, 1, 1, 'ReLU'),
            ConvolutionalBlock('Conv2d', in_channels, in_channels, 3, 1, 1, 'ReLU'))

    def forward(self, x):
        shortcut = x ###shortcut path
        out = self.residual(x) ###main path
        out += shortcut ###gather path
        return out
##--------------------------------------------------------------------------------------------------------------
class MultipleResiduals(nn.Module):
    def __init__(self, in_channels, num_repeat):
        super(MultipleResiduals, self).__init__()
        index = list(range(num_repeat))
        self.multiple_residuals = nn.Sequential()
        for i in range(num_repeat):
            self.multiple_residuals.add_module(f'{index[i]}th_multiple_residuals', ResidualBlock(in_channels))
    
    def forward(self, x):
        return self.multiple_residuals(x)    

<img src="static/depict/cyclegan_generator.png" style="width:70%"/>

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        ###generator is similar to autoencoder
        self.generator = nn.Sequential(
            ###encoder part
            ### formular to calculate padding for Conv2d
            ### at https://sebastianraschka.com/pdf/lecture-notes/stat479ss19/L13_intro-cnn-part2_slides.pdf
                ### (i - k + 2*p)/s + 1 = o round floor
                ### => p = (s(o-1) - i + k)/2 ###round up
                ### if s=1, o=i then p = (1(i-1) - i + k)/2 = (k-1)/2
            ###kernel_name, in_channels, out_channels, kernel_size, stride, padding, activation_name
            ConvolutionalBlock('Conv2d', 3, 64, 7, 1, 3, 'ReLU'), ### i
            ConvolutionalBlock('Conv2d', 64, 128, 3, 2, 1, 'ReLU'), ### i/2
            ConvolutionalBlock('Conv2d', 128, 256, 3, 2, 1, 'ReLU'), ### i/4
            
            ###transformer part
            MultipleResiduals(256, 6), ### i/4
            
            ###decoder part
            ### calculate padding for ConvTranspose2d:
                ###o = s(n-1) + k - 2p
                ###if s=2 then output = 2(n-1) + k - 2p = 2n - 2 + k -2p
            ###kernel_name, in_channels, out_channels, kernel_size, stride, padding, activation_name
            ConvolutionalBlock('ConvTranspose2d', 256, 128, 2, 2, 0, 'ReLU'), ### i/2
            ConvolutionalBlock('ConvTranspose2d', 128, 64, 2, 2, 0, 'ReLU'), ### i
            ConvolutionalBlock('Conv2d', 64, 3, 7, 1, 3, 'Tanh') ### i
        )

    def forward(self, x):
        return self.generator(x) ###output size(set_batch_size, 3, 256, 256)

<img src="static/depict/cyclegan_discriminator.png" style="width:70%"/>

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            ###kernel_name, in_channels, out_channels, kernel_size, stride, padding, activation_name
            ConvolutionalBlock('Conv2d', 3, 64, 4, 2, 1, 'LeakyReLU'), ### i/2
            ConvolutionalBlock('Conv2d', 64, 128, 4, 2, 1, 'LeakyReLU'), ### i/4
            ConvolutionalBlock('Conv2d', 128, 256, 4, 2, 1, 'LeakyReLU'), ### i/8
            ConvolutionalBlock('Conv2d', 256, 512, 4, 2, 1, 'LeakyReLU'), ### i/16
            nn.ZeroPad2d((1, 0, 1, 0)),###reference at https://pytorch.org/docs/master/generated/torch.nn.ZeroPad2d.html
            ConvolutionalBlock('Conv2d', 512, 1, 4, 1, 1, 'Sigmoid'), ### i/16
        )
        
    def forward(self, x):
        return self.discriminator(x) ###output size(set_batch_size, 1, 16, 16)

In [None]:
#SETUP MODELS
## initialize models
torch.manual_seed(set_random_seed)
G_AB = Generator()
G_BA = Generator()
D_A = Discriminator()
D_B = Discriminator()
    

##device
###one gpu or cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
G_AB = G_AB.to(device)
G_BA = G_BA.to(device)
D_A = D_A.to(device)
D_B = D_B.to(device)
parallel=False
### multiple gpus
if device.type != 'cpu' and torch.cuda.device_count() > 1:
    parallel = True
    G_AB = nn.DataParallel(G_AB)
    G_BA = nn.DataParallel(G_BA)
    D_A = nn.DataParallel(D_A)
    D_B = nn.DataParallel(D_B)
    

## load pretrained model weights if you train model in several times
if set_epoch_start > 0:  
    ###in order to use pretrained model weight, you must rename the last weight G_AB@number1_number2.pt to G_AB.pt 
    set_G_AB_file = f'{set_outmodels_dir}/G_AB.pt' 
    ###in order to use pretrained model weight, you must rename the last weight G_BA@number1_number2.pt to G_BA.pt 
    set_G_BA_file = f'{set_outmodels_dir}/G_BA.pt' 
    ###in order to use pretrained model weight, you must rename the last weight D_AB@number1_number2.pt to D_AB.pt 
    set_D_AB_file = f'{set_outmodels_dir}/D_AB.pt' 
    ###in order to use pretrained model weight, you must rename the last weight D_BA@number1_number2.pt to D_BA.pt 
    set_D_BA_file = f'{set_outmodels_dir}/D_BA.pt' 
    
    models = [G_AB, G_BA, D_A, D_B]
    model_files = [set_G_AB_file, set_G_BA_file, set_D_AB_file, set_D_BA_file]
    for f, m in zip(model_files, models):
        if parallel:
            m.module.load_state_dict(torch.load(f))
        else:
            m.load_state_dict(torch.load(f))             

##optimizers
optim_G_AB = torch.optim.Adam(G_AB.parameters(), lr=set_lr, betas=(set_beta1, set_beta2))
optim_G_BA = torch.optim.Adam(G_BA.parameters(), lr=set_lr, betas=(set_beta1, set_beta2))
optim_D_A = torch.optim.Adam(D_A.parameters(), lr=set_lr, betas=(set_beta1, set_beta2))
optim_D_B = torch.optim.Adam(D_B.parameters(), lr=set_lr, betas=(set_beta1, set_beta2))


##learing rate scheduler
condition_1 = set_epoch_decay > set_epoch_start
condition_2 = set_epoch_decay < set_epoch_end
assert condition_1 and  condition_2, 'The set_epoch_decay value must lie between (set_epoch_start, set_epoch_end)'
lambda_func = lambda epoch: 1 - max(0, epoch + set_epoch_start - set_epoch_decay ) / (set_epoch_end - set_epoch_decay)
scheduler_G_AB = LambdaLR(optim_G_AB, lr_lambda=lambda_func)
scheduler_G_BA = LambdaLR(optim_G_BA, lr_lambda=lambda_func)
scheduler_D_A = LambdaLR(optim_D_A, lr_lambda=lambda_func)
scheduler_D_B = LambdaLR(optim_D_B, lr_lambda=lambda_func)

##generate cyclegan image from val_loader input
def sample_images(image_name):
    batch = next(iter(val_loader))
    with torch.no_grad():
        real_A = batch["A"].to(device)
        real_B = batch["B"].to(device)
        fake_A = G_BA(real_B)
        fake_B = G_AB(real_A)
    ### Arange images along x-axis
    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    ### Arange images along y-axis
    image = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    ###save image
    if os.path.exists(set_outimages_dir) == False:
        os.makedirs(set_outimages_dir)
    set_outimage_file = f'{set_outimages_dir}/{image_name}.png'
    save_image(image, set_outimage_file, normalize=False)
    return image
device

In [None]:
# TRAINING 
print('Please wait for training ...', end="")
for epoch in range(set_epoch_start, set_epoch_end):
    for batch_idx, batch in enumerate(train_loader):
        #starting time for every batch
        time_start = time.time()
        ##Inputs And Ground Truths
        real_A = batch["A"].to(device)
        real_B = batch["B"].to(device)
        p_size = (real_A.size(0), 1, 16, 16) 
        valid = torch.ones(p_size).float().to(device)
        fake = torch.zeros(p_size).float().to(device)
        
        ##Train Generators
        ###start from input A
        fake_B = G_AB(real_A)
        cycle_A = G_BA(fake_B)
        id_A = G_BA(real_A)
        ###start from input B
        fake_A = G_BA(real_B)
        cycle_B = G_AB(fake_A)
        id_B = G_AB(real_B)
        ###loss gan from input A
        loss_gan_AB = F.l1_loss(D_B(fake_B), valid)
        loss_cycle_A = F.l1_loss(cycle_A, real_A)
        loss_id_A = F.l1_loss(id_A, real_A)
        ###loss gan from input B
        loss_gan_BA = F.l1_loss(D_A(fake_A), valid)
        loss_cycle_B = F.l1_loss(cycle_B, real_B)
        loss_id_B = F.l1_loss(id_B, real_B)
        ###total loss_generator
        loss_gan = 0.5 * (loss_gan_AB + loss_gan_BA)
        loss_cycle = 0.5 * (loss_cycle_A + loss_cycle_B)
        loss_identity = 0.5 * (loss_id_A + loss_id_B)
        loss_G = loss_gan + set_lambda_cyc * loss_cycle + set_lambda_id * loss_identity
        ###backward and update gradient
        optim_G_AB.zero_grad()
        optim_G_BA.zero_grad()
        loss_G.backward()
        optim_G_AB.step()
        optim_G_BA.step()
        

        ##Train Discriminator
        ###loss discriminator from input A
        p_real_A = D_A(real_A)
        p_fake_A = D_A(fake_A.detach())
        loss_real_A = F.mse_loss(p_real_A, valid)
        loss_fake_A = F.mse_loss(p_fake_A, fake)
        ###loss discriminator from input B
        p_real_B = D_B(real_B)
        p_fake_B = D_B(fake_B.detach())
        loss_real_B = F.mse_loss(p_real_B, valid)
        loss_fake_B = F.mse_loss(p_fake_B, fake)
        with torch.no_grad():
            mp_fake_A = p_fake_A.mean().item()####it is the mean of probabilities in p_fake_A
            mp_real_A = p_real_A.mean().item()####it is the mean of probabilities in p_real_A
            mp_fake_B = p_fake_B.mean().item()####it is the mean of probabilities in p_fake_B
            mp_real_B = p_real_B.mean().item()####it is the mean of probabilities in p_real_B
        ##total loss discrimnator
        loss_D_A = 0.5 * (loss_real_A + loss_fake_A)
        loss_D_B = 0.5 * (loss_real_B + loss_fake_B)
        loss_D = 0.5 * (loss_D_A + loss_D_B)
        ##backward and update gradient
        optim_D_A.zero_grad()
        optim_D_B.zero_grad()
        loss_D.backward()
        optim_D_A.step()
        optim_D_B.step()
        
            
        ##for every set_save_interval
        if batch_idx % set_save_interval == 0:
            ### calulate time remain
            if (epoch == set_epoch_start and batch_idx == 0):
                print('...')
                continue
            batches_per_epoch = len(train_loader)
            batches_total = (set_epoch_end - set_epoch_start) * batches_per_epoch
            batches_complete = (epoch - set_epoch_start) * batches_per_epoch + batch_idx
            batches_remain = batches_total - batches_complete
            time_remain = datetime.timedelta(seconds=batches_remain * (time.time() - time_start))
            time_remain = str(time_remain).split(".")[0] ###remove microsecond part
            ###log text
            print('P(real_A): %.10f' % (mp_real_A))
            print('P(fake_A): %.10f' % (mp_fake_A))
            print('P(real_B): %.10f' % (mp_real_B))
            print('P(fake_B): %.10f' % (mp_fake_B))
            print('Epoch:%03d/%03d | Batch:%03d/%03d | D:%.3f | G:%.3f | adv:%.3f | cyc:%.3f | id:%.3f | remain: %s' 
                   %(epoch+1, set_epoch_end, batch_idx, batches_per_epoch,
                   loss_D, loss_G, loss_gan, loss_cycle, loss_identity,
                   time_remain))
            print('======================================================================================================')
            ###save images
            image_name = f'{epoch+1}_{batch_idx}'
            sample_images(image_name)
            ###save models
            models = [G_AB, G_BA, D_A, D_B]
            files = [set_G_AB_file, set_G_BA_file, set_D_AB_file, set_D_BA_file]
            for file, model in zip(files, models):
                if not os.path.exists(set_outmodels_dir):
                    os.makedirs(set_outmodels_dir)
                if parallel:
                    torch.save(model.module.state_dict(), file + f'@{epoch+1}_{batch_idx}.pt')##save weight model when running in parallel
                else:
                    torch.save(model.state_dict(), file + f'@{epoch+1}_{batch_idx}.pt')##save weight model when running without parallel
    ##schedule learning rate
    scheduler_G_AB.step()
    scheduler_G_BA.step()        
    scheduler_D_A.step()
    scheduler_D_B.step()  
    

In [None]:
#RELEASE GPU(s) MEMORY IF USING GPU(s)
del G_AB
del G_BA
del D_A
del D_B
torch.cuda.empty_cache()