In [1]:
import os
import re
import math
import glob
import itertools

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision
from torchvision import transforms

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class Dataset(Dataset):
    def __init__(self, path_A, path_B):
        self.to_tensor = transforms.Compose([
                                transforms.Resize((128, 128)), 
                                transforms.ToTensor()])
        self.re_number = re.compile('[0-9]')
        self.files_A = self._get_files_path(path_A)
        self.files_B = self._get_files_path(path_B)
        
    def __len__(self):
        return len(self.files_A)
    
    def _file_numbers(self, x):
        filename = x.split('/')[-1]
        file_number = int(''.join(self.re_number.findall(filename)))
        return file_number
    
    def _get_files_path(self, path):
        path_files = glob.glob(path)
        path_files = sorted(path_files, key=self._file_numbers)
        return path_files
    
    def __getitem__(self, idx):
        A = self.to_tensor(Image.open(self.files_A[idx]))
        B = self.to_tensor(Image.open(self.files_B[idx]))
        return A, B

In [5]:
train_path_A = 'maps/trainA/*'
train_path_B = 'maps/trainB/*'
val_path_A = 'maps/valA/*'
val_path_B = 'maps/valB/*'

In [6]:
train_dataset = Dataset(train_path_A, train_path_B)
val_dataset = Dataset(val_path_A, val_path_B)

In [7]:
path_genererated_X = 'fake_X'
path_genererated_Y = 'fake_Y'

<center> Additional Functions

In [8]:
def plot(x_fake, y_fake, epoch, batch_idx):
    torchvision.utils.save_image(x_fake.data, f'{path_genererated_X}/{epoch}_{batch_idx}.png', 
                                 normalize=True)
    torchvision.utils.save_image(y_fake.data, f'{path_genererated_Y}/{epoch}_{batch_idx}.png', 
                                 normalize=True)

In [9]:
def weights_init(m):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean=0.0, std=0.02)

<center> Generator Layers

In [10]:
class DKLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DKLayer, self).__init__()
        self.refl_padding = nn.ReflectionPad2d(1)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
        self.instance_norm = nn.InstanceNorm2d(out_channels)
        
    def forward(self, batch):
        batch = self.refl_padding(batch)
        batch = self.conv(batch)
        batch = self.instance_norm(batch)
        batch = F.relu(batch)
        return batch

In [11]:
class ResidualBlock(nn.Module):
    def __init__(self, n_channels):
        super(ResidualBlock, self).__init__()
        self.refl_padding_1 = nn.ReflectionPad2d(1)
        self.conv1 = nn.Conv2d(n_channels, n_channels, kernel_size=3, padding=0)
        self.refl_padding_2 = nn.ReflectionPad2d(1)
        self.conv2 = nn.Conv2d(n_channels, n_channels, kernel_size=3, padding=0)
        
    def forward(self, batch):
        padded_batch = self.refl_padding_1(batch)
        conv_batch = self.conv1(padded_batch)
        padded_batch = self.refl_padding_2(conv_batch)
        conv_batch = self.conv2(padded_batch)
        out = batch + conv_batch
        return out

