# Objective:

This is a first attempt at creating a simple GAN model following the documentation of *Deep Learning with Pytorch* by Vishnu Subramanian and this Medium article: https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f

Certain model parameters and structure will be based off of the pix2pix model implemented in *Image-to-Image Translation with Conditional Adversarial Networks* by Isola, et al.  


In [1]:
import pandas as pd 
import numpy as np

from pathlib import Path
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import itertools
import time 
import cv2
import os

## 1. Load re-sized images and their corresponding sketches 

In [2]:
!pwd

/home/ec2-user/couch_gan


In [4]:
# assign paths 
real_img_path = Path('./data/og_resized')
sketch_img_path = Path('./data/sketch_resized')
fake_img_path = Path('./data/fake')

In [5]:
# get files into list  
real_img_files = [f for f in real_img_path.iterdir()]
sketch_img_files = [f for f in sketch_img_path.iterdir()]

## 2. Get image names of training and test files (assigned by Kaggle)

In [6]:
test_file_names = [f'00000{i}.jpg' for i in range(203, 303)]

In [7]:
# assign paths to appropriate lists 
train_real_paths = []
train_sketch_paths = []
test_real_paths = []
test_sketch_paths = []

for f in real_img_files:
    if f.parts[-1] in test_file_names:
        test_real_paths.append(f)
    else:
        train_real_paths.append(f)

for f in sketch_img_files:
    if f.parts[-1] in test_file_names:
        test_sketch_paths.append(f)
    else:
        train_sketch_paths.append(f)        

In [8]:
# sanity check that the order of real and sketch paths are the same 
counter = 0
for r, f in zip(train_real_paths, train_sketch_paths):
    if r.parts[-1] == r.parts[-1]:
        counter += 1
assert(counter == 900)

counter = 0
for r, f in zip(test_real_paths, test_sketch_paths):
    if r.parts[-1] == r.parts[-1]:
        counter += 1
assert(counter == 100)

## 3. Couch Dataset  and Dataloaders 

Note: No data augmentation 

In [9]:
def getX(path):
    x = cv2.imread(str(path)).astype(np.float32)  # reading image 
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)/255  # convert from BGR to RGB
    return x 

In [10]:
def normalize(im):
    """Normalizes images """
    imagenet_stats = np.array([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]])
    return (im - imagenet_stats[0])/imagenet_stats[1]

In [11]:
class CouchDataset(Dataset):
    def __init__(self, real_files, sketch_files):
        self.real_files = real_files
        self.sketch_files = sketch_files
        
    def __len__(self):
        return len(self.real_files)  # should be same number as sketch files 
        
    def __getitem__(self, idx):
        real_path = self.real_files[idx]
        sketch_path = self.sketch_files[idx]
        img_name = sketch_path.parts[-1]
        
        x_real = getX(real_path)
        x_sketch = getX(sketch_path)
        
        # normalize 
        x_real = normalize(x_real)
        x_sketch = normalize(x_sketch)
        
        # roll axis (channels, height, width)
        x_real = np.rollaxis(x_real, 2)
        x_sketch = np.rollaxis(x_sketch, 2)
        
        return x_sketch, x_real
        

In [12]:
# make datasets 
train_couch_ds = CouchDataset(real_files=train_real_paths, sketch_files=train_sketch_paths)
test_couch_ds = CouchDataset(real_files=test_real_paths, sketch_files=test_sketch_paths)



In [13]:
# inspect content 
x_s, x_r = train_couch_ds[20]


In [14]:
x_s, x_r

