# Style translation

In [None]:
import numpy as np
import pandas as pd
import torch
import torchvision
from torchvision import models, transforms
from tqdm.notebook import tqdm
import os
import shutil
import matplotlib.pyplot as plt
from skimage import io, transform
import itertools
%matplotlib inline

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
from torch.utils.data import Dataset, DataLoader

class PhotoMonetDataset(Dataset):

    def __init__(self, photo_path, monet_path):
        if not os.path.exists(photo_path):
            raise ValueError(f'{photo_path} does not exist')
        if not os.path.exists(monet_path):
            raise ValueError(f'{monet_path} does not exist')
        self.photo = photo_path
        self.monet = monet_path
        self.photo_names = sorted(os.listdir(self.photo))
        self.monet_names = sorted(os.listdir(self.monet))

    def __len__(self):
        return max(len(self.photo_names), len(self.monet_names))

    def __getitem__(self, idx):
        assert type(idx) == int
        
        n_photo = len(self.photo_names)
        n_monet = len(self.monet_names)
        
        if n_photo > n_monet:
            i_photo = idx
            i_monet = idx % n_monet
        else:
            i_photo = idx % n_photo
            i_monet = idx
        
        photo_name = self.photo_names[i_photo]
        monet_name = self.monet_names[i_monet]
        
        def get_image(path, filename):
            image = io.imread(os.path.join(path, filename))
            return transforms.ToTensor()(image) * 2.0 - 1.0
        
        return get_image(self.photo, photo_name), get_image(self.monet, monet_name)

In [None]:
dataset = PhotoMonetDataset(
    photo_path='../input/gan-getting-started/photo_jpg',
    monet_path='../input/gan-getting-started/monet_jpg',
)

dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=16, shuffle=True, num_workers=8
)

In [None]:
def downsample(in_channels, out_channels, size, apply_instancenorm=True):
    layers = []
    layers.append(torch.nn.Conv2d(in_channels, out_channels, size, stride=2, padding=1, bias=False))

    if apply_instancenorm:
        layers.append(torch.nn.InstanceNorm2d(out_channels))

    layers.append(torch.nn.LeakyReLU())

    return torch.nn.Sequential(*layers).to(device)

In [None]:
def upsample(in_channels, out_channels, size, apply_dropout=False):
    layers = []
    layers.append(torch.nn.ConvTranspose2d(in_channels, out_channels, size, stride=2, padding=1, bias=False))

    layers.append(torch.nn.InstanceNorm2d(out_channels))

    if apply_dropout:
        layers.append(torch.nn.Dropout(0.5))

    layers.append(torch.nn.ReLU())

    return torch.nn.Sequential(*layers).to(device)

In [None]:
OUTPUT_CHANNELS = 3

