In [None]:
import os
print(os.listdir("../input"))

# Reference
GitHub
1. https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/cyclegan/cyclegan.py
2. https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

Paper --
[CycleGan]( https://arxiv.org/pdf/1703.10593.pdf)


In [None]:
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from PIL import Image
import glob
import random
import itertools
import time
import datetime


%matplotlib inline

seed=2018
torch.manual_seed(seed)

In [None]:
from os import walk
for (dirpath, dirnames, filenames) in walk("../input/monet2photo/"):
    print("Directory path: ", dirpath)
    #print("Folder name: ", dirnames)
    #print("Files name: ", filenames)

In [None]:
import imageio
def show_img(file_path):
    img = imageio.imread(file_path)
    plt.imshow(img)
show_img('../input/monet2photo/monet2photo/monet2photo/trainA/00001.jpg')

In [None]:
class ImageDataset(Dataset):
    def __init__(self,root,transforms_=None ,unaligned=False ,mode='train'):
        self.transforms = transforms.Compose(transforms_)
        self.unaligned = unaligned
        self.files_A = sorted(glob.glob(os.path.join(root,'%sA' % mode) + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root,'%sB' % mode) + '/*.*'))
    def __getitem__(self,index):
        item_A = self.transforms(Image.open(self.files_A[index % len(self.files_A)]))
        if self.unaligned:
            item_B = self.transforms(Image.open(self.files_B[random.randint(0,len(self.files_B)-1)]))
        else:
            item_B = self.transforms(Image.open(self.files_B[index % len(self.files_B)]))
        
        return {'A':item_A,'B':item_B}
    def __len__(self):
        return max(len(self.files_A),len(self.files_B))
        

In [None]:
data_root='../input/monet2photo/monet2photo/monet2photo'
workers = 2
test_workers = 1
ngpu = 1
batch_size = 1
test_batch_size = 5
image_size = 256
epoch = 0
n_epochs = 200
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
decay_epoch = 100
nc = 3
sample_interval = 100
n_residual_blocks = 9

In [None]:
cuda = torch.cuda.is_available() and ngpu > 0
print(cuda)
device = 'cuda:0' if cuda else 'cpu'
print(device)

In [None]:
transforms_ = [ transforms.Resize(int(image_size*1.12), Image.BICUBIC),
                transforms.RandomCrop(image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]

In [None]:
dataloader = DataLoader(ImageDataset(data_root, transforms_=transforms_, unaligned=True),
                        batch_size=batch_size, shuffle=True, num_workers=workers)

In [None]:
val_dataloader = DataLoader(ImageDataset(data_root, transforms_=transforms_, unaligned=True, mode='test'),
                        batch_size=test_batch_size, shuffle=True, num_workers=ngpu)

# A is monet and B is photo

In [None]:
train_batch = next(iter(dataloader))
trainAB_batch = torch.cat((train_batch['A'],train_batch['B'],train_batch['A'],train_batch['B']),0)
plt.figure(figsize=(12,12))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(trainAB_batch.to(device), padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

# ResNet

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self,in_features):
        super().__init__()
        
        conv_block = [ 
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        ]
        self.conv_block = nn.Sequential(*conv_block)
    def forward(self,x):
        return x + self.conv_block(x)

# Generator from Perceptual Losses for Real-Time Style Transfer and Super-Resolution
https://arxiv.org/pdf/1603.08155.pdf

In [None]:
class GeneratorResNet(nn.Module):
    def __init__(self,in_channels=3,out_channels=3,res_blocks=9):
        super().__init__()
        
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]
        # Downsample
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features,out_features,3,stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2
        
        for _ in range(res_blocks):
            model+= [ResidualBlock(in_features)]
            
        # Upsampleing
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features,out_features,3,stride=2, padding=1 , output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2
        
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, out_channels, 7),
            nn.Tanh()
        ]
        
        self.model = nn.Sequential(*model)
        
    def forward(self,x):
        return self.model(x)

Input: (N,C,H,W)
[H+2p-d(k-1)-1]/s+1
[256+2-(3)-1]/2+1=128
64


(output_size - 1) * stride + ksize
(128-1)*2+4