(array([[[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]],
 
        [[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]],
 
        [[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]]]),
 array([[[ 0.87450981,  0.89019608,  0.9137255 , ...,  0.32549024,
           0.3176471 ,  0.3176471 ],
         [ 0.88235295,  0.89803922,  0.9137255 , ..., -0.09019607,
          -0.03529412,  0.0196079 ],
         [ 0.84313726

In [15]:
# 3 channels for RGB
# 240 rows (height)
# 400 columns (width)
x_s.shape, x_r.shape

((3, 240, 400), (3, 240, 400))

## 4. Model  

In [53]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import torch.nn as nn
import torch.nn.functional as F
import torch

from torchvision.models import resnet18

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

In [18]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [27]:
# based on Bicycle GAN 
## https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/bicyclegan/models.py

## U-Net Generator 
class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 3, stride=2, padding=1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_size, 0.8))
        layers.append(nn.LeakyReLU(0.2))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)
    
class UNetUp(nn.Module):
    def __init__(self, in_size, out_size):
        super(UNetUp, self).__init__()
        self.model = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_size, out_size, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_size, 0.8),
            nn.ReLU(inplace=True),
        )

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x
    
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        channels, self.h, self.w = img_shape

        self.fc = nn.Linear(latent_dim, self.h * self.w)

        self.down1 = UNetDown(channels + 1, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512)
        self.down5 = UNetDown(512, 512)
        self.down6 = UNetDown(512, 512)
        self.down7 = UNetDown(512, 512, normalize=False)
        self.up1 = UNetUp(512, 512)
        self.up2 = UNetUp(1024, 512)
        self.up3 = UNetUp(1024, 512)
        self.up4 = UNetUp(1024, 256)
        self.up5 = UNetUp(512, 128)
        self.up6 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2), nn.Conv2d(128, channels, 3, stride=1, padding=1), nn.Tanh()
        )

    def forward(self, x, z):
        # Propogate noise through fc layer and reshape to img shape
        z = self.fc(z).view(z.size(0), 1, self.h, self.w)
        d1 = self.down1(torch.cat((x, z), 1))
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        u1 = self.up1(d7, d6)
        u2 = self.up2(u1, d5)
        u3 = self.up3(u2, d4)
        u4 = self.up4(u3, d3)
        u5 = self.up5(u4, d2)
        u6 = self.up6(u5, d1)

        return self.final(u6)    





In [67]:
# encoder 
class Encoder(nn.Module):
    def __init__(self, latent_dim, input_shape):
        super(Encoder, self).__init__()
        resnet18_model = resnet18(pretrained=False)
        self.feature_extractor = nn.Sequential(*list(resnet18_model.children())[:-3])

        self.pooling = nn.AvgPool2d(kernel_size=8, stride=8, padding=0)
        # Output is mu and log(var) for reparameterization trick used in VAEs
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

    def forward(self, img):
        return self.feature_extractor(img)
#         out = self.feature_extractor(img)
#         out = self.pooling(out)
#         out = out.view(out.size(0), -1)
#         mu = self.fc_mu(out)
#         logvar = self.fc_logvar(out)
#         return mu, logvar