class Generator(torch.nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        
        # bs = batch size
        self.down_stack = [
            downsample(3, 64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
            downsample(64, 128, 4), # (bs, 64, 64, 128)
            downsample(128, 256, 4), # (bs, 32, 32, 256)
            downsample(256, 512, 4), # (bs, 16, 16, 512)
            downsample(512, 512, 4), # (bs, 8, 8, 512)
            downsample(512, 512, 4), # (bs, 4, 4, 512)
            downsample(512, 512, 4), # (bs, 2, 2, 512)
            downsample(512, 512, 4), # (bs, 1, 1, 512)
        ]

        self.up_stack = [
            upsample(512, 512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
            upsample(1024, 512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
            upsample(1024, 512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
            upsample(1024, 512, 4), # (bs, 16, 16, 1024)
            upsample(1024, 256, 4), # (bs, 32, 32, 512)
            upsample(512, 128, 4), # (bs, 64, 64, 256)
            upsample(256, 64, 4), # (bs, 128, 128, 128)
        ]

        self.last = torch.nn.ConvTranspose2d(128, OUTPUT_CHANNELS, 4,
                                             stride=2, padding=1).to(device) # (bs, 256, 256, 3)
        self.tanh = torch.nn.Tanh()

    def forward(self, x):
        skips = []
        for down in self.down_stack:
            x = down(x)
            skips.append(x)

        skips = reversed(skips[:-1])

        for up, skip in zip(self.up_stack, skips):
            x = up(x)
            x = torch.cat([x, skip], dim=1)

        x = self.tanh(self.last(x))
        return x

In [None]:
class Discriminator(torch.nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.down1 = downsample(3, 64, 4, False)
        self.down2 = downsample(64, 128, 4)
        self.down3 = downsample(128, 256, 4)

        self.zero_pad1 = torch.nn.ZeroPad2d(1)
        self.conv = torch.nn.Conv2d(256, 512, 4, stride=1, bias=False)

        self.norm1 = torch.nn.InstanceNorm2d(512)
        self.leaky_relu = torch.nn.LeakyReLU()

        self.zero_pad2 = torch.nn.ZeroPad2d(1)
        self.last = torch.nn.Conv2d(512, 1, 4, stride=1)
    
    def forward(self, x):
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        
        x = self.zero_pad1(x)
        x = self.conv(x)
        
        x = self.norm1(x)
        x = self.leaky_relu(x)
        
        x = self.zero_pad2(x)
        x = self.last(x)
        
        return x


In [None]:
def discriminator_loss(real, generated):
    real = torch.nn.Sigmoid()(real)
    generated = torch.nn.Sigmoid()(generated)
    real_loss = torch.nn.BCELoss(reduction='mean')(real, torch.ones_like(real))
    generated_loss = torch.nn.BCELoss(reduction='mean')(generated, torch.zeros_like(generated))
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss * 0.5

def generator_loss(generated):
    generated = torch.nn.Sigmoid()(generated)
    return torch.nn.BCELoss(reduction='mean')(generated, torch.ones_like(generated))

def calc_cycle_loss(real_image, cycled_image, LAMBDA):
    loss1 = torch.mean(torch.abs(real_image - cycled_image))
    return LAMBDA * loss1

def identity_loss(real_image, same_image, LAMBDA):
    loss = torch.mean(torch.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss

In [None]:
class CycleGAN:
    
    def __init__(self, lm, epochs, dataloader):
        self.lm = lm
        self.epochs = epochs
        self.dataloader = dataloader
        self.monet_gen = Generator().to(device)
        self.photo_gen = Generator().to(device)
        self.monet_disc = Discriminator().to(device)
        self.photo_disc = Discriminator().to(device)
        self.gen_opt = torch.optim.Adam(itertools.chain(self.monet_gen.parameters(), self.photo_gen.parameters()))
        self.disc_opt = torch.optim.Adam(itertools.chain(self.monet_disc.parameters(), self.photo_disc.parameters()))
        
    def generator_pass(self, real_photo, real_monet):
        fake_monet = self.monet_gen(real_photo)
        cycled_photo = self.photo_gen(fake_monet)

        fake_photo = self.photo_gen(real_monet)
        cycled_monet = self.monet_gen(fake_photo)

        same_monet = self.monet_gen(real_monet)
        same_photo = self.photo_gen(real_photo)

        disc_real_monet = self.monet_disc(real_monet)
        disc_real_photo = self.photo_disc(real_photo)

        disc_fake_monet = self.monet_disc(fake_monet)
        disc_fake_photo = self.photo_disc(fake_photo)

        monet_gen_loss = generator_loss(disc_fake_monet)
        photo_gen_loss = generator_loss(disc_fake_photo)

        total_gen_loss = monet_gen_loss + photo_gen_loss \
            + calc_cycle_loss(real_monet, cycled_monet, self.lm) \
            + calc_cycle_loss(real_photo, cycled_photo, self.lm) \
            + identity_loss(real_monet, same_monet, self.lm) + identity_loss(real_photo, same_photo, self.lm)
    
        return total_gen_loss

    def discriminator_pass(self, real_photo, real_monet):
        fake_monet = self.monet_gen(real_photo)
        fake_photo = self.photo_gen(real_monet)

        same_monet = self.monet_gen(real_monet)
        same_photo = self.photo_gen(real_photo)

        disc_real_monet = self.monet_disc(real_monet)
        disc_real_photo = self.photo_disc(real_photo)

        disc_fake_monet = self.monet_disc(fake_monet)
        disc_fake_photo = self.photo_disc(fake_photo)

        monet_disc_loss = discriminator_loss(disc_real_monet, disc_fake_monet)
        photo_disc_loss = discriminator_loss(disc_real_photo, disc_fake_photo)
    
        return (monet_disc_loss + photo_disc_loss) * 0.5

    def train(self):
        for i in range(self.epochs):
            print(f"epoch {i+1} / {self.epochs}")
            for (real_photo, real_monet) in tqdm(self.dataloader):
                real_photo, real_monet = real_photo.to(device), real_monet.to(device)
                
                self.gen_opt.zero_grad()
                
                gen_loss = self.generator_pass(real_photo, real_monet)
                gen_loss.backward()
                
                self.gen_opt.step()
                
                self.disc_opt.zero_grad()

                disc_loss = self.discriminator_pass(real_photo, real_monet)
                
                disc_loss.backward()
                self.disc_opt.step()
                
                print(f"gen loss: {gen_loss.item():.5f}, disc loss: {disc_loss.item():.5f}")


In [None]:
gan = CycleGAN(lm=10, epochs=15, dataloader=dataloader)

In [None]:
gan.train()

In [None]:
def unnorm(img):
    return img.add(1.0).mul(0.5)

plt.figure(figsize=(7, 16))
for i in range(5):
    photo_img, _ = next(iter(dataloader))
    pred_monet = gan.monet_gen(photo_img.to(device)).cpu().detach()
    photo_img = unnorm(photo_img)
    pred_monet = unnorm(pred_monet)

    plt.subplot(5, 2, i*2 + 1)
    plt.imshow(photo_img[0].permute(1, 2, 0))
    plt.title("Input Photo")
    plt.axis("off")
    
    plt.subplot(5, 2, i*2 + 2)
    plt.imshow(pred_monet[0].permute(1, 2, 0))
    plt.title("Monet-esque Photo")
    plt.axis("off")
plt.show()

In [None]:
! mkdir ../images

In [None]:
def unnorm(img):
    return img.add(1.0).mul(0.5)

trans = transforms.ToPILImage()

i = 1
for batch in tqdm(dataloader):
    photos, _ = batch
    monets = gan.monet_gen(photos.to(device)).cpu().detach()
    monets = unnorm(monets)
    for monet in monets:
        img = trans(monet).convert("RGB")
        img.save("../images/image" + str(i) + ".jpg")
        i += 1

In [None]:
shutil.make_archive("images", 'zip', "../images")