In [None]:
import os
import torchvision.transforms as tt
import torch
import torch.nn as nn
from tqdm.notebook import tqdm

import numpy as np
import matplotlib.pyplot as plt
#import seaborn as sns
from PIL import Image
import random

In [None]:
photos = []
transform = tt.ToTensor()
names = random.sample(os.listdir('/kaggle/input/gan-getting-started/photo_jpg'), 300)
for f in names:
    img = Image.open('/kaggle/input/gan-getting-started/photo_jpg/' + f)
    tensor = transform(img).unsqueeze(0).to('cuda')
    #tensor = tensor.permute((1, 2, 0))
    photos.append(tensor)

In [None]:
monet = []
transform = tt.ToTensor()
for f in os.listdir('/kaggle/input/gan-getting-started/monet_jpg'):
    img = Image.open('/kaggle/input/gan-getting-started/monet_jpg/' + f)
    tensor = transform(img).unsqueeze(0).to('cuda')
    #tensor = tensor.permute((1, 2, 0))
    monet.append(tensor)

In [None]:
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels, size, apply_instancenorm=True):
        super(Downsample, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, size, stride=2, padding=1, bias=False)
        self.apply_instancenorm = apply_instancenorm

        if apply_instancenorm:
            self.instance_norm = nn.InstanceNorm2d(out_channels, affine=False)
        
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.apply_instancenorm:
            x = self.instance_norm(x)
        x = self.leaky_relu(x)
        return x

In [None]:
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, size, apply_dropout=False):
        super(Upsample, self).__init__()

        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, size, stride=2, padding=1, output_padding=0, bias=False)
        self.apply_dropout = apply_dropout
        self.instance_norm = nn.InstanceNorm2d(out_channels, affine=False)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.conv_transpose(x)
        x = self.instance_norm(x)
        
        if self.apply_dropout:
            x = self.dropout(x)
        
        x = self.relu(x)
        return x

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.down_stack = nn.ModuleList([
            Downsample(3, 64, 4, apply_instancenorm=False),  # (bs, 128, 128, 64)
            Downsample(64, 128, 4),  # (bs, 64, 64, 128)
            Downsample(128, 256, 4),  # (bs, 32, 32, 256)
            Downsample(256, 512, 4),  # (bs, 16, 16, 512)
            Downsample(512, 512, 4),  # (bs, 8, 8, 512)
            Downsample(512, 512, 4),  # (bs, 4, 4, 512)
            Downsample(512, 512, 4),  # (bs, 2, 2, 512)
            Downsample(512, 512, 4, apply_instancenorm=False)  # (bs, 1, 1, 512)
        ])

        self.up_stack = nn.ModuleList([
            Upsample(512, 512, 4, apply_dropout=True),  # (bs, 2, 2, 1024)
            Upsample(1024, 512, 4, apply_dropout=True),  # (bs, 4, 4, 1024)
            Upsample(1024, 512, 4, apply_dropout=True),  # (bs, 8, 8, 1024)
            Upsample(1024, 512, 4),  # (bs, 16, 16, 1024)
            Upsample(1024, 256, 4),  # (bs, 32, 32, 512)
            Upsample(512, 128, 4),  # (bs, 64, 64, 256)
            Upsample(256, 64, 4)  # (bs, 128, 128, 128)
        ])

        self.last = nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1)
        self.tanh = nn.Tanh()
    def forward(self, x):
        skips = []
        for down in self.down_stack:
            x = down(x)
            skips.append(x)

        skips = reversed(skips[:-1])

        for up, skip in zip(self.up_stack, skips):
            x = up(x)
            x = torch.cat([x, skip], dim=1)

        x = self.last(x)
        x = self.tanh(x)

        return x

In [None]:
Discriminator = nn.Sequential(
    Downsample(3, 64, 4, False), # (bs, 128, 128, 64)
    Downsample(64, 128, 4), # (bs, 64, 64, 128)
    Downsample(128, 256, 4), # (bs, 32, 32, 256)
    nn.ZeroPad2d(1), # (bs, 34, 34, 256)
    nn.Conv2d(256, 512, 4, stride=1, padding=0, bias=False), # (bs, 31, 31, 512)
    nn.InstanceNorm2d(512, affine=False),
    nn.LeakyReLU(),
    nn.ZeroPad2d(1),  # (bs, 33, 33, 512)
    nn.Conv2d(512, 1, 4, stride=1),
    #nn.Sigmoid()
)


In [None]:
monet_generator = Generator().to('cuda') # transforms photos to Monet-esque paintings
photo_generator = Generator().to('cuda') # transforms Monet paintings to be more like photos

monet_discriminator = Discriminator.to('cuda') # differentiates real Monet paintings and generated Monet paintings
photo_discriminator = Discriminator.to('cuda') # differentiates real photos and generated photos

In [None]:
#model = CycleGAN(monet_generator, photo_generator, monet_discriminator, photo_discriminator)

