In [1]:
!kaggle competitions download -c gan-getting-started

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
import itertools
import time

# **Dataloader**

In [84]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class CycleGanDataset(Dataset):
    def __init__(self, path, size = (256, 256)):
    
        self.photo_path = os.path.join(path, 'photo_jpg')
        self.monet_path = os.path.join(path, 'monet_jpg')
        
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
        )
        
        self.photo_idx = dict()
        self.monet_idx = dict()
        
        for i, path in enumerate(os.listdir(self.photo_path)):
            self.photo_idx[i] = path
        
        for i, path in enumerate(os.listdir(self.monet_path)):
            self.monet_idx[i] = path
    
    def __len__(self):
        return min(len(os.listdir(self.photo_path)), len(os.listdir(self.monet_path)))
    
    def __getitem__(self, idx):
        rand_idx = int(np.random.uniform(0, len(self.monet_idx.keys())))
        photo_dir = os.path.join(self.photo_path, self.photo_idx[rand_idx])
        monet_dir = os.path.join(self.monet_path, self.monet_idx[idx])
        
        photo_img = Image.open(photo_dir).convert('RGB')
        photo_img = self.transform(photo_img)
        
        monet_img = Image.open(monet_dir).convert('RGB')
        monet_img = self.transform(monet_img)
        
        return photo_img, monet_img
        
    
        

In [85]:
data = CycleGanDataset('../input/gan-getting-started/')
loader = DataLoader(data, batch_size=1)

In [86]:
def unnorm(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    for t, m, s in zip(img, mean, std):
        t.mul_(s).add_(s)
        
    return img

In [87]:
def save_checkpoint(state, save_path):
    torch.save(state, save_path)

In [88]:
#Пробуем нарисовать
photo_img, monet_img = next(iter(loader))

fig = plt.figure(figsize=(10, 10))

fig.add_subplot(1, 2, 1)
plt.title('Photo')
photo_img = unnorm(photo_img)
plt.imshow(photo_img[0].permute(1, 2, 0))

fig.add_subplot(1, 2, 2)
plt.title('Monet')
monet_img = unnorm(monet_img)
plt.imshow(monet_img[0].permute(1, 2, 0))

 # **Generator**

In [89]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, 3),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, 3),
            nn.InstanceNorm2d(in_channels)
        )
    
    def forward(self, x):
        return x + self.block(x)

In [90]:
def Downsample(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(inplace=True),
        )

In [91]:
def Upsample(in_channels, out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

In [92]:
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        layers = list()
        
        # Downsample
        layers.append(nn.ReflectionPad2d(in_channels))
        layers.append(Downsample(in_channels, 64, 7, 1, 0)) # 3, 256, 256 -> 64, 256, 256
        layers.append(Downsample(64, 128, 3, 2, 1)) # 64, 256, 256 -> 128, 128, 128
        layers.append(Downsample(128, 256, 3, 2, 1)) # 128, 128, 128 -> 256, 64, 64
        
        #ResBlock
        for _ in range(6):
            layers.append(ResidualBlock(256))
            
        # Upsample
        layers.append(Upsample(256, 128)) # 256, 64, 64 -> 128, 128, 128
        layers.append(Upsample(128, 64)) # 128, 128, 128 -> 64, 256, 256
        layers.append(nn.ReflectionPad2d(3))
        layers.append(nn.Conv2d(64, out_channels, 7, padding = 0)) # 64, 256, 256 -> 3, 256, 256
        layers.append(nn.Tanh())
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)

# **Discriminator**

In [93]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, out_channesl):
        super().__init__()
        layers = list()
        layers.append(nn.Conv2d(in_channels, 64, 4, stride=2, padding=1)) # 3, 256, 256 -> 64, 128, 128 
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        
        layers.append(Downsample(64, 128, 4, 2, 1)) # 64, 128, 128 -> 128, 64, 64
        layers.append(Downsample(128, 256, 4, 2 ,1)) # 128, 64, 64 -> 256, 32, 32
        layers.append(Downsample(256, 512, 4, 1, 1)) # 256, 32, 32 -> 512, 16, 16
        
        layers.append(nn.Conv2d(512, 1, 4, 1, 1)) # 512, 16, 16 -> 1, 16, 16
        layers.append(nn.Sigmoid())
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [94]:
def update_req_grad(models, requires_grad=True):
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad

In [95]:
class lr_sched():
    def __init__(self, decay_epochs=100, total_epochs=200):
        self.decay_epochs = decay_epochs
        self.total_epochs = total_epochs

    def step(self, epoch_num):
        if epoch_num <= self.decay_epochs:
            return 1.0
        else:
            fract = (epoch_num - self.decay_epochs)  / (self.total_epochs - self.decay_epochs)
            return 1.0 - fract

# **CycleGan**

