In [None]:
import torch
from torch import nn, optim, autograd

import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

# device agnostic code
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Define models

In [None]:
class ResBlock(nn.Module):

    def __init__(self, n_filter=64):
        super(ResBlock, self).__init__()

        self.res_block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(n_filter, n_filter, 5, padding=2),#nn.Linear(DIM, DIM),
            nn.ReLU(True),
            nn.Conv2d(n_filter, n_filter, 5, padding=2),#nn.Linear(DIM, DIM),
        )

    def forward(self, input):
        output = self.res_block(input)
        return input + (0.3*output)

In [None]:
class ResGenerator(nn.Module):
    def __init__(self, input_shape, n_filter=64, n_residual=7):
        super(self.__class__, self).__init__()
        
        self.input_shape = input_shape
        self.n_filter = n_filter
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(self.input_shape[1], n_filter, 9, padding=4),
            nn.LeakyReLU()
        )
        
        self.residual = nn.Sequential(
            *[ResBlock() for _ in range(n_residual)]
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(n_filter, n_filter, 3, padding=1),
            nn.BatchNorm2d(n_filter),
        )
        
        self.out = nn.Sequential(
            nn.Conv2d(n_filter, self.input_shape[1], 3, padding=1),
            nn.Tanh()
        )
        
    def forward(self, input):
        conv1 = self.conv1(input)
        res = self.residual(conv1)
        conv2 = self.conv2(res)
        out = self.out(conv2)
        
        return out

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_shape, wasserstein=True, n_filter=64):
        super(self.__class__, self).__init__()
        
        self.input_shape = input_shape
        
        self.wasserstein = wasserstein
        self.n_filter = n_filter
        
        self.layers = nn.Sequential(
            nn.Conv2d(input_shape[1], n_filter, 3, padding=1),
            
            nn.Sequential(
                *[
                    nn.Sequential(
                        nn.Conv2d(n_filter, n_filter, 3, padding=1),
                        nn.BatchNorm2d(n_filter),
                        nn.LeakyReLU()
                    )

                    for _ in range(7)
                ]
            )
        )
        
        _, chan, x, y = input_shape
        
        self.linear = nn.Sequential(
            nn.Linear(x * y * n_filter, 1024),
            nn.Sequential(nn.Linear(1024, 1)) if wasserstein else nn.Sequential(nn.Linear(1024, 1), nn.Sigmoid())
        )
                
    def forward(self, input):
        _, _, x, y = self.input_shape
        
        hid = self.layers(input).view(-1, x * y * self.n_filter)
        out = self.linear(hid)
        
        return out

## Define loss functions

