# Inspired by
https://github.com/aitorzip/PyTorch-CycleGAN

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os

monet_files = []
photo_files = []

for dirname, _, filenames in os.walk('/kaggle/input'):
    print(dirname)
    if dirname == "/kaggle/input/gan-getting-started/monet_jpg":
        for filename in filenames:
            monet_files.append(os.path.join(dirname, filename))
    elif dirname == "/kaggle/input/gan-getting-started/photo_jpg":
        for filename in filenames:
            photo_files.append(os.path.join(dirname, filename))
    

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

print(len(monet_files))
print(len(photo_files))


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable

import itertools
from PIL import Image

from tqdm import tqdm
import matplotlib.pyplot as plt

import random

# Dataset Def

Note, that the dataset has an imbalance, hence we are using only 1/7 of the training pictures and loop over the monet pictures.

Improve this in the future
- Add transforms to artificially create more monet images (random crop, horizontal flip etc.)

In [None]:
class GANDataset(Dataset):
    def __init__(self,photo_list,monet_list,transform_monet,transform_photo):
        #self.photo_list = photo_list
        #self.monet_list = monet_list
        self.transform_monet = transform_monet
        self.transform_photo = transform_photo
        self.data = {'Photo': photo_list, 'Monet': monet_list}
        self.len_p = len(photo_list)
        self.len_m = len(monet_list)
        
    def __len__(self):
        #return int(np.ceil(self.len_p/7)) # use only the first 7th of the data
        return self.len_p
        #return self.len_p / 
    
    def __getitem__(self, idx):
        # Due to the imbalence (len_p=7028, len_m=300) we still can use all the photos if we loop through the monet pictures
        imgP_path = os.path.join(self.data['Photo'][idx])
        imgM_path = os.path.join(self.data['Monet'][idx%self.len_m])
        imgP = Image.open(imgP_path)
        imgM = Image.open(imgM_path)
        return {'Photo': self.transform_photo(imgP), 'Monet': self.transform_monet(imgM)}
    
class ImgAugment(object):
    
    def __init__(self,mean,std):
        self.data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean,std),
            ])
        self.data_augment = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean,std),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            #transforms.Resize([400,400]),
            #transforms.RandomCrop(256),
            #transforms.Resize([400,400]),
            #transforms.RandomRotation(degrees=90),
            #transforms.CenterCrop(256),
            #transforms.GaussianBlur(11, sigma=(0.1, 2.0))
            ])
    def __call__(self,img):
        rand = random.random()
        if(rand < 0.2):
            return self.data_augment(img)
        else:
            return self.data_transform(img)   
    
    # you ran: random flips + resize + crop + gaussian blur DONE
    # todo: resize + crop + gaussian blur THIS RUN
    # todo: resize + crop DONE
    # todo: gaussian blur DONE

class ImgAugmentRandChoice(object):
    def __init__(self,mean,std):
        self.data_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std),
            transforms.RandomChoice([transforms.RandomHorizontalFlip(p=0.5),
                                    transforms.RandomVerticalFlip(p=0.5),
                                    transforms.GaussianBlur(5),
                                    transforms.Compose([
                                        transforms.Resize([400,400]),
                                        transforms.RandomRotation(degrees=90),
                                        transforms.CenterCrop(256)])
                                    ]),
                                    transforms.Compose([
                                        transforms.Resize([400,400]),
                                        transforms.RandomCrop(256)
                                    ])
            ])
    def __call__(self,img):
        rand = random.random()
        if(rand < 0.2):
            return self.data_augment(img)
        else:
            return self.data_transform(img)   
    
class ImgTransform(object):
    def __init__(self,mean,std):
        self.data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean,std)
        ])
    def __call__(self,img):
        return self.data_transform(img)
    
#reference: https://www.kaggle.com/adrianda/cyclegan-pytorch-style-transfer
def reverse_normalize(image, mean_=0.5, std_=0.5):
    if torch.is_tensor(image):
        image = image.cpu().detach()
    un_normalized_img = image * std_ + mean_
    un_normalized_img = un_normalized_img * 255
    #return un_normalized_img
    return np.uint8(un_normalized_img)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant(m.bias.data, 0.0)

