# Objective:

This is a first attempt at creating a simple GAN model following the documentation of *Deep Learning with Pytorch* by Vishnu Subramanian and this Medium article: https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f

Certain model parameters and structure will be based off of the pix2pix model implemented in *Image-to-Image Translation with Conditional Adversarial Networks* by Isola, et al.  


In [1]:
import pandas as pd 
import numpy as np

from pathlib import Path
import pickle
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import itertools
import time 
import cv2
import os

## 1. Load re-sized images and their corresponding sketches 

In [2]:
!pwd

/home/ec2-user/couch_gan


In [3]:
# assign paths 
real_img_path = Path('./data/og_resized')
sketch_img_path = Path('./data/sketch_resized')
fake_img_path = Path('./data/fake')

In [4]:
# get files into list  
real_img_files = [f for f in real_img_path.iterdir()]
sketch_img_files = [f for f in sketch_img_path.iterdir()]

## 2. Get image names of training and test files (assigned by Kaggle)

In [5]:
test_file_names = [f'00000{i}.jpg' for i in range(203, 303)]

In [6]:
# assign paths to appropriate lists 
train_real_paths = []
train_sketch_paths = []
test_real_paths = []
test_sketch_paths = []

for f in real_img_files:
    if f.parts[-1] in test_file_names:
        test_real_paths.append(f)
    else:
        train_real_paths.append(f)

for f in sketch_img_files:
    if f.parts[-1] in test_file_names:
        test_sketch_paths.append(f)
    else:
        train_sketch_paths.append(f)        

In [7]:
# sanity check that the order of real and sketch paths are the same 
counter = 0
for r, f in zip(train_real_paths, train_sketch_paths):
    if r.parts[-1] == r.parts[-1]:
        counter += 1
assert(counter == 900)

counter = 0
for r, f in zip(test_real_paths, test_sketch_paths):
    if r.parts[-1] == r.parts[-1]:
        counter += 1
assert(counter == 100)

## 3. Couch Dataset  and Dataloaders 

Note: No data augmentation 

In [8]:
def getX(path):
    x = cv2.imread(str(path)).astype(np.float32)  # reading image 
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)/255  # convert from BGR to RGB
    return x 

In [9]:
def normalize(im):
    """Normalizes images """
    imagenet_stats = np.array([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]])
    return (im - imagenet_stats[0])/imagenet_stats[1]

In [10]:
class CouchDataset(Dataset):
    def __init__(self, real_files, sketch_files):
        self.real_files = real_files
        self.sketch_files = sketch_files
        
    def __len__(self):
        return len(self.real_files)  # should be same number as sketch files 
        
    def __getitem__(self, idx):
        real_path = self.real_files[idx]
        sketch_path = self.sketch_files[idx]
        img_name = sketch_path.parts[-1]
        
        x_real = getX(real_path)
        x_sketch = getX(sketch_path)
        
        # normalize 
        x_real = normalize(x_real)
        x_sketch = normalize(x_sketch)
        
        # roll axis (channels, height, width)
        x_real = np.rollaxis(x_real, 2)
        x_sketch = np.rollaxis(x_sketch, 2)
        
        return x_sketch, x_real
        

In [11]:
# make datasets 
train_couch_ds = CouchDataset(real_files=train_real_paths, sketch_files=train_sketch_paths)
test_couch_ds = CouchDataset(real_files=test_real_paths, sketch_files=test_sketch_paths)



In [12]:
# inspect content 
x_s, x_r = train_couch_ds[20]


In [13]:
x_s, x_r

