In [14]:
import os
import re
import math
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision
from torchvision import transforms

import numpy as np

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:
class Dataset(Dataset):
    def __init__(self, path_A, path_B):
        self.to_tensor = transforms.Compose([transforms.Resize((128, 128)), 
                                transforms.ToTensor()])
        self.re_number = re.compile('[0-9]')
        self.files_A = self._get_files_path(path_A)
        self.files_B = self._get_files_path(path_B)
        
    def __len__(self):
        return len(self.files_A)
    
    def _file_numbers(self, x):
        filename = x.split('/')[-1]
        file_number = int(''.join(self.re_number.findall(filename)))
        return file_number
    
    def _get_files_path(self, path):
        path_files = glob.glob(path)
        path_files = sorted(path_files, key=self._file_numbers)
        return path_files
    
    def __getitem__(self, idx):
        A = self.to_tensor(Image.open(self.files_A[idx]))
        B = self.to_tensor(Image.open(self.files_B[idx]))
        return A, B

In [17]:
train_path_A = 'maps/trainA/*'
train_path_B = 'maps/trainB/*'

In [18]:
train_dataset = Dataset(train_path_A, train_path_B)

In [19]:
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=8)

In [20]:
path_X = 'fake_X'
path_y = 'fake_Y'

In [21]:
def plot(x_fake, y_fake, batch_idx):
    torchvision.utils.save_image(x_fake.data, f'{path_X}/{batch_idx}.png', normalize=True)
    torchvision.utils.save_image(y_fake.data, f'{path_y}/{batch_idx}.png', normalize=True)

In [22]:
class DKLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DKLayer, self).__init__()
        self.refl_padding = nn.ReflectionPad2d(1)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2)
        self.instance_norm = nn.InstanceNorm2d(out_channels)
        
    def forward(self, batch):
        batch = self.refl_padding(batch)
        batch = self.conv(batch)
        batch = self.instance_norm(batch)
        batch = F.relu(batch)
        return batch

In [23]:
class ResidualBlock(nn.Module):
    def __init__(self, n_channels):
        super(ResidualBlock, self).__init__()
        self.refl_padding_1 = nn.ReflectionPad2d(1)
        self.conv1 = nn.Conv2d(n_channels, n_channels, kernel_size=3)
        self.refl_padding_2 = nn.ReflectionPad2d(1)
        self.conv2 = nn.Conv2d(n_channels, n_channels, kernel_size=3)
        
    def forward(self, batch):
        padded_batch = self.refl_padding_1(batch)
        conv_batch = self.conv1(padded_batch)
        padded_batch = self.refl_padding_2(conv_batch)
        conv_batch = self.conv2(padded_batch)
        return batch + conv_batch

In [24]:
class Generator(nn.Module):
    def __init__(self, num_res=6):
        super(Generator, self).__init__()
        self.num_res = num_res
        #make layers
        self.refl_padding_1 = nn.ReflectionPad2d(3)
        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=1)
        self.instance_norm1 = nn.InstanceNorm2d(32)
        self.dk_layer_1 = DKLayer(32, 64)
        self.dk_layer_2 = DKLayer(64, 128)
        self.res_blocks = nn.ModuleList([ResidualBlock(128) for _ in range(num_res)])
        
        self.conv_trans_1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=(2, 2))
        self.instance_norm2 = nn.InstanceNorm2d(64)
        self.conv_trans_2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=(2, 2), output_padding=1)
        self.instance_norm3 = nn.InstanceNorm2d(32)
        
        self.refl_padding_2 = nn.ReflectionPad2d(1)
        self.conv2 = nn.Conv2d(32, 3, kernel_size=7, stride=1)
        self.instance_norm4 = nn.InstanceNorm2d(3)
        
    def forward(self, batch):
        #print('IN', batch.shape)
        batch = self.refl_padding_1(batch)
        batch = self.conv1(batch)
        batch = self.instance_norm1(batch)
        batch = F.relu(batch)
        #print('R', batch.shape)
        batch = self.dk_layer_1(batch)
        #print('DK1', batch.shape)
        batch = self.dk_layer_2(batch)
        #print('DK2', batch.shape)
        #res blocks
        for i in range(self.num_res):
            batch = self.res_blocks[i](batch)
        #print('RES', batch.shape)
        #deconvolutions
        batch = self.conv_trans_1(batch)
        batch = self.instance_norm2(batch)
        batch = F.relu(batch)
        #print('DEC1', batch.shape)
        batch = self.conv_trans_2(batch)
        batch = self.instance_norm3(batch)
        batch = F.relu(batch)
        #print('DEC2', batch.shape)
        batch = self.refl_padding_2(batch)
        batch = self.conv2(batch)
        batch = self.instance_norm4(batch)
        batch = F.relu(batch)
        return batch

