In [14]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [276]:
from collections import defaultdict
import cv2
from itertools import chain
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils

from config import *
from htools import hdir
from models import BaseModel, conv_block, Discriminator
from torch_datasets import sketch_dl, photo_dl
from utils import render_samples, show_img

In [17]:
# TEMPORARY, IMPORTED FROM CONFIG - JUST FOR EASY REFERENCE

bs = 64                # Batch size (paper uses 128).
img_size = 64          # Size of input (here it's 64 x 64).
workers = 2            # Number of workers for data loader.
input_c = 100          # Depth of input noise (1 x 1 x noise_dim). AKA nz.
ngf = 64               # Filters in first G layer.
ndf = 64               # Filters in first D layer.
lr = 2e-4              # Recommended learning rate of .0002.
beta1 = .5             # Recommended parameter for Adam.
nc = 3                 # Number of channels of input image.
ngpu = 1               # Number of GPUs to use.
sample_dir = 'samples' # Directory to store sample images from G. 
weight_dir = 'weights' # Directory to store model weights.
device = torch.device('cuda:0' if torch.cuda.is_available() and ngpu > 0 
                      else 'cpu')

In [213]:
class ResBlock(nn.Module):
    """Residual block to be used in CycleGenerator. Note that the relu or 
    leaky must still be applied on the output.
    """
    
    def __init__(self, c_in, num_layers=2, leak=.02):
        """
        Parameters
        -----------
        c_in: int
            # of input channels.
        num_layers: int
            Number of conv blocks inside the skip connection (default 2). 
            ResNet paper notes that skipping a single layer did not show
            noticeable improvements.
        """
        super().__init__()
        self.leak = leak
        self.layers = nn.ModuleList([conv_block(False, c_in, c_in, 3, 1, 1) 
                                     for i in range(num_layers)])
    
    def forward(self, x):
        x_out = x
        for layer in self.layers:
            x_out = F.leaky_relu(layer(x_out), self.leak)
        return x + x_out

In [215]:
class CycleGenerator(BaseModel):
    """CycleGAN Generator."""

    def __init__(self, img_c=3, ngf=64, leak=.02):
        """
        Parameters
        -----------
        img_c: int
            # of channels of input image.
        ngf: int
            # of channels in first convolutional layer.
        leak: float
            Slope of leaky relu where x < 0. Leak of 0 is regular relu.
        """
        super().__init__()
        self.leak = leak
        self.activation = nn.LeakyReLU(self.leak)

        # ENCODER
        # 3 x 64 x 64 -> 64 x 32 x 32
        deconv1 = conv_block(False, img_c, ngf, f=4, stride=2, pad=1)
        # 64 x 32 x 32 -> 128 x 16 x 16
        deconv2 = conv_block(False, ngf, ngf*2, 4, 2, 1)
        self.encoder = nn.Sequential(deconv1, 
                                     self.activation,
                                     deconv2,
                                     self.activation)

        # TRANSFORMER
        # 128 x 16 x 16 -> 128 x 16 x 16
        res1 = ResBlock(ngf*2, num_layers=2, leak=self.leak)
        # 128 x 16 x 16 -> 128 x 16 x 16
        res2 = ResBlock(ngf*2, 2, self.leak)
        self.transformer = nn.Sequential(res1,
                                         self.activation,
                                         res2,
                                         self.activation)

        # DECODER
        # 128 x 16 x 16 -> 64 x 32 x 32
        deconv1 = conv_block(True, ngf*2, ngf, f=4, stride=2, pad=1)
        # 64 x 32 x 32 -> 3 x 64 x 64
        deconv2 = conv_block(True, ngf, img_c, 4, 2, 1)
        self.decoder = nn.Sequential(deconv1, 
                                     self.activation,
                                     deconv2,
                                     nn.Tanh())

        # Module list of Sequential objects is helpful if we want to use 
        # different learning rates per group.
        self.groups = nn.ModuleList([self.encoder,
                                     self.transformer,
                                     self.decoder])

    def forward(self, x):
        for group in self.groups:
            x = group(x)
        return x

In [226]:
# class ResNetDiscriminator(BaseModel):
    
#     def __init__(self, img_c=3, ndf=64):
#         super().__init__()
#         self.conv1 = conv_block(False, img_c, ndf, f=4, stride=2, pad=1)
        
#     def forward(self):
#         pass

## Check - photo and sketch dataloaders may not always have same batch size?