In [96]:
class CycleGan(object):
    def __init__(
        self, 
        monet_generator, 
        photo_generator, 
        monet_discriminator, 
        photo_discriminator,
        device,
        epochs,
        lmbda=10,
        decay_epoch=0
    ):
        self.device = device
        self.epochs = epochs
        self.decay_epoch = decay_epoch
        self.m_gen = monet_generator.to(self.device)
        self.p_gen = photo_generator.to(self.device)
        self.m_disc = monet_discriminator.to(self.device)
        self.p_disc = photo_discriminator.to(self.device)
        self.epochs = epochs
        self.lmbda = lmbda
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        self.adam_gen = torch.optim.Adam(list(self.m_disc.parameters()) + list(self.p_disc.parameters()),
                                         lr = 1e-3, betas=(0.5, 0.999))
        self.adam_disc = torch.optim.Adam(list(self.m_gen.parameters()) + list(self.p_gen.parameters()),
                                          lr=1e-3, betas=(0.5, 0.999))
        gen_lr = lr_sched(self.decay_epoch, self.epochs)
        disc_lr = lr_sched(self.decay_epoch, self.epochs)
        self.gen_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.adam_gen, gen_lr.step)
        self.disc_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.adam_disc, disc_lr.step)
    

    def train(self, dataloader):
        for epoch in range(self.epochs):
            start_time = time.time()
            avg_gen_loss = 0.0
            avg_disc_loss = 0.0
            loop = tqdm(dataloader, leave=True)
            
            for i,(photo, monet) in enumerate(loop):

                
                
                update_req_grad([self.m_disc, self.p_disc], False)
                photo_true = photo.to(self.device)
                monet_true = monet.to(self.device)
                
                # photo to monet and monet to photo
                monet_fake = self.m_gen(photo_true)
                photo_cycle = self.p_gen(monet_fake)
                
                # monet to photo and photo to monet
                photo_fake = self.p_gen(monet_true)
                monet_cycle = self.m_gen(photo_fake)
                
                # generating itself
                monet_idt = self.m_gen(monet_true)
                photo_idt = self.p_gen(photo_true)
                
                # loss for generator
                monet_idt_loss = self.l1_loss(monet_true, monet_idt)
                photo_idt_loss = self.l1_loss(photo_true, photo_idt)
                
                monet_cycle_loss = self.l1_loss(monet_cycle, monet_true) * self.lmbda
                photo_cycle_loss = self.l1_loss(photo_cycle, photo_true) * self.lmbda
                
                monet_disc = self.m_disc(monet_fake)
                photo_disc = self.p_disc(photo_fake)
                true = torch.ones(monet_disc.size()).to(self.device)
                
                monet_gen_loss = self.mse_loss(monet_disc, true)
                photo_gen_loss = self.mse_loss(photo_disc, true)
                
                # total generator loss
                total_gen_loss = monet_idt_loss + photo_idt_loss + monet_gen_loss + photo_gen_loss + photo_cycle_loss + monet_cycle_loss
                avg_gen_loss += total_gen_loss.item()
                #backward
                self.adam_gen.zero_grad()
                total_gen_loss.backward()
                self.adam_gen.step()
                
                
                #loss for discriminator
                monet_fake = self.m_gen(photo_true)
                photo_fake = self.p_gen(monet_true)
                
                monet_disc = self.m_disc(monet_fake.detach())
                photo_disc = self.p_disc(photo_fake.detach())
                monet_disc_real = self.m_disc(monet_true)
                monet_disc_fake = self.m_disc(monet_fake)
                photo_disc_real = self.p_disc(photo_true)
                photo_disc_fake = self.p_disc(photo_fake)
                

                
                true = torch.ones(monet_disc_real.size()).to(self.device)
                fake = torch.zeros(monet_disc_fake.size()).to(self.device)
                
                monet_disc_true_loss = self.mse_loss(monet_disc_real, true)
                monet_disc_fake_loss = self.mse_loss(monet_disc_fake, fake)
                photo_disc_true_loss = self.mse_loss(photo_disc_real, true)
                photo_disc_fake_loss = self.mse_loss(photo_disc_fake, fake)
                
                monet_disc_loss = (monet_disc_true_loss + monet_disc_fake_loss) / 2
                photo_disc_loss = (photo_disc_true_loss + photo_disc_fake_loss) / 2
                
                # total loss
                total_disc_loss = monet_disc_loss + photo_disc_loss
                avg_disc_loss += total_disc_loss.item()

                
                # backward
                update_req_grad([self.m_disc, self.p_disc], False)
                self.adam_disc.zero_grad()
                
                photo_disc_loss.backward()
                monet_disc_loss.backward()
                self.adam_disc.step()
                
                # output
            avg_gen_loss /= dataloader.__len__()
            avg_disc_loss /= dataloader.__len__()
            time_req = time.time() - start_time

            print("Epoch: (%d) | Generator Loss:%f | Discriminator Loss:%f" % 
                                                (epoch+1, avg_gen_loss, avg_disc_loss))

In [99]:
monet_generator = Generator(3, 3)
photo_generator = Generator(3, 3)
monet_discriminator = Discriminator(3, 3)
photo_discriminator = Discriminator(3, 3)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cycleGan = CycleGan(monet_generator, photo_generator, monet_discriminator, photo_discriminator, device, epochs=10)

In [100]:
cycleGan.train(loader)

In [101]:
#Так и не понял в чем беда =/
for idx, (monet, photo) in enumerate(loader):
    monet = monet.to(device)
    photo = photo.to(device)
    fake_photo = photo_generator(monet).cpu().detach()[0]
    fake_monet = monet_generator(photo).cpu().detach()[0]
    
    plt.subplot(121)
    plt.imshow(monet.cpu().detach()[0].squeeze().permute(1, 2, 0) * 0.5 + 0.5)
    plt.subplot(122)
    plt.imshow(fake_photo.squeeze().permute(1, 2, 0) * 0.5 + 0.5)
        
    plt.show()
        
    if idx == 4:
        break