optimizer_mg = torch.optim.Adam(monet_generator.parameters(), lr=0.001, )#betas=(0.5, 0.999))
optimizer_pg = torch.optim.Adam(photo_generator.parameters(), lr=0.001, )#betas=(0.5, 0.999))
optimizer_md = torch.optim.Adam(monet_discriminator.parameters(), lr=0.001, )#betas=(0.5, 0.999))
optimizer_pd = torch.optim.Adam(photo_discriminator.parameters(), lr=0.001, )#betas=(0.5, 0.999))

criterion = nn.MSELoss()

In [None]:
for epoch in range(40):
    monet_generator.train()
    photo_generator.train()
    monet_discriminator.train()
    photo_discriminator.train()
    print(f'Epoch {epoch + 1}')
    monet_epoch_losses = []
    disc_epoch_losses = []
    for i in tqdm(range(300)):
        real_photo = photos[i]
        real_monet = monet[i]
    #real_photo, real_monet = batch
        optimizer_md.zero_grad()
        optimizer_pd.zero_grad()
        optimizer_mg.zero_grad()
        optimizer_pg.zero_grad()

        fake_monet = monet_generator(real_photo)
        fake_photo = photo_generator(real_monet)

        disc_real_monet = monet_discriminator(real_monet)
        disc_fake_monet = monet_discriminator(fake_monet)

        disc_real_photo = photo_discriminator(real_photo)
        disc_fake_photo = photo_discriminator(fake_photo)

        cycled_photo = photo_generator(fake_monet)
        cycled_monet = monet_generator(fake_photo)

        same_monet = monet_generator(real_monet)
        same_photo = photo_generator(real_photo)

        monet_discriminator_loss = 0.5 * criterion(disc_real_monet, torch.ones_like(disc_real_monet)) + 0.5 * criterion(disc_fake_monet, torch.zeros_like(disc_fake_monet))
        

        photo_discriminator_loss = 0.5 * criterion(disc_real_photo, torch.ones_like(disc_real_photo)) + 0.5 * criterion(disc_fake_photo, torch.zeros_like(disc_fake_photo))

        cycle_loss = torch.abs(cycled_photo - real_photo).mean() + torch.abs(cycled_monet - real_monet).mean()

        #monet_identity_loss = torch.abs(real_monet - same_monet).mean()
        monet_generator_loss = criterion(disc_fake_monet, torch.ones_like(disc_fake_monet))
        full_mg_loss = monet_generator_loss + cycle_loss #+ monet_identity_loss
        
        
        #photo_identity_loss = torch.abs(real_photo - same_photo).mean()

        photo_generator_loss = criterion(disc_fake_photo, torch.ones_like(disc_fake_photo))
        full_pg_loss = photo_generator_loss + cycle_loss #+ photo_identity_loss
        
        
        full_mg_loss.backward(retain_graph = True)
        optimizer_mg.step()
        
        full_pg_loss.backward(retain_graph = True)
        optimizer_pg.step()
        
        monet_discriminator_loss.backward(retain_graph = True)
        optimizer_md.step()
        
        photo_discriminator_loss.backward(retain_graph = True)
        optimizer_pd.step()
        
        k = monet_generator_loss.item()
        monet_epoch_losses.append(k)
        l = monet_discriminator_loss.item()
        disc_epoch_losses.append(l)
        print('Monet generator loss:', sum(monet_epoch_losses) / len(monet_epoch_losses),'Monet discriminator loss:', sum(disc_epoch_losses) / len(disc_epoch_losses), end ='\r')
    monet_generator.eval()
    test = random.choice(photos)
    out = monet_generator(test)
    plt.imshow(test.squeeze(0).permute((1,2,0)).cpu().detach().numpy())
    plt.show()
    plt.imshow(out.squeeze(0).permute((1,2,0)).cpu().detach().numpy())
    plt.show()
    #print('Monet generator loss:', sum(monet_epoch_losses) / len(monet_epoch_losses))
    #monet_epoch_loss = np.mean(monet_epoch_losses)
    #print(f'Epoch {epoch + 1} Monet generator loss:', monet_epoch_loss)
    
    
    #fake_monet, cycled_photo, fake_photo, cycled_monet, same_monet, same_photo, disc_real_monet, disc_real_photo, disc_fake_monet, disc_fake_photo = model(batch)
    
                                                                                                                    

In [None]:
torch.save(monet_generator, 'model2')

In [None]:
!pip install gdown

In [None]:
import gdown
#!gdown 15N3WHOQjYFxK-ovG-5CjV2mt76YOTd30

In [None]:
all_photos = []
transform = tt.ToTensor()
for f in tqdm(os.listdir('/kaggle/input/gan-getting-started/photo_jpg')):
    img = Image.open('/kaggle/input/gan-getting-started/photo_jpg/' + f)
    tensor = transform(img).unsqueeze(0).to('cuda')
    #tensor = tensor.permute((1, 2, 0))
    all_photos.append(tensor)
    
#do the inference on google colab

In [None]:
!gdown 1FumiaE5HDXx44hD1RRNVq90DQZjGmJuy #load the results

In [None]:
torch.save(monet_generator, 'model2')