In [12]:
class Generator(nn.Module):
    def __init__(self, num_res=6):
        super(Generator, self).__init__()
        self.num_res = num_res
        #make layers
        self.refl_padding_1 = nn.ReflectionPad2d(3)
        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=1, padding=0)
        self.instance_norm1 = nn.InstanceNorm2d(32)
        self.dk_layer_1 = DKLayer(32, 64)
        self.dk_layer_2 = DKLayer(64, 128)
        self.res_blocks = nn.ModuleList([ResidualBlock(128) for _ in range(num_res)])
        
        self.conv_trans_1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, 
                                               output_padding=1)
        self.instance_norm2 = nn.InstanceNorm2d(64)
        self.conv_trans_2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, 
                                               output_padding=1)
        self.instance_norm3 = nn.InstanceNorm2d(32)
        
        self.refl_padding_2 = nn.ReflectionPad2d(3)
        self.conv2 = nn.Conv2d(32, 3, kernel_size=7, stride=1, padding=0)
        #init weights
        self.apply(weights_init)
        
    def forward(self, batch):
        #print('IN', batch.shape)
        batch = self.refl_padding_1(batch)
        batch = self.conv1(batch)
        batch = self.instance_norm1(batch)
        batch = F.relu(batch)
        #print('R', batch.shape)
        batch = self.dk_layer_1(batch)
        #print('DK1', batch.shape)
        batch = self.dk_layer_2(batch)
        #print('DK2', batch.shape)
        #res blocks
        for i in range(self.num_res):
            batch = self.res_blocks[i](batch)
        #print('RES', batch.shape)
        #deconvolutions
        batch = self.conv_trans_1(batch)
        batch = self.instance_norm2(batch)
        batch = F.relu(batch)
        #print('DEC1', batch.shape)
        batch = self.conv_trans_2(batch)
        batch = self.instance_norm3(batch)
        batch = F.relu(batch)
        #print('DEC2', batch.shape)
        batch = self.refl_padding_2(batch)
        batch = self.conv2(batch)
        batch = torch.tanh(batch)
        return batch

<center> Discriminator Layers

In [13]:
class ConvolutionNorm(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super(ConvolutionNorm, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=4, padding=1)
        self.instance_norm = nn.InstanceNorm2d(out_channels)
    def forward(self, batch, use_norm=True):
        batch = self.conv(batch)
        if use_norm:
            batch = self.instance_norm(batch)
        batch = F.leaky_relu(batch, 0.2)
        return batch

In [14]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = ConvolutionNorm(3, 64)
        self.conv2 = ConvolutionNorm(64, 128)
        self.conv3 = ConvolutionNorm(128, 256)
        self.conv4 = ConvolutionNorm(256, 512, stride=1)
        self.conv_out = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        self.apply(weights_init)
        
    def forward(self, batch):
        batch = self.conv1(batch, use_norm=False)
        batch = self.conv2(batch)
        batch = self.conv3(batch)
        outputs = self.conv4(batch)
        logits = self.conv_out(outputs)
        #logits = logits.squeeze(3).squeeze(2)
        probs = torch.sigmoid(logits)#.squeeze()
        return probs

<center> Loss Functions

In [15]:
'''def cycle_loss(generator_x, generator_y, x, y):
    fake_x = generator_x(y)
    fake_y = generator_y(x)
    proxy_y  = generator_y(fake_x)
    proxy_x = generator_x(fake_y)
    loss = (proxy_y - y).mean() + (proxy_x - x).mean()
    return loss'''

'def cycle_loss(generator_x, generator_y, x, y):\n    fake_x = generator_x(y)\n    fake_y = generator_y(x)\n    proxy_y  = generator_y(fake_x)\n    proxy_x = generator_x(fake_y)\n    loss = (proxy_y - y).mean() + (proxy_x - x).mean()\n    return loss'

In [16]:
cycle_loss = torch.nn.L1Loss()

In [17]:
mse_criterion = nn.MSELoss()

In [18]:
def discriminator_loss(discriminator, real_input, fake_input):
    prob_real = discriminator(real_input)
    real = torch.ones_like(prob_real).to(prob_real.device)
    prob_fake = discriminator(fake_input)
    fake = torch.zeros_like(prob_fake).to(prob_fake.device)
    
    real_loss = mse_criterion(prob_real, real)
    fake_loss = mse_criterion(prob_fake, fake)
    loss = (real_loss + fake_loss) / 2
    return loss

In [19]:
def generator_loss(discriminator, fake_input):
    prob_fake = discriminator(fake_input)
    real = torch.ones_like(prob_fake).to(prob_fake.device)
    loss = mse_criterion(prob_fake, real)
    return loss

<center> Training

In [20]:
def one_step(D_x, D_y, G_x, G_y, x, y, pass_steps=50, lambda_coef=10):
    #Generators
    for _ in range(pass_steps):
        fake_x = G_y(y)
        fake_y = G_x(x)
        rec_x = G_y(fake_y)
        rec_y = G_x(fake_x)
        gen_loss_x = generator_loss(D_y, fake_x)
        gen_loss_y = generator_loss(D_x, fake_y)
        #cycle loss
        c_loss_x = cycle_loss(rec_x, x)
        c_loss_y = cycle_loss(rec_y, y)
        gen_loss = gen_loss_x + gen_loss_y + lambda_coef * (c_loss_x + c_loss_y)
        #optimization
        optimizer_G.zero_grad()
        gen_loss.backward()
        optimizer_G.step()
    #Discriminators
    dis_y = discriminator_loss(D_y, y, fake_x.detach())
    dis_x = discriminator_loss(D_x, x, fake_y.detach())
    #step
    optimizer_D.zero_grad()
    dis_x.backward()
    dis_y.backward()
    optimizer_D.step()

In [21]:
num_epochs = 100
batch_size = 6
lr = 0.0002

In [22]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6)