(array([[[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]],
 
        [[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]],
 
        [[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]]]),
 array([[[ 0.87450981,  0.89019608,  0.9137255 , ...,  0.32549024,
           0.3176471 ,  0.3176471 ],
         [ 0.88235295,  0.89803922,  0.9137255 , ..., -0.09019607,
          -0.03529412,  0.0196079 ],
         [ 0.84313726

In [14]:
# 3 channels for RGB
# 240 rows (height)
# 400 columns (width)
x_s.shape, x_r.shape

((3, 240, 400), (3, 240, 400))

## 4. Model  

In [15]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import torch.nn as nn
import torch.nn.functional as F
import torch

from torchvision.models import resnet18

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

In [16]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [17]:
# based on Pix2pix GAN 
## https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/pix2pix/models.py

class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)

        return x


class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GeneratorUNet, self).__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        # U-Net generator with skip connections from encoder to decoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)




In [18]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

## 5. Training 

In [19]:
# Create sample and checkpoint directories
os.makedirs("checkpoint", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)
osos.makedirs("fake_imgs", exist_ok=True)


#### Debugging 

In [20]:
# create small data loaders 
batch_size = 10
train_dl = DataLoader(train_couch_ds, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_couch_ds, batch_size=batch_size)

In [21]:
train_x_s, train_x_r = next(iter(train_dl))

In [22]:
train_x_s.shape

torch.Size([10, 3, 240, 400])

In [23]:
train_x_r.shape

torch.Size([10, 3, 240, 400])

In [30]:
input_shape = (3, 240, 400)  # TODO: dynamically grab this 

# used default parameters for this 
latent_dim = 8 
lr = 0.01
b1 = 0.5  # adam: decay of first order momentum of gradient
b2 = 0.999  # adam: decay of first order momentum of gradient


In [24]:
# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

In [26]:
# Calculate output of image discriminator (PatchGAN)
img_height = 240
img_width = 400
patch = (1, img_height // 2 ** 4, img_width // 2 ** 4)


In [28]:
# Initialize generator and discriminator
generator = generator.cuda()
discriminator = discriminator.cuda()
criterion_GAN.cuda()
criterion_pixelwise.cuda()

L1Loss()

In [29]:
# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

Discriminator(
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): ZeroPad2d(padding=(1, 0, 1, 0), value=0.0)
    (12): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
  )
)

In [31]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))


In [32]:
# Tensor type
Tensor = torch.cuda.FloatTensor 

In [36]:
def sample_images(batches_done):
    """Saves a generated sample from the validation set"""
    x_sketch, x_real = next(iter(test_dl))
    x_sketch_var = Variable(x_sketch.type(Tensor))
    x_real_var = Variable(x_real.type(Tensor))
    fake_from_sketch = generator(x_sketch_var)
    
    img_sample = torch.cat((x_sketch_var.data, fake_from_sketch.data, x_real_var.data), -2)
    save_image(img_sample, "fake_imgs/%s/%s.png" % ("pix2pix", batches_done), nrow=5, normalize=True)
    
    
    

In [38]:
train_x_sketch, train_x_real = next(iter(train_dl))    

In [46]:
for x_sketch, x_real in zip(train_x_sketch, train_x_real):
    # model inputs 
    x_sketch_var = Variable(x_sketch.type(Tensor))
    x_real_var = Variable(x_real.type(Tensor))
    
    
    # adveserial ground truths 
    # iffy about size; should it be the number of channels
    valid = Variable(Tensor(np.ones((x_sketch_var.size(0), *patch))), requires_grad=False)
    fake = Variable(Tensor(np.zeros((x_sketch_var.size(0), *patch))), requires_grad=False)
    
    # train generators 
    optimizer_G.zero_grad()
    
    # GAN loss
    fake_sketch = generator(x_real_var)
    pred_fake = discriminator(fake_sketch, x_real_var)
    loss_GAN = criterion_GAN(pred_fake, valid)
    
    # Pixel-wise loss
    loss_pixel = criterion_pixelwise(fake_B, real_B)

    # Total loss
    loss_G = loss_GAN + lambda_pixel * loss_pixel

    loss_G.backward()

    optimizer_G.step()
    
    
    

RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3 4, but got 3-dimensional input of size [3, 240, 400] instead

## with new dataset 

In [48]:
class CouchDataset_v2(Dataset):
    def __init__(self, real_files, sketch_files):
        self.real_files = real_files
        self.sketch_files = sketch_files
        
        self.sketch_transform = transforms.Compose(
                [transforms.Resize(input_shape[-2:], Image.BICUBIC),
                 transforms.ToTensor(),
                 transforms.Normalize((0.5,), (0.5,))
                ])
        
        self.real_transform = transforms.Compose(
                [
                    transforms.Resize(input_shape[-2:], Image.BICUBIC),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                ]
            )
        
    def __len__(self):
        return len(self.real_files)  # should be same number as sketch files 
        
    def __getitem__(self, idx):
        real_path = self.real_files[idx]
        sketch_path = self.sketch_files[idx]
        img_name = sketch_path.parts[-1]
        
        x_real = getX(real_path)
        x_sketch = getX(sketch_path)

        x_sketch = self.sketch_transform(x_sketch)
        x_real = self.real_transform(x_real)
        
        return {"sketch": x_sketch, "real": x_real}
                

In [49]:
# make datasets 
train_couch_ds = CouchDataset_v2(real_files=train_real_paths, sketch_files=train_sketch_paths)
test_couch_ds = CouchDataset_v2(real_files=test_real_paths, sketch_files=test_sketch_paths)


In [50]:
train_couch_ds[0]

TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>