In [277]:
def train_cycle_gan(epochs, x_dl, y_dl, lr=2e-4, b1=.5, use_labels=False, 
                    models=None):
    """Train cycleGAN with Adam optimizer. The naming conventin G_xy will be
    used to refer to a generator that converts from set x to set y, while
    D_x refers to a discriminator that classifies examples as actually 
    belonging to set x (class 1) or being a model-generated example (class 0).
    
    Parameters
    -----------
    use_labels: bool
        Specifies whether to use class labels (i.e. horse, zebra, giraffe). If
        False, D only tries to predict if it is a real or fake example 
        (e.g. photo or sketch2photo).
    """
    # Create models.
    if not models:
        G_xy = CycleGenerator().to(device)
        G_yx = CycleGenerator().to(device)
        D_x = Discriminator().to(device)
        D_y = Discriminator().to(device)
    else:
        G_xy, G_yx, D_x, D_y = models
    # Models should stay in training mode throughout.
    G_xy.train()
    G_yx.train()
    D_x.train()
    D_y.train()
    
    # Create optimizers.
    optim_g = torch.optim.Adam(chain(G_xy.parameters(), G_yx.parameters()), 
                               lr, betas=(b1, .999))
    optim_d = torch.optim.Adam(chain(D_x.parameters(), D_y.parameters()),
                               lr, betas=(b1, .999))    
    
    # Define loss function.
    if use_labels:
        criterion = nn.BCELoss(reduction='mean')
    else:
        criterion = nn.MSELoss(reduction='mean')
        
    # Set fixed examples for sample generation.
    fixed_x = next(iter(x_dl))[0]
    fixed_y = next(iter(y_dl))[0]
    
    stats = defaultdict()
    for epoch in range(1, epochs+1):
        G_xy.train()
        G_yx.train()
        D_x.train()
        D_y.train()
        
        for i, ((x, x_labels), (y, y_labels)) in enumerate(zip(x_dl, y_dl)):
            x = x.to(device)
            y = y.to(device)
            batch_len = x.shape[0]
            labels_real = torch.ones(batch_len, device=device)
            labels_fake = torch.zeros(batch_len, device=device)
            
            ### TEST
            print('fixed shapes', fixed_x.shape, fixed_y.shape)
            print('batch lens', batch_len, y.shape[0])
            print('img shapes', x.shape, y.shape)
            print('label_shapes', x_labels.shape, y_labels.shape)
            
            
#             ##################################################################
#             # Train D_x and D_y.
#             ##################################################################
#             # Train D's on real images.
#             optim_d.zero_grad()
#             pred_x, pred_y = D_x(x), D_y(y)
#             loss_dx = criterion(pred_x, labels_real)
#             loss_dy = criterion(pred_y, labels_real)
#             loss_d_real = loss_dx + loss_dy
#             loss_d_real.backward()
#             optim_d.step()
            
#             # Train D's on fake images.
#             optim_d.zero_grad()
#             x_fake, y_fake = G_yx(y), G_xy(x)
#             pred_x, pred_y = D_x(x_fake), D_y(y_fake)
#             loss_dx_fake = criterion(pred_x, labels_fake)
#             loss_dy_fake = criterion(pred_y, labels_fake)
#             loss_d_fake = loss_dx_fake + loss_dy_fake
#             loss_d_fake.backward()
#             optim_d.step()
            
#             ##################################################################
#             # Train G_xy and G_yx.
#             ################################################################## 
#             # Stage 1: x -> y -> x
#             optim_g.zero_grad()
#             y_fake = G_xy(x)
#             pred_y = D_y(y_fake)
#             loss_g = criterion(pred_y, labels_real)
            
#             x_recon = G_yx(y_fake)
#             pred_x = D_x(x_recon)
#             loss_g_cycle = criterion(pred_x, labels_real)
#             loss_g_total = loss_g + loss_g_cycle
#             loss_g_total.backward()
#             optim_g.step()
            
#             # Stage2: y -> x -> y
#             optim_g.zero_grad()
#             x_fake = G_yx(y)
#             pred_x = D_x(x_fake)
#             loss_g = criterion(pred_x, labels_real)
            
#             y_recon = G_xy(x_fake)
#             pred_y = D_y(y_recon)
#             loss_g_cycle = criterion(pred_y, labels_real)
#             loss_g_total = loss_g + loss_g_cycle
#             loss_g_total.backward()
#             optim_g.step()
        
        # Print results for last mini batch of epoch.
        print(f'Epoch [{epoch}/{epochs}])')
    
    return stats

In [251]:
train_cycle_gan(2, photo_dl, sketch_dl)

In [265]:
G_xy = CycleGenerator(img_c, ngf)
G_yx = CycleGenerator()

In [259]:
D = Discriminator(ndf)

In [219]:
len(G.dims())

24

In [261]:
x.shape

torch.Size([2, 3, 4, 4])

In [272]:
y_hat = G_xy(x)
y_hat.shape

torch.Size([2, 3, 4, 4])

In [271]:
G_yx(y_hat).shape

torch.Size([2, 3, 4, 4])

In [278]:
train_cycle_gan(2, photo_dl, sketch_dl)

fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])


fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])


fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])


fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])


fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])


fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])


fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])


fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])


fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])


fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])
fixed shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
batch lens 64 64
img shapes torch.Size([64, 3, 64, 64]) torch.Size([64, 3, 64, 64])
label_shapes torch.Size([64]) torch.Size([64])


defaultdict(None, {})

In [279]:
len(photo_dl)

196

In [280]:
len(sketch_dl)

1180