In [1]:
import torch
import torch.nn as nn
import numpy as np
import random
import matplotlib.pyplot as plt
from import_utils import *

gan_file_path = '../models/DCGAN.py'

Discriminator, Generator, Encoder = import_names(gan_file_path, 'DCGAN', 
                                        ['Discriminator', 'Generator', 'Encoder'])

In [2]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import torchvision.utils as vutils

SEED = 0xDEADF00D
IMAGE_SIZE = 64
TEST_SIZE = 0.2

BATCH_SIZE = 16
NUM_EPOCHS = 10000
LR = 0.0002
LAMBDA = 0.3

EMBEDDING_SIZE = 100

# torch.backends.cudnn.deterministic = True
# torch.manual_seed(SEED)
# torch.cuda.manual_seed_all(SEED)
# np.random.seed(SEED)
# random.seed(SEED)

my_data_path = '../../data/faces'
my_dataset = ImageFolder(my_data_path,
                      transform=transforms.Compose([
                          transforms.Resize(IMAGE_SIZE),
                          transforms.CenterCrop(IMAGE_SIZE),
                          transforms.ToTensor(),
                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                      ]))
my_data_loader = DataLoader(my_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=3)

den_data_path = '../../data/den_faces'
den_dataset = ImageFolder(den_data_path,
                      transform=transforms.Compose([
                          transforms.Resize(IMAGE_SIZE),
                          transforms.CenterCrop(IMAGE_SIZE),
                          transforms.ToTensor(),
                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                      ]))
den_data_loader = DataLoader(den_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=3)

In [None]:
batch = [elem for elem in next(iter(my_data_loader))[0]]

plt.figure(figsize=(6,6))
plt.imshow(np.transpose(vutils.make_grid(batch, nrow=int(np.sqrt(BATCH_SIZE)), normalize=True), [1,2,0]));

In [None]:
batch = [elem for elem in next(iter(den_data_loader))[0]]

plt.figure(figsize=(6,6))
plt.imshow(np.transpose(vutils.make_grid(batch, nrow=int(np.sqrt(BATCH_SIZE)), normalize=True), [1,2,0]));

In [5]:
from torch.optim import Adam

USE_CUDA = torch.cuda.is_available()
DTYPE = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor

my_gen = Generator().cuda() if USE_CUDA else Generator()
den_gen = Generator().cuda() if USE_CUDA else Generator()
discr = Discriminator().cuda() if USE_CUDA else Discriminator()
enc = Encoder().cuda() if USE_CUDA else Encoder()

opt_my_g = Adam(my_gen.parameters(), lr=LR)
opt_den_g = Adam(den_gen.parameters(), lr=LR)
opt_d = Adam(discr.parameters(), lr=LR)
opt_e = Adam(enc.parameters(), lr=LR)

In [6]:
def get_grad_norm(model):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
    return total_norm ** (1. / 2)

def plot_results():
    my_batch = [elem for elem in next(iter(my_data_loader))[0]][:int(np.sqrt(BATCH_SIZE))]
    den_batch = [elem for elem in next(iter(den_data_loader))[0]][:int(np.sqrt(BATCH_SIZE))]
    my_batch = torch.stack(my_batch).type(DTYPE)
    den_batch = torch.stack(den_batch).type(DTYPE)
    
    with torch.no_grad():
        my_enc, den_enc = enc(my_batch), enc(den_batch)
        my_fake_batch = torch.cat([my_gen(my_enc), den_gen(my_enc)]).cpu()
        den_fake_batch = torch.cat([den_gen(den_enc), my_gen(den_enc)]).cpu()
        my_batch, den_batch = my_batch.cpu(), den_batch.cpu()

    batch = torch.cat([my_batch, my_batch, my_fake_batch, 
                       den_batch, den_batch, den_fake_batch])
    plt.figure(figsize=(12,8))
    plt.imshow(np.transpose(vutils.make_grid(batch, nrow=int(np.sqrt(BATCH_SIZE)) * 2, 
                                             normalize=True), [1,2,0]))
    plt.show()
    
criterion = nn.BCELoss()
mae_loss = nn.L1Loss()

In [7]:
def train_step(batch, _discr, _gen, _enc, _opt_d, _opt_g, _opt_e):
    real_batch = batch.type(DTYPE)
    
    # Discriminator step
    _discr.zero_grad()
    out_real = _discr(real_batch).squeeze()
    out_fake = _discr(_gen(_enc(real_batch))).squeeze()

    real_labels = .995 + torch.rand(real_batch.shape[0]).type(DTYPE) * .005
    fake_labels = torch.rand(real_batch.shape[0]).type(DTYPE) * .005

    loss_d = criterion(out_real, real_labels)
    loss_d += criterion(out_fake, fake_labels)
    loss_d.backward()
    _opt_d.step()

    # Generator + encoder step
    _gen.zero_grad()
    _enc.zero_grad()

    reconstruction = _gen(_enc(real_batch))
    out_discr = _discr(reconstruction).squeeze()

    loss_g = criterion(out_discr, real_labels)
    reconstruction_loss = mae_loss(reconstruction, real_batch)
    f_loss = LAMBDA * reconstruction_loss + loss_g
    f_loss.backward()
    _opt_g.step()
    _opt_e.step()
    
    return loss_d, loss_g, reconstruction_loss


In [None]:
from time import time
from tqdm import tqdm
from IPython import display

den_gen_loss, my_gen_loss, discr_loss, enc_loss = [], [], [], []
start_time = time()

for epoch in tqdm(range(NUM_EPOCHS)):
    tmp_den_gen_loss, tmp_my_gen_loss, tmp_discr_loss, tmp_enc_loss = [], [], [], []
    
    for (my_batch, _), (den_batch, _) in zip(my_data_loader, den_data_loader):
        my_d_loss, my_g_loss, my_e_loss = train_step(my_batch, discr, my_gen, 
                                                     enc, opt_d, opt_my_g, opt_e)
        den_d_loss, den_g_loss, den_e_loss = train_step(den_batch, discr, den_gen, 
                                                        enc, opt_d, opt_den_g, opt_e)
        
        tmp_enc_loss.append((my_e_loss.item() + den_e_loss.item()) / 2)
        tmp_discr_loss.append((my_d_loss.item() + den_d_loss.item()) / 2)
        tmp_my_gen_loss.append(my_g_loss.item())
        tmp_den_gen_loss.append(den_g_loss.item())
    
    my_gen_loss.append(np.mean(tmp_my_gen_loss))
    den_gen_loss.append(np.mean(tmp_den_gen_loss))
    discr_loss.append(np.mean(tmp_discr_loss))
    enc_loss.append(np.mean(tmp_enc_loss))
    
    display.clear_output(wait=True)
    plot_results()
    plt.plot(my_gen_loss, label='My generator loss')
    plt.plot(den_gen_loss, label='Den generator loss')
    plt.plot(discr_loss, label='Discriminator loss')
    plt.show()

print(f'Training took {time() - start_time} seconds')