# RL and Advanced DL: Домашнее задание 3

In [2]:
from glob import glob
import itertools
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from PIL import Image
from torchvision import transforms, utils
from IPython.display import display

## Часть первая: перенос стиля при помощи pix2pix
В этом задании мы будем реализовывать перенос стиля на спаренных данных (paired data). Из стандартных датасетов я выбрал самый маленький -- давайте реализуем перенос стиля из масок сегментации в “настоящие фотографии” для конкретного случая фасадов домов

В первой части подразумевается реализация pix2pix архитектуры, т.е. генератора и дискриминатора, которые получают на вход условие в виде изображения исходного стиля (подробнее см. лекцию и собственно Isola et al., 2016)

In [16]:
def encoder_layer(in_channel, out_channel, batchnorm=True):
    if batchnorm:
        layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_channel, affine=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),          
        )
    else:
        layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
        )        
    return layer

def decoder_layer(in_channel, out_channel, dropout=True):
    if dropout:      
        layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_channel, affine=True),
            nn.Dropout2d(p=0.5, inplace=True),
            #nn.ReLU(inplace=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),          
        )
    else:
        layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_channel, affine=True),
            #nn.ReLU(inplace=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
        )        
    
    return layer

In [24]:
class DefineUnet(nn.Module):
    def __init__(self, in_image=3, filter_size=64):
        super().__init__()
        self.enc = nn.MaxPool2d(kernel_size=2)
        self.dec = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
        self.enc1 = encoder_layer(in_image, filter_size)
        self.enc2 = encoder_layer(filter_size, filter_size * 2)
        self.enc3 = encoder_layer(filter_size * 2, filter_size * 4)
        self.enc4 = encoder_layer(filter_size * 4, filter_size * 8)
        self.enc5 = encoder_layer(filter_size * 8, filter_size * 8)
        self.enc6 = encoder_layer(filter_size * 8, filter_size * 16)
        
        self.dec1 = decoder_layer(filter_size * 16 + filter_size * 8, filter_size * 8)
        self.dec2 = decoder_layer(filter_size * 8  + filter_size * 8, filter_size * 8)
        self.dec3 = decoder_layer(filter_size * 4 + filter_size * 8, filter_size * 4)
        self.dec4 = decoder_layer(filter_size * 4 + filter_size * 2, filter_size * 2)
        self.dec5 = decoder_layer(filter_size * 2 + filter_size, filter_size)
        self.dec6 = nn.Conv2d(64, 3, kernel_size=1)#decoder_layer(filter_size, in_image) #nn.conv2d(64, 3, kernel_size=1)
        
        
    def forward(self, x):
        out1 = self.enc1(x)
        x = self.enc(out1) #dropout?
        out2 = self.enc2(x)
        x = self.enc(out2)
        out3 = self.enc3(x)
        x = self.enc(out3)
        out4 = self.enc4(x)
        x = self.enc(out4)
        out5 = self.enc5(x)
        x = self.enc(out5)
        x = self.enc6(x)
        
        x = self.dec(x)
        x = self.dec1(torch.cat([x, out5], dim=1))
        x = self.dec(x)
        x = self.dec2(torch.cat([x, out4], dim=1))
        x = self.dec(x)
        x = self.dec3(torch.cat([x, out3], dim=1))
        x = self.dec(x)
        x = self.dec4(torch.cat([x, out2], dim=1))
        x = self.dec(x)
        x = self.dec5(torch.cat([x, out1], dim=1))
        x = torch.sigmoid(self.dec6(x))
        
        return x
    
    
class Discriminator(nn.Module):
    def __init__(self, in_image=6, filter_size=64):
        super().__init__()
        self.enc = nn.MaxPool2d(kernel_size=2)
        self.pooling = nn.AdaptiveAvgPool2d((1,1))
        self.enc1 = encoder_layer(in_image, filter_size)
        self.enc2 = encoder_layer(filter_size, filter_size * 2)
        self.enc3 = encoder_layer(filter_size * 2, filter_size * 4)
        self.enc4 = encoder_layer(filter_size * 4, filter_size * 8)        
        self.linear = nn.Linear(filter_size * 8, 2)
        
    def forward(self, x1, x2):
        out1 = self.enc1(torch.cat([x1, x2], dim=1))
        x = self.enc(out1) #dropout?
        out2 = self.enc2(x)
        x = self.enc(out2)
        out3 = self.enc3(x)
        x = self.enc(out3)
        out4 = self.enc4(x)
        x = self.pooling(self.enc(out4)).squeeze(-1).squeeze(-1)
        
        return self.linear(x)
        

In [4]:
class Facades(torch.utils.data.Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.image = glob(self.root_dir + '/*.jpg')
            
    def __getitem__(self, idx):
        img = Image.open(self.image[idx]).convert("RGB")
        return transforms.ToTensor()(img.crop((256, 0, 512, 256))), transforms.ToTensor()(img.crop((0, 0, 256, 256)))

        
    def __len__(self):
        return len(self.image)
    
train_dataset = Facades(root_dir='/kaggle/working/facades/train')
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)

test_dataset = Facades(root_dir='/kaggle/working/facades/train')
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False)


In [None]:
from tqdm.notebook import tqdm

N_EPOCH = 20
LR = 0.0002
BETA = (0.5, 0.999)
STEP = 200
L1_LAMBDA = 100
generator = DefineUnet().to(device)
gen_optim = torch.optim.Adam(generator.parameters(), lr=LR, betas=BETA)
gen_scheduler = torch.optim.lr_scheduler.StepLR(gen_optim, STEP)
discriminator = Discriminator().to(device)
disc_optim = torch.optim.Adam(discriminator.parameters(), lr=LR, betas=BETA)
disc_scheduler = torch.optim.lr_scheduler.StepLR(disc_optim, STEP)

gen_history = list()
disc_history = list()

for i in tqdm(range(N_EPOCH)):
    for batch in train_dataloader:       
        batch[0] = batch[0].to(device)
        batch[1] = batch[1].to(device)
        
        generator.eval()
        discriminator.train()
        new_image = generator(batch[0])
        preds = discriminator(batch[0], new_image)
        real = discriminator(batch[0], batch[1])
        
        disc_optim.zero_grad()
        preds_loss = F.cross_entropy(preds, torch.zeros(preds.shape[0]).long().to(device))
        real_loss = F.cross_entropy(real, torch.ones(preds.shape[0]).long().to(device))
        
        loss = 0.5 * (real_loss + preds_loss)
        loss.backward()
        disc_optim.step()
        disc_history.append(loss.item())
        
        generator.train()
        discriminator.eval()
        
        new_image = generator(batch[0])
        preds = discriminator(batch[0], new_image)        
        loss = F.cross_entropy(preds, torch.ones(preds.shape[0]).long().to(device))
        loss = loss + L1_LAMBDA * F.l1_loss(new_image, batch[1])
        gen_optim.zero_grad()
        loss.backward()
        gen_optim.step()
        gen_history.append(loss.item())
        
    gen_scheduler.step()
    disc_scheduler.step()   
        

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(gen_history)

In [None]:
plt.plot(disc_history)