In [39]:
# discriminator 
class MultiDiscriminator(nn.Module):
    def __init__(self, input_shape):
        super(MultiDiscriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_filters, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        channels, _, _ = input_shape
        # Extracts discriminator models
        self.models = nn.ModuleList()
        for i in range(3):
            self.models.add_module(
                "disc_%d" % i,
                nn.Sequential(
                    *discriminator_block(channels, 64, normalize=False),
                    *discriminator_block(64, 128),
                    *discriminator_block(128, 256),
                    *discriminator_block(256, 512),
                    nn.Conv2d(512, 1, 3, padding=1)
                ),
            )
            
        self.downsample = nn.AvgPool2d(channels, stride=2, padding=[1, 1], count_include_pad=False)

#         self.downsample = nn.AvgPool2d(in_channels, stride=2, padding=[1, 1], count_include_pad=False)

    def compute_loss(self, x, gt):
        """Computes the MSE between model output and scalar gt"""
        loss = sum([torch.mean((out - gt) ** 2) for out in self.forward(x)])
        return loss

    def forward(self, x):
        outputs = []
        for m in self.models:
            outputs.append(m(x))
            x = self.downsample(x)
        return outputs


## 5. Training 

In [22]:
# Create sample and checkpoint directories
os.makedirs("checkpoint", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)



#### Debugging 

In [68]:
# create small data loaders 
batch_size = 10
train_dl = DataLoader(train_couch_ds, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_couch_ds, batch_size=batch_size)

In [69]:
train_x_s, train_x_r = next(iter(train_dl))

In [70]:
train_x_s.shape

torch.Size([10, 3, 240, 400])

In [71]:
train_x_r.shape

torch.Size([10, 3, 240, 400])

In [72]:
input_shape = (3, 240, 400)  # TODO: dynamically grab this 

# used default parameters for this 
latent_dim = 8 
lr = 0.01
b1 = 0.5  # adam: decay of first order momentum of gradient
b2 = 0.999  # adam: decay of first order momentum of gradient


In [73]:
# Loss functions
mae_loss = torch.nn.L1Loss()

# Initialize generator, encoder and discriminators
generator = Generator(latent_dim, input_shape).cuda()
encoder = Encoder(latent_dim, input_shape).cuda()
D_VAE = MultiDiscriminator(input_shape).cuda()
D_LR = MultiDiscriminator(input_shape).cuda()

In [74]:
# Initialize weights at start of epoch training (epoch = 0)
generator.apply(weights_init_normal)
D_VAE.apply(weights_init_normal)
D_LR.apply(weights_init_normal)

MultiDiscriminator(
  (models): ModuleList(
    (disc_0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2)
      (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (4): LeakyReLU(negative_slope=0.2)
      (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (6): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (7): LeakyReLU(negative_slope=0.2)
      (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (9): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (10): LeakyReLU(negative_slope=0.2)
      (11): Conv2d(512, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (disc_1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1

In [75]:
# Optimizers
optimizer_E = torch.optim.Adam(encoder.parameters(), lr=lr, betas=(b1, b2))
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_VAE = torch.optim.Adam(D_VAE.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_LR = torch.optim.Adam(D_LR.parameters(), lr=lr, betas=(b1, b2))



In [76]:
Tensor = torch.cuda.FloatTensor

In [77]:
# validation metrics 
def val_metrics(generator, valid_dl):
    # stop training generator
    generator.eval()
    
    for x_sketch, x_real in valid_dl:
        sampled_z = Variable(Tensor(np.random.normal(0, 1, (latent_dim,
                                                            latent_dim))))
        x_sketch_variable = Variable(x_sketch)
        # generate samples 
        fake_from_sketch = geneator(x_sketch_variable, sampled_z)
        
        # concatenate samples horizontally
        fake_from_sketch = torch.cat([x for x in fake_from_sketch.data.cpu()], -1)
        
        img_sample = torch.cat((x_sketch_variable, fake_from_sketch), -1)
        img_sample = img_sample.view(1, *img_sample.shape)
        
        # Concatenate with previous samples vertically
        img_samples = img_sample if img_samples is None else torch.cat((img_samples, img_sample), -2)
    save_image(img_samples, "fake_imgs/%s/%s.png" % ("couch_from_sketch", batches_done), nrow=8, normalize=True)
    generator.train()
        
    

In [78]:
def reparameterization(mu, logvar):
    std = torch.exp(logvar / 2)
    sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), opt.latent_dim))))
    z = sampled_z * std + mu
    return z

In [79]:
# Adversarial loss
valid = 1
fake = 0

In [81]:
# debugging training 
for x_sketch, x_real in zip(train_x_s, train_x_r):
    x_sketch_variable = Variable(x_sketch)
    x_real_variable = Variable(x_real)
    print(x_sketch_variable)
    print(x_real_variable)
    # train generator and encoder 
#     optimizer_E.zero_grad()
#     optimizer_G.zero_grad()
    
#     # cVAE-GAN
    
#     # Produce output using encoding of real image 
#     mu, logvar = encoder(x_real_variable)
#     encoded_z = reparameterization(mu, logvar)
    
    
    
    
    

tensor([[[1.0000, 0.9843, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 0.9922, 0.9686,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

        [[1.0000, 0.9843, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 0.9922, 0.9686,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

        [[1.0000, 0.9843, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 0.9922, 0.9686,  ..., 1.0000, 1.0000, 1.