In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        
        def discriminator_block(in_filters, out_filters ,normalize = True):
            layers = [
                nn.Conv2d(in_filters, out_filters , 4,stride=2,padding=1)
            ]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2,inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1,0,1,0)), # left,right,top,bottom
            nn.Conv2d(512,1,4,padding=1)
        )
    def forward(self,img):
        return self.model(img)

In [None]:
G_AB = GeneratorResNet(res_blocks=n_residual_blocks)
G_BA = GeneratorResNet(res_blocks=n_residual_blocks)
D_A = Discriminator()
D_B = Discriminator()

In [None]:
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
criterion_GAN = nn.MSELoss() # lsgan

In [None]:
if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()


In [None]:
G_AB.apply(weights_init_normal)
G_BA.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)

In [None]:
lambda_cyc = 10
lambda_identity = 0.5 * lambda_cyc

In [None]:
def sample_images(batch_i):
    imgs = next(iter(val_dataloader))
    real_A = Variable(imgs['A'].to(device))
    real_B = Variable(imgs['B'].to(device))
    fake_A = G_BA(real_B)
    fake_B = G_AB(real_A)
    img_sample = torch.cat((real_A.data,fake_B.data,real_B.data,fake_A.data),0)
    plt.figure(figsize=(12,12))
    plt.axis("off")
    plt.title("Sample Images {}".format(batch_i))
    plt.imshow(np.transpose(vutils.make_grid(img_sample.to(device), padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()),
                                lr=lr, betas=(beta1, beta2))
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(beta1, beta2))

In [None]:
class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs > decay_start_epoch), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch
    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)
        

In [None]:
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)


In [None]:
class ReplayBuffer():
    def __init__(self,max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []
    def push_and_pop(self,data):
        to_return=[]
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

In [None]:
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Training

In [None]:
patch = (1, image_size // 2**4, image_size // 2**4)
print(patch)

In [None]:
prev_time=time.time()
for epoch in range(n_epochs):
    for i, batch in enumerate(dataloader):
        
        # set model input
        real_A = Variable(batch['A'].to(device))
        real_B = Variable(batch['B'].to(device))
        
        valid = Variable(torch.ones((1, *patch)).to(device))
        fake = Variable(torch.zeros((1, *patch)).to(device))
        
        # Train G
        optimizer_G.zero_grad()
        # identityloss
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)
        loss_identity = (loss_id_A + loss_id_B) / 2
        
        # GanLoss
        fake_B = G_AB(real_A)
        loss_GAN_AB=criterion_GAN(D_B(fake_B),valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA=criterion_GAN(D_A(fake_A),valid)
        loss_GAN = (loss_GAN_AB+loss_GAN_BA)/2
        
        # CycleGan
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A,real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B,real_B)
        loss_cycle = (loss_cycle_A+loss_cycle_B)/2
        
        # Totloss
        loss_G = loss_GAN + lambda_cyc*loss_cycle + lambda_identity*loss_identity
        loss_G.backward()
        optimizer_G.step()
        
        # Train D_A
        optimizer_D_A.zero_grad()
        loss_real = criterion_GAN(D_A(real_A),valid)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()),fake)
        loss_D_A = (loss_real+loss_fake)/2
        loss_D_A.backward()
        optimizer_D_A.step()
        # Train D_B
        optimizer_D_B.zero_grad()
        loss_real = criterion_GAN(D_B(real_B),valid)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()),fake)
        loss_D_B = (loss_real+loss_fake)/2
        loss_D_B.backward()
        optimizer_D_B.step()
        
        loss_D = (loss_D_A+loss_D_B)/2
        
        # Log_process
        batch_done = epoch * len(dataloader) + i
        batches_left = n_epochs * len(dataloader) - batch_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()
        print('[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}, adv: {}, cycle: {}, identity: {}] ETA: {}'.format(
                                                        epoch, n_epochs,
                                                        i, len(dataloader),
                                                        loss_D.item(), loss_G.item(),
                                                        loss_GAN.item(), loss_cycle.item(),
                                                        loss_identity.item(), time_left))
        if batch_done % sample_interval==0:
            sample_images(batch_done)
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()