In [25]:
class ConvolutionNorm(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvolutionNorm, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, stride=2, kernel_size=4)
        self.instance_norm = nn.InstanceNorm2d(out_channels)
    def forward(self, batch, use_norm=True):
        batch = self.conv(batch)
        if use_norm:
            batch = self.instance_norm(batch)
        batch = F.leaky_relu(batch, 0.2)
        return batch

In [26]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = ConvolutionNorm(3, 64)
        self.conv2 = ConvolutionNorm(64, 128)
        self.conv3 = ConvolutionNorm(128, 256)
        self.conv4 = ConvolutionNorm(256, 512)
        self.conv_out = nn.Conv2d(512, 1, kernel_size=4, stride=4)
        
    def forward(self, batch):
        batch = self.conv1(batch, use_norm=False)
        batch = self.conv2(batch)
        batch = self.conv3(batch)
        outputs = self.conv4(batch)
        logits = self.conv_out(outputs)
        logits = logits.squeeze(3).squeeze(2)
        probs = F.softmax(logits, dim=1).squeeze()
        return probs

In [27]:
def cycle_loss(generator_x, generator_y, x, y):
    fake_x = generator_x(y)
    fake_y = generator_y(x)
    proxy_y  = generator_y(fake_x)
    proxy_x = generator_x(fake_y)
    loss = (proxy_y - y).mean() + (proxy_x - x).mean()
    return loss

In [28]:
def discriminator_loss(discriminator, generator, x, y):
    loss = ((discriminator(y) - 1) ** 2).mean() + (discriminator(generator(x)) ** 2).mean()
    return loss

In [29]:
def generator_loss(discriminator, generator, x):
    return ((discriminator(generator(x)) - 1) ** 2).mean()

In [35]:
def one_step(D_x, D_y, G_x, G_y, x, y, pass_steps=50, lambda_coef=10):
    #Generators
    for _ in range(pass_steps):
        G_x_optimizer.zero_grad()
        G_y_optimizer.zero_grad()
        
        gen_loss_x = generator_loss(D_x, G_x, x)
        gen_loss_y = generator_loss(D_y, G_y, y)
        #cycle loss
        c_loss = cycle_loss(G_x, G_y, x, y)
        gen_loss_x += lambda_coef * c_loss
        gen_loss_y += lambda_coef * c_loss
        #optimization
        #X
        gen_loss_x.backward()
        G_x_optimizer.step()
        #Y
        gen_loss_y.backward()
        G_y_optimizer.step()
    #Discriminators
    #X
    dis_x = discriminator_loss(D_x, G_x, x.detach(), y)
    D_x_optimizer.zero_grad()
    dis_x.backward()
    D_x_optimizer.step()
    #Y
    dis_y = discriminator_loss(D_y, G_y, y.detach(), x)
    D_y_optimizer.zero_grad()
    dis_y.backward()
    D_y_optimizer.step()

In [36]:
num_epochs = 30
batch_size = 4
lr = 0.0002

In [37]:
D_x = Discriminator()
D_y = Discriminator()
G_x = Generator()
G_y = Generator() 

In [38]:
D_y_optimizer = torch.optim.Adam(D_y.parameters(), lr=lr)
D_x_optimizer = torch.optim.Adam(D_x.parameters(), lr=lr)
G_y_optimizer = torch.optim.Adam(G_y.parameters(), lr=lr)
G_x_optimizer = torch.optim.Adam(G_x.parameters(), lr=lr)

In [39]:
for epoch in range(num_epochs):
    for X_batch, Y_batch in train_dataloader:
        #optimize
        one_step(D_x, D_y, G_x, G_y, X_batch, Y_batch, pass_steps=1)
        #plot fake, real
        #with torch.no_grad():
        #    plot(G())

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.