Can be found in [this paper](https://arxiv.org/abs/1704.00028).

In [None]:
def gradient_penalty(disc, real_data, generated_data, device=device):
    batch_size = real_data.size(0)
    
    epsilon = torch.rand(batch_size, 1, device=device, requires_grad=True)
    epsilon_exp = epsilon.expand(batch_size, real_data.nelement()/batch_size).contiguous().view(real_data.size())
    
    interpolates = epsilon_exp * real_data + (1 - epsilon_exp) * generated_data
    
    disc_interpolates = disc(interpolates)
    
    gradients = autograd.grad(outputs=disc_interpolates.mean(), 
                              inputs=interpolates, 
                              create_graph=True, retain_graph=True
    )[0]
    
    gradients = gradients.view(gradients.size(0), -1)
     
    return ((gradients.norm(2, dim=1) - 1) ** 2).mean()

In [None]:
def wasserstein_loss(disc, real_data, fake_data, gp_lambda=10):
    return (disc(fake_data) - disc(real_data)).mean() + gradient_penalty(disc, real_data, fake_data) * gp_lambda

In [None]:
def cycle_loss(inputA, inputB, genA2B, genB2A):
    cycleA = genB2A(genA2B(inputA))
    cycleB = genA2B(genB2A(inputB))
    
    return ((cycleA - inputA).pow(2) + (cycleB - inputB).pow(2)).mean()

In [None]:
def generator_combined_loss(inputA, inputB, genA2B, genB2A, discA, discB):
    return -discA(genB2A(inputB)).mean() - discB(genA2B(inputA)).mean() + cycle_loss(inputA, inputB, genA2B, genB2A)

def discriminator_combined_loss(inputA, inputB, discA, discB, genA2B, genB2A):
    genA = genB2A(inputB)
    genB = genA2B(inputA)
    
    return wasserstein_loss(discA, inputA, genA) + wasserstein_loss(discB, inputB, genB)

## Define hyperparameters

In [None]:
IMG_SHAPE = [1, 1, 32, 32]
N_RESIDUAL = 4
ADAM_PARAMS = {"lr": 1e-4, "betas": (0.5, 0.9)}

build models

In [None]:
generatorA2B = ResGenerator(IMG_SHAPE, n_residual=N_RESIDUAL).cuda()
generatorB2A = ResGenerator(IMG_SHAPE, n_residual=N_RESIDUAL).cuda()

discriminatorA = Discriminator(IMG_SHAPE).cuda()
discriminatorB = Discriminator(IMG_SHAPE).cuda()

adamGen = optim.Adam(list(generatorA2B.parameters()) + list(generatorB2A.parameters()), **ADAM_PARAMS)
adamDisc = optim.Adam(list(discriminatorA.parameters()) + list(discriminatorB.parameters()), **ADAM_PARAMS)

## Prepare dataset

In [None]:
from lib.pic2tensor import pics2array
import matplotlib.pyplot as plt
%matplotlib inline

import os

In [None]:
x_dir = "datasets/notMNIST_small/A/"
y_dir = "datasets/notMNIST_small/B/"

X = np.array(pics2array(x_dir, os.listdir(x_dir), IMG_SHAPE[2:]), dtype="float32")[:, np.newaxis, :, :] / 255.
Y = np.array(pics2array(y_dir, os.listdir(y_dir), IMG_SHAPE[2:]), dtype="float32")[:, np.newaxis, :, :] / 255.

## Training models

In [None]:
y, x = torch.tensor(X[0:1]).to(device), torch.tensor(Y[0: 1]).to(device)

from tqdm import tqdm
from IPython.display import clear_output

BATCH_SIZE = 8
N_EPOCHS = 5000
N_DISC = 8
DISPLAY_EPOCH = 20

d_loss_hist = list()
g_loss_hist = list()

for epoch in tqdm(range(N_EPOCHS)):
    
    # train discriminator
    for _ in range(N_DISC):
        d_loss = discriminator_combined_loss(x, y, discriminatorA, discriminatorB, generatorA2B, generatorB2A)
                
        d_loss.backward()
        adamDisc.step()
        adamDisc.zero_grad()
    
    d_loss_hist.append(d_loss)
        
    g_loss = generator_combined_loss(x, y, generatorA2B, generatorB2A, discriminatorA, discriminatorB)
    
    g_loss_hist.append(g_loss)
    
    g_loss.backward()
    adamGen.step()
    adamGen.zero_grad()
    
    if epoch % DISPLAY_EPOCH == 0:
        clear_output()
        
        plt.plot(d_loss_hist[:50])
        plt.plot(g_loss_hist[:50])
        plt.show()

        
        ix = np.random.randint(0, len(X))
        expl = torch.tensor(Y[ix:ix + 1]).to(device)

        pic = generatorA2B(expl).data.cpu().numpy()[0, 0]

        plt.imshow(Y[ix, 0])
        plt.colorbar()
        plt.show()

        plt.imshow(pic)
        plt.colorbar()
        plt.show()

In [None]:
ix = np.random.randint(0, len(X))
expl = torch.tensor(Y[ix:ix + 1]).to(device)

pic = generatorA2B(expl).data.cpu().numpy()[0, 0]

plt.imshow(Y[ix, 0])
plt.colorbar()
plt.show()

plt.imshow(pic)
plt.colorbar()