In [4]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    print(dirname, len(filenames))
    
#    for filename in filenames:
#        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
from torchvision.utils import make_grid

from PIL import Image
import matplotlib.pyplot as plt
import shutil

print(torch.__version__)
if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'
print(device)

Residual block and generator

In [6]:
class ResBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResBlock, self).__init__()
        self.reflect = nn.ReflectionPad2d(1)
        self.conv1 = nn.Conv2d(in_channels, in_channels, 3)
        self.relu = nn.ReLU(inplace=True)
        self.norm = nn.InstanceNorm2d(in_channels)
        
    def forward(self, x):
        origin_x = x
        x = self.reflect(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.reflect(x)
        x = self.conv1(x)
        x = self.norm(x)
        return origin_x + x

class Generator(nn.Module):
    def __init__(self, img_channels=3, out_channels=64, residual_num=9):
        super(Generator, self).__init__()
        self.conv = nn.Sequential(
            nn.ReflectionPad2d(img_channels),
            nn.Conv2d(img_channels, out_channels, kernel_size=2*img_channels+1, stride=1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        self.down_sample = nn.Sequential(nn.Conv2d(out_channels, out_channels*2, kernel_size=img_channels, stride=2, padding=1),
                                  nn.InstanceNorm2d(out_channels*2),
                                  nn.ReLU(inplace=True),
                                  nn.Conv2d(out_channels*2, out_channels*4, kernel_size=img_channels, stride=2, padding=1),
                                  nn.InstanceNorm2d(out_channels*4),
                                  nn.ReLU(inplace=True),
        )
        
        self.resnet = [ResBlock(out_channels*4) for _ in range(residual_num)]
        self.resnet = nn.Sequential(*self.resnet)

        self.up_sample = nn.Sequential(nn.Upsample(scale_factor=2), 
                                nn.Conv2d(4*out_channels, 2*out_channels, kernel_size=img_channels, stride=1, padding=1),
                                nn.InstanceNorm2d(out_channels*2),
                                nn.ReLU(inplace=True),
                                nn.Upsample(scale_factor=2),
                                nn.Conv2d(2*out_channels, out_channels, kernel_size=img_channels, stride=1, padding=1),
                                nn.InstanceNorm2d(out_channels),
                                nn.ReLU(inplace=True),
        )
        self.out_put = nn.Sequential(
            nn.ReflectionPad2d(img_channels),
            nn.Conv2d(out_channels, img_channels, kernel_size=2*img_channels+1),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.conv(x)
        x = self.down_sample(x)
        x = self.resnet(x)
        x = self.up_sample(x)
        return self.out_put(x)

Discriminator

In [7]:
class Discriminator(nn.Module):
    def __init__(self, img_channels=3, out_channels=64):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(out_channels, out_channels*2, kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels*2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(out_channels*2, out_channels*4, kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels*4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(out_channels*4, out_channels*8, kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels*8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.ReflectionPad2d((1,0,1,0)),
            nn.Conv2d(out_channels*8, 1, kernel_size=4, padding=1)
        )
        
    def forward(self, x):
        return self.model(x)

In [8]:
gen_m2p = Generator().to(device)
dis_m = Discriminator().to(device)
gen_p2m = Generator().to(device)
dis_p = Discriminator().to(device)

In [9]:
learning_rate = 0.0002
betas_range = tuple([0.5, 0.999])

gen_m2p_optim = optim.Adam(gen_m2p.parameters(), lr=learning_rate, betas=betas_range)
gen_p2m_optim = optim.Adam(gen_p2m.parameters(), lr=learning_rate, betas=betas_range)
dis_m_optim = optim.Adam(dis_m.parameters(), lr=learning_rate, betas=betas_range)
dis_p_optim = optim.Adam(dis_p.parameters(), lr=learning_rate, betas=betas_range)

gan_loss = nn.MSELoss().to(device)
cycle_loss = nn.L1Loss().to(device)
id_loss = nn.L1Loss().to(device)

epoch_num = 50
start_epoch = epoch_num // 5
lambda_val = lambda epoch: 1 - max(0, epoch-start_epoch) / (epoch_num-start_epoch)  
gen_m2p_lr = optim.lr_scheduler.LambdaLR(gen_m2p_optim, lr_lambda=lambda_val)
gen_p2m_lr = optim.lr_scheduler.LambdaLR(gen_p2m_optim, lr_lambda=lambda_val)
dis_m_lr = optim.lr_scheduler.LambdaLR(dis_m_optim, lr_lambda=lambda_val)
dis_p_lr = optim.lr_scheduler.LambdaLR(dis_p_optim, lr_lambda=lambda_val)

Process the dataset

In [10]:
class MonetDataset(Dataset):
    def __init__(self, data_path, mode=0, transformer=None):
        monet_path = os.path.join(data_path, 'monet_jpg')
        photo_path = os.path.join(data_path, 'photo_jpg')
        monet_list = os.listdir(monet_path)
        photo_list = os.listdir(photo_path)
        self.length = min(len(monet_list), len(photo_list))
        self.transformer = transformer
        div = 250
        down_bound, up_bound = div * mode, div * (1 - mode) + self.length * mode
        self.monet_images = [os.path.join(monet_path, _) for _ in monet_list[down_bound:up_bound]]
        self.photo_images = [os.path.join(photo_path, _) for _ in photo_list[down_bound:up_bound]]
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        monet_exp = self.monet_images[idx % len(self.monet_images)]
        photo_exp = self.photo_images[idx % len(self.photo_images)]
        monet_exp = Image.open(monet_exp)
        photo_exp = Image.open(photo_exp)
        
        if self.transformer:
            monet_exp = self.transformer(monet_exp)
            photo_exp = self.transformer(photo_exp)
        return monet_exp, photo_exp

data_location = '/kaggle/input/gan-getting-started'
img_size = 256
transformer_train = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5])
            ])

transformer_test = transforms.Compose([
                transforms.Resize((img_size, img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5])
            ])

batch_size = 5
worker_num = 2
# mode=0 represent train set, mode=1 represent test set
trainloader = DataLoader(
    MonetDataset(data_location, mode=0, transformer=transformer_train),
    batch_size = batch_size,
    shuffle = True,
    num_workers = worker_num
)

testloader = DataLoader(
    MonetDataset(data_location, mode=1, transformer=transformer_test),
    batch_size = batch_size,
    shuffle = False,
    num_workers = worker_num
)

Show some of the images

In [11]:
def show_samples(images):
    monet_real, photo_real = images[0].to(device), images[1].to(device)
    image_num = min(monet_real.size(0), 5)
    gen_m2p.eval()
    gen_p2m.eval()
    monet_fake = gen_m2p(monet_real).detach()
    photo_fake = gen_p2m(photo_real).detach()
    figure_size = tuple([4*image_num, 4*image_num])
    
    var_name = ["monet_real", "monet_fake", "photo_real", "photo_fake"]
    title_name = ["real monet image", "m2p image", "real photo image", "p2m image"]
    for i in range(len(var_name)):
        cur_image = make_grid(eval(var_name[i]), nrow=image_num, normalize=True)
        plt.figure(figsize=figure_size)
        plt.imshow(cur_image.cpu().permute(1, 2, 0))
        plt.axis('off')
        plt.title(title_name[i])
        plt.show()

image_to_show = next(iter(testloader))
show_samples(image_to_show)

def show_samples1(images):
    image_num = min(images.size(0), 5)
    gen_m2p.eval()
    gen_p2m.eval()
    monet_fake = gen_m2p(images).detach()
    
    figure_size = tuple([4*image_num, 4*image_num])
    
    var_name = ["images", "monet_fake"]
    for i in range(len(var_name)):
        cur_image = make_grid(eval(var_name[i]), nrow=image_num, normalize=True)
        plt.figure(figsize=figure_size)
        plt.imshow(cur_image.cpu().permute(1, 2, 0))
        plt.axis('off')
        plt.show()

Train the model

In [12]:
model_list = ["gen_m2p", "gen_p2m", "dis_m", "dis_p"]
scale_size = 16
real_img = torch.ones([batch_size, 1, img_size // scale_size, img_size // scale_size]).to(device)
fake_img = torch.zeros([batch_size, 1, img_size // scale_size, img_size // scale_size]).to(device)
loss_val = 5
for epoch in range(epoch_num):
    for train_data in trainloader:
        monet_real, photo_real = train_data[0].to(device), train_data[1].to(device)
        # Train Generators G and F
        for i in range(2):
            eval(model_list[i]).train()
            eval(model_list[i] + "_optim").zero_grad()
        photo_fake = gen_m2p(monet_real)
        monet_fake = gen_p2m(photo_real)
        
        # Loss_idt and Loss_GAN
        loss_idt = (id_loss(photo_fake, monet_real) + id_loss(monet_fake, photo_real)) * 0.5
        loss_GAN = (gan_loss(dis_p(photo_fake), real_img) + gan_loss(dis_m(monet_fake), real_img)) * 0.5
        
        # Loss_Cycle and Loss_Total
        m2p2m = gen_p2m(photo_fake)
        p2m2p = gen_m2p(monet_fake)
        loss_cycle = cycle_loss(m2p2m, monet_real) + cycle_loss(p2m2p, photo_real)
        loss_G = loss_val * (loss_idt + loss_cycle) + loss_GAN
        loss_G.backward()
        
        # Train Discriminator P and Train Discriminator M
        for i in range(2):
            eval(model_list[i+2] + "_optim").zero_grad()
        loss_D_P = (gan_loss(dis_m(monet_real), real_img) + gan_loss(dis_m(monet_fake.detach()), fake_img)) * 0.5
        loss_D_P.backward()
        loss_D_M = (gan_loss(dis_p(photo_real), real_img) + gan_loss(dis_p(photo_fake.detach()), fake_img)) * 0.5
        loss_D_M.backward()
        
        # step the optimizers
        for i in range(len(model_list)):
            eval(model_list[i]+'_optim').step()
        
    for i in range(len(model_list)):
        eval(model_list[i]+"_lr").step()
    
    print("epoch" + str(epoch + 1) + " is finished")
    loss_D = (loss_D_P + loss_D_M) * 0.5
    print("Current epoch is: " + str(epoch + 1) + ' of ' + str(epoch_num))
    print("Generators loss:",loss_G.item(), "Loss_idt:" , loss_idt.item())
    print("Loss_GAN:", loss_GAN.item(), "Loss_Cycle:", loss_cycle.item())
    print("Discriminators loss:", loss_D.item(), "D_P:", loss_D_P.item(), "D_M:", loss_D_M.item())   
    if (epoch + 1) % 10 == 0:
        show_samples(image_to_show) 

In [None]:
save_dir = '../images'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
# This submission part refers to https://www.kaggle.com/code/lmyybh/pytorch-cyclegan 
# For other parts, a very few thoughts are also learnt from it, they are too sparse to be cited.
# We wrote all the codes basing on our own understanding for our own design.
photo_path = os.path.join(data_location, 'photo_jpg')
files = [os.path.join(photo_path, name) for name in os.listdir(photo_path)]

norm_val = tuple([0.5] * 3)
generate_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(norm_val, norm_val)
])

to_image = transforms.ToPILImage()
gen_p2m.eval()
for i in range(0, len(files), batch_size):
    imgs = []
    limit = min(len(files), i + batch_size)
    for j in range(i, limit):
        img = Image.open(files[j])
        imgs.append(generate_transforms(img))
    imgs = torch.stack(imgs, 0).to(device)
    fake_imgs = gen_p2m(imgs).detach().cpu()
    image_num = fake_imgs.size(0)
    for j in range(image_num):
        img = fake_imgs[j].squeeze().permute(1, 2, 0)
        img_arr = img.numpy()
        img_arr = (img_arr - np.min(img_arr)) * 255 / (np.max(img_arr) - np.min(img_arr))
        img = to_image(img_arr.astype(np.uint8))
        _, name = os.path.split(files[i+j])
        img.save(os.path.join(save_dir, name))

compress and submit

In [None]:
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")
torch.save(gen_m2p.state_dict(),'./gen_m2p.pt')
torch.save(dis_m.state_dict(),'./dis_m.pt')
torch.save(gen_p2m.state_dict(),'./gen_p2m.pt')
torch.save(dis_p.state_dict(),'./dis_p.pt')