In [23]:
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True, num_workers=6)

In [24]:
D_x = Discriminator().to(device)
D_y = Discriminator().to(device)
G_x = Generator().to(device)
G_y = Generator().to(device)

In [25]:
optimizer_G = torch.optim.Adam(itertools.chain(G_x.parameters(), G_y.parameters()),
                               lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(itertools.chain(D_x.parameters(), D_y.parameters()),
                               lr=lr, betas=(0.5, 0.999))

In [26]:
for X_batch, Y_batch in val_dataloader:
    X_batch, Y_batch = X_batch.to(device), Y_batch.to(device) 
    break

In [None]:
%%time
for epoch in range(num_epochs):
    batch_idx = 0
    for X_batch, Y_batch in train_dataloader:
        #optimize
        one_step(D_x, D_y, G_x, G_y, X_batch.to(device), Y_batch.to(device), pass_steps=5)
        #plot fake, real
        if batch_idx % 100 == 0:
            print(f'  {batch_idx}: pic saved')
            with torch.no_grad():
                for X_batch, Y_batch in val_dataloader:
                    plot(G_x(Y_batch.to(device)), G_y(X_batch.to(device)), epoch, batch_idx)
                    break
        batch_idx += 1



  0: pic saved
  100: pic saved
  0: pic saved
  100: pic saved
  0: pic saved


Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7fb285864f98>>
Traceback (most recent call last):
  File "/home/dulat/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 399, in __del__
    self._shutdown_workers()
  File "/home/dulat/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 378, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/dulat/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/dulat/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 151, in rebuild_storage_fd
    fd = df.detach()
  File "/home/dulat/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/dulat/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 87, in get

  100: pic saved


Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7fb270170b70>>
Traceback (most recent call last):
  File "/home/dulat/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 399, in __del__
    self._shutdown_workers()
  File "/home/dulat/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 378, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/dulat/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/dulat/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 151, in rebuild_storage_fd
    fd = df.detach()
  File "/home/dulat/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 58, in detach
    return reduction.recv_handle(conn)
  File "/home/dulat/anaconda3/lib/python3.6/multiprocessing/reduction.py", line 182, in recv_handle
    return recvfd

  0: pic saved
  100: pic saved
  0: pic saved
  100: pic saved


Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7fb270170898>>
Traceback (most recent call last):
  File "/home/dulat/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 399, in __del__
    self._shutdown_workers()
  File "/home/dulat/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 378, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/dulat/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/dulat/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 151, in rebuild_storage_fd
    fd = df.detach()
  File "/home/dulat/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 58, in detach
    return reduction.recv_handle(conn)
  File "/home/dulat/anaconda3/lib/python3.6/multiprocessing/reduction.py", line 182, in recv_handle
    return recvfd

  0: pic saved


Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7fb27117def0>>
Traceback (most recent call last):
  File "/home/dulat/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 399, in __del__
    self._shutdown_workers()
  File "/home/dulat/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 378, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/dulat/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/dulat/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 151, in rebuild_storage_fd
    fd = df.detach()
  File "/home/dulat/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/dulat/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 87, in get