# Testing DataLoader

In [None]:
train_dataset = GANDataset(
    photo_list=photo_files, monet_list=monet_files, 
    transform_monet = ImgAugment(mean=0.5,std=0.5), 
    transform_photo = ImgTransform(mean=0.5, std=0.5),
)
batch_size = 1
train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                                               batch_size=batch_size, shuffle=True)
dl_iter = iter(train_dataloader)

data = next(dl_iter)
p = data['Photo']
m = data['Monet']

p = reverse_normalize(torch.squeeze(p)).transpose(1,2,0)
m = reverse_normalize(torch.squeeze(m)).transpose(1,2,0)

print(p.shape)

plt.imshow(p)
plt.show()
plt.imshow(m)
plt.show()

# Define models

https://github.com/aitorzip/PyTorch-CycleGAN/blob/master/models.py

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block       
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

# Init

- Maybe include a learningrate that changes

In [None]:
train_dataset = GANDataset(
    photo_list=photo_files, monet_list=monet_files, 
    transform_monet = ImgAugment(mean=0.5,std=0.5), 
    #transform_monet = ImgAugmentRandChoice(mean=0.5,std=0.5),
    transform_photo = ImgTransform(mean=0.5, std=0.5),
)


batch_size = 1
epochs = 6


train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                                               batch_size=batch_size, shuffle=True)

# Define accelerator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

netG_P2M = Generator(3,3).to(device)
netG_M2P = Generator(3,3).to(device)
netD_P = Discriminator(3).to(device)
netD_M = Discriminator(3).to(device)

netG_P2M.apply(weights_init_normal)
netG_M2P.apply(weights_init_normal)
netD_P.apply(weights_init_normal)
netD_M.apply(weights_init_normal)

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

optimizer_G = torch.optim.Adam(itertools.chain(netG_P2M.parameters(), netG_M2P.parameters()),
                                lr=2e-4, betas=(0.5, 0.999))
