In [6]:
# general imports
import numpy as np
import time
import os
import math
import matplotlib.pyplot as plt

# torch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

import torchvision.utils as vutils

# tensorboardX
from tensorboardX import SummaryWriter

In [9]:
# TODO: Download scripts and data
!mkdir data
!mkdir data/images

from format import print_iter

In [8]:
image_size = (224, 224)
batch_size = 16
num_workers = 4

dataset = ImageFolder("data/images", transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
dataloader = Dataloader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)


FileNotFoundError: [Errno 2] No such file or directory: 'data/images'

In [None]:
def get_padding(output_dim, input_dim, kernel_size, stride):
    """
    Calculates padding given in output and input dim, and parameters of the convolutional layer

    Arguments should all be integers. Use this function to calculate padding for 1 dimesion at a time.
    Output dimensions should be the same or bigger than input dimensions

    Returns 0 if invalid arguments were passed, otherwise returns an int or tuple that represents the padding.
    """

    padding = (((output_dim - 1) * stride) - input_dim + kernel_size) // 2

    if padding < 0:
        return 0
    else:
        return padding

def gen_block(input_channels, output_channels, kernel_size, stride, padding):
    layers = []
    layers += [nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride=stride, padding=padding, bias=False)]
    layers += [nn.BatchNorm2d(output_channels)]
    layers += [nn.ReLU(inplace=True)]
    
    return layers
    
class Generator(nn.Module):
    def __init__(self, channels=3, input_size=100, output_dim=64):
        super(Generator, self).__init__()
        self.channels = channels
        self.input_size = input_size
        self.output_size = output_size
        self.layers = build_layers()
        
    def forward(self, x):
        return self.layers(x).squeeze()
    
    def build_layers(self):
        layers = []
        in_c = self.input_size
        out_c = self.output_size * 8
        
        # dim: out_c x 4 x 4
        layers += gen_block(in_c, out_c, 4, 1, 0)
        in_c = out_c
        out_c = self.output_size * 4
        
        # dim: out_c x 8 x 8
        layers += gen_block(in_c, out_c, 4, 2, 1)
        in_c = out_c
        out_c = self.output_size * 2
        
        # dim: out_c x 16 x 16
        layers += gen_block(in_c, out_c, 4, 2, 1)
        in_c = out_c
        out_c = self.output_size
        
        # dim: out_c x 32 x 32
        layers += gen_block(in_c, out_c, 4, 2, 1)
        in_c = out_c
        out_c = self.channels
        
        # dim: out_c x 64 x 64
        layers += [nn.ConvTranspose2d(in_c, out_c, 4, 2, 1), nn.Tanh()]
        
        return nn.Sequential(*layers)

In [None]:
def discrim_block(input_channels, output_channels, kernel_size, stride, padding):
    layers = []
    layers += [nn.Conv2d(input_channels, output_channels, kernel_size, stride=stride, padding=padding, bias=False)]
    layers += [nn.BatchNorm2d(output_channels)]
    layers += [nn.LeakyReLU(0.2, inplace=True)]
    
    return layers

def Discriminator(nn.Module):
    def __init__(self, channels=3, input_dim=64):
        super(Discriminator, self).__init__()
        self.channels = channels
        self.input_dim = input_dim
        self.layers = build_layers()
        
    def forward(self, x):
        return self.layers(x).squeeze()
    
    def build_layers(self):
        layers = []
        in_c = self.channels
        out_c = self.input_dim
        
        # dim: out_c x 32 x 32
        layers += discrim_block(in_c, out_c, 4, 1, 0)
        in_c = out_c
        out_c = self.input_dim * 2
        
        # dim: out_c x 16 x 16
        layers += discrim_block(in_c, out_c, 4, 2, 1)
        in_c = out_c
        out_c = self.input_dim * 4
        
        # dim: out_c x 8 x 8
        layers += discrim_block(in_c, out_c, 4, 2, 1)
        in_c = out_c
        out_c = self.input_dim * 8
        
        # dim: out_c x 4 x 4
        layers += discrim_block(in_c, out_c, 4, 2, 1)
        in_c = out_c
        out_c = self.channels
        
        # dim: out_c x 64 x 64
        layers += [nn.Conv2d(in_c, 1, 4, 1, 0), nn.Sigmoid()]
        
        return nn.Sequential(*layers)
        

In [None]:
start_time = time.strftime("%a_%b_%d_%Y_%H:%M", time.localtime())

gen_input = 100
gen_output = 64

gen = Generator(input_size=gen_input, output_dim=gen_output)
discrim = Discriminator()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device: {}".format(device))
gen.to(device)
discrim.to(device)

learn_rate = 0.001

optG = optim.Adam(gen.parameters(), lr=learn_rate)
optD = optim.Adam(discrim.parameters(), lr=learn_rate)

criterion = nn.BCELoss()

fixed_noise = torch.randn(gen_output, gen_input, 1, 1, device=device)

real_label = 0.9
fake_label = 0.1

summary(gen, (gen_input))
summary(discrim, (gen_output, gen_output, 3))
writer = SummaryWriter('tensorboard_logs/run_{}'.format(start_time))

In [None]:
epochs = 20
print_step = 50

gen_imgs = []

for e in range(epochs):
    g_train_loss = 0
    d_train_loss = 0
    
    for i, data in enumerate(dataloader):

        # Train Discriminator
        
        # only need images from data, don't care about class from ImageFolder
        real_images = data[0].to(device)
        b_size = real_images.size(0)
        real_labels = torch.full((b_size,), real_label, device=device)
        # get fake data from generator
        noise = torch.randn(b_size, gen_input, 1, 1, device=device)
        fake_images = gen(noise)
        fake_labels = torch.full((b_size,), fake_label, device=device)
        
        # concat and shuffle real/fake images
        images = torch.cat(real_images, fake_images, axis=0)
        labels = torch.cat(real_labels, fake_labels, axis=0)
        shuffle_i = np.arange(len(images))
        np.random.shuffle(shuffle_i)
        images = images[shuffle_i]
        labels = labels[shuffle_i]
        
        # calculate loss and update gradients
        discrim.zero_grad()
        d_output = discrim(images).view(-1)
        d_loss = criterion(d_output, labels)
        d_loss.backward()
        optD.step()
        
        d_train_loss += d_loss.item()
        
        # Train Generator
        gen.zero_grad()
        # get new output from discriminator for fake images
        d_output = discrim(fake_images).view(-1)
        # calculate the Generator's loss based on this, use real_labels since fake images are real for generator
        g_loss = criterion(d_output, real_labels)
        g_loss.backward()
        optG.step()
        
        g_train_loss += g_loss.item()
        
        if i % print_step == 0:
            print_iter(curr_epoch=e, epochs=epochs, batch_i=i, num_batches=len(dataloader), d_loss=d_train_loss/(i+1), g_loss=g_train_loss/(i+1))
    
    print_iter(curr_epoch=e, epochs=epochs, writer=writer, d_loss=d_train_loss/(i+1), g_loss=g_train_loss/(i+1))
    # save example images
    gen.eval()
    with torch.no_grad():
        fake_images = gen(fixed_noise).detach().cpu()
    gen.train()
    gen_imgs.append(vutils.make_grid(fake_images, padding=2, normalize=True))

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()