optimizer_D_P = torch.optim.Adam(netD_P.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_D_M = torch.optim.Adam(netD_M.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
input_P = Tensor(batch_size, 3, 256, 256)
input_M = Tensor(batch_size, 3, 256, 256)
target_real = Variable(Tensor(batch_size).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(batch_size).fill_(0.0), requires_grad=False)


# Training Loop

In [None]:
losses_G = []
losses_DP = []
losses_DM = []
losses_IP = []
losses_IM = []

best_loss = 999.9

for epoch in range(epochs):
    loss_G_sum = 0
    loss_DP_sum = 0
    loss_DM_sum = 0
    loss_IP_sum = 0
    loss_IM_sum = 0
    
    n_iterations = 0
    
    for i,batch in enumerate(train_dataloader):
        # Set model input
        real_P = Variable(input_P.copy_(batch['Photo']))
        real_M = Variable(input_M.copy_(batch['Monet']))
        
         ###### Generators A2B and B2A ######
        optimizer_G.zero_grad()

        # Identity loss
        # G_P2M(M) should equal M if real M is fed
        same_M = netG_P2M(real_M)
        loss_identity_M = criterion_identity(same_M, real_M)*5.0
        # G_M2P(P) should equal P if real P is fed
        same_P = netG_M2P(real_P)
        loss_identity_P = criterion_identity(same_P, real_P)*5.0
        
        # GAN loss
        fake_P = netG_P2M(real_P)
        pred_fake = netD_M(fake_P)
        loss_GAN_P2M = criterion_GAN(pred_fake, target_real)

        fake_M = netG_M2P(real_M)
        pred_fake = netD_P(fake_M)
        loss_GAN_M2P = criterion_GAN(pred_fake, target_real)
        
         # Cycle loss
        recovered_P = netG_M2P(fake_M)
        loss_cycle_PMP = criterion_cycle(recovered_P, real_P)*10.0

        recovered_M = netG_P2M(fake_P)
        loss_cycle_MPM = criterion_cycle(recovered_M, real_M)*10.0
    
        # Total loss
        loss_G = loss_identity_P + loss_identity_M + loss_GAN_P2M + loss_GAN_M2P + loss_cycle_PMP + loss_cycle_MPM
        loss_G.backward()
        
        optimizer_G.step()
        
        
        ###### Discriminator P ######
        optimizer_D_P.zero_grad()

        # Real loss
        pred_real = netD_P(real_P)
        loss_D_real = criterion_GAN(pred_real, target_real)

        # Fake loss
        #fake_A = fake_A_buffer.push_and_pop(fake_A)
        pred_fake = netD_P(fake_P.detach())
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_P = (loss_D_real + loss_D_fake)*0.5
        loss_D_P.backward()

        optimizer_D_P.step()
        
        ###### Discriminator M ######
        optimizer_D_M.zero_grad()

        # Real loss
        pred_real = netD_M(real_M)
        loss_D_real = criterion_GAN(pred_real, target_real)
        
        # Fake loss
        #fake_B = fake_B_buffer.push_and_pop(fake_B)
        pred_fake = netD_M(fake_M.detach())
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss
        loss_D_M = (loss_D_real + loss_D_fake)*0.5
        loss_D_M.backward()

        optimizer_D_M.step()
        
        loss_G_sum += loss_G.cpu().detach().numpy()
        loss_DP_sum += loss_D_P.cpu().detach().numpy()
        loss_DM_sum += loss_D_M.cpu().detach().numpy()
        loss_IP_sum += loss_identity_P.cpu().detach().numpy()
        loss_IM_sum += loss_identity_M.cpu().detach().numpy()
        
        n_iterations += 1
        
    if epoch > 10 and best_loss > loss_G:
        torch.save(netG_P2M.state_dict(), './Photo2Monet_Gen.pt')
        best_loss = loss_G
        print("Model saved E: {}".format(epoch))
    losses_G.append(loss_G_sum/n_iterations)
    losses_DP.append(loss_DP_sum/n_iterations)
    losses_DM.append(loss_DM_sum/n_iterations)
    losses_IP.append(loss_IP_sum/n_iterations)
    losses_IM.append(loss_IM_sum/n_iterations)
    
    print("Epoch: {} lossGen: {} lossDP: {} lossDM: {}".format(epoch,loss_G,loss_D_P,loss_D_M))

# Plot Losses

In [None]:
plt.plot(losses_G)
plt.title("Loss Generator")
plt.show()
np.save("./lossG.npy",losses_G)

plt.plot(losses_DP)
plt.title("Loss Discriminator Photos")
plt.show()
np.save("./lossDP.npy",losses_DP)

plt.plot(losses_DM)
plt.title("Loss Discriminator Monet")
plt.show()
np.save("./lossDM.npy",losses_DM)

plt.plot(losses_IP)
plt.title("Loss Identity Photos")
plt.show()
np.save("./lossIP.npy",losses_IP)

plt.plot(losses_IM)
plt.title("Loss Identity Monet")
plt.show()
np.save("./lossIM.npy",losses_IM)


# Reverse

In [None]:
! mkdir ../images

#load the best network
for file in os.listdir('./'):
    if file.endswith('pt'):
        print(file)
        netG_P2M.load_state_dict(torch.load('./{0}/{1}'.format('./',file)))
#
transform = ImgTransform(mean=0.5,std=0.5)
i = 1
for img in photo_files:
    imgP_trans = transform(Image.open(os.path.join(img)))
    imgP_trans = imgP_trans[None,:,:,:]
    imgP_trans = imgP_trans.to(device)
    fakeM = netG_P2M(imgP_trans)
    fakeM = torch.squeeze(fakeM)
    fakeM = reverse_normalize(fakeM)
    fakeM = fakeM.transpose(1,2,0) # bring color channel to last
    if i < 10:
        #plot 10 fake images
        plt.subplot(121)
        plt.imshow(reverse_normalize(torch.squeeze(imgP_trans)).transpose(1,2,0))
        plt.subplot(122)
        plt.imshow(fakeM)
        plt.show()
        
    im = Image.fromarray(fakeM)
    im.save("../images/" + str(i) + ".jpg")
    
    
    i += 1
    
    
    
import shutil
shutil.make_archive("/kaggle/working/images",'zip',"/kaggle/images")