In [None]:
import os
import random
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from torch.autograd import grad
from torch.utils.data import DataLoader
from tqdm import tqdm

sys.path.insert(0, '../')

import metrics as mt
import settings as opt
import weights
from preprocess.dataset import ValenceArousalWithClassesDataset

In [None]:
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

# Choose model and dataset

# Select dataset to train on - possible choices: affective, augmented_affective, cifar10
dataset_choice = "affective"

# Select model to train on - WGAN, WGAN_D, WGAN_SN
model_choice = "WGAN"

import model.wgan as gan

In [None]:
gen_imgs_path = opt.outi + model_choice + '/' + dataset_choice + '/'
save_model_checkpoints = opt.outc + model_choice + '/' + dataset_choice + '/'

# create directories for output
try:
    os.makedirs(gen_imgs_path)
except OSError:
    pass

try:
    os.makedirs(save_model_checkpoints)
except OSError:
    pass

In [None]:
# Image preprocessing
pre_process = transforms.Compose([
    transforms.Resize(opt.image_size),
    transforms.CenterCrop(opt.image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

if dataset_choice == "cifar10":
    dataset = dset.CIFAR10(
        root="../data/cifar10",
        download=True,
        transform=pre_process
    )
    n_classes = 10
elif dataset_choice == "affective":
    dataset = ValenceArousalWithClassesDataset(csv_file=opt.annonations_file,
                                               root_dir=opt.all_images_path,
                                               transform=pre_process)
    n_classes = 13
else:
    dataset = ValenceArousalWithClassesDataset(csv_file=opt.augmented_annonations_file,
                                               root_dir=opt.augmented_images_path,
                                               transform=pre_process)
    n_classes = 13

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4,
                        drop_last=True)

In [None]:
def gradient_penalty(x, y, f):
    # interpolation
    shape = [x.size(0)] + [1] * (x.dim() - 1)
    alpha = torch.rand(shape).to(device)
    z = x + alpha * (y - x)

    # gradient penalty
    z = Variable(z, requires_grad=True).to(device)
    o = f(z)
    g = grad(o, z, grad_outputs=torch.ones(o.size()).to(device), create_graph=True)[0].view(
        z.size(0), -1)
    gp = ((g.norm(p=2, dim=1) - 1) ** 2).mean()

    return gp

In [None]:
# Create the GAN model

# Create or load the generator
generator = gan.Generator(opt.latent_dim, opt.num_channels, opt.ngf, opt.n_classes).to(device)
generator.apply(weights.weights_init)

# Create or load the discriminator
if model_choice == "WGAN":
    discriminator = gan.Discriminator(opt.num_channels, opt.ndf, opt.n_classes).to(device)
elif model_choice == "WGAN_D":
    discriminator = gan.DiscriminatorDropout(opt.num_channels, opt.ndf, opt.n_classes).to(device)
else:
    discriminator = gan.DiscriminatorSN(opt.num_channels, opt.ndf, opt.n_classes).to(device)

discriminator.apply(weights.weights_init)

# Loss functions
adversarial_loss = nn.BCELoss()

# Establish convention for real and fake labels during training
# with condition GAN, these fields are unnecessary
real_label = 0.9  # GAN trick, real examples are real in 90%
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(discriminator.parameters(), lr=opt.lr_D, betas=(opt.beta1, opt.beta2))
optimizerG = optim.Adam(generator.parameters(), lr=opt.lr_G, betas=(opt.beta1, opt.beta2))

In [None]:
z_sample = Variable(torch.randn(100, opt.latent_dim))
z_sample = z_sample.to(device)

In [None]:
# Training Loop

# Lists to keep track of progress
G_losses = []
D_losses = []

fid_score_history = []
kid_score_history = []

fake_images_list = []

global_step = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(opt.epochs):

    ##############################
    ### COMPUTE METRICS
    if epoch % 5 == 0:
        print("Global step: {}. Computing metrics...".format(global_step))

        # get random real samples
        samples = random.sample(range(len(dataset)), opt.fid_batch)
        real_samples = [dataset[s][0] for s in samples]
        real_samples = torch.stack(real_samples, dim=0).to(device)

        # generate random fake samples
        fake_samples = []

        noise = Variable(
            torch.FloatTensor(np.random.normal(0, 1, (opt.fid_batch, opt.latent_dim)))).to(
            device)
        gen_labels = Variable(
            torch.LongTensor(np.random.randint(0, n_classes, opt.fid_batch))).to(device)

        for k in tqdm(range(opt.fid_batch), desc="Generating fake images"):
            noise_ = noise[k * opt.batch_size: (k + 1) * opt.batch_size]
            gen_labels_ = gen_labels[k * opt.batch_size: (k + 1) * opt.batch_size]

            fake_samples.append(generator(noise_, gen_labels_))
        fake_samples = torch.cat(fake_samples, dim=0).to(device)

        print("Computing KID and FID...")
        kid, fid = mt.compute_metrics(real_samples, fake_samples)

        print("FID: {:.4f}".format(fid))

        print("KID: {:.4f}".format(kid))

        fid_score_history.append(fid)
        kid_score_history.append(kid)

    ##############################

    for i, (imgs, _) in enumerate(dataloader):
        # step
        step = epoch * len(dataloader) + i + 1

        # set train
        generator.train()

        # leafs
        imgs = Variable(imgs).to(device)
        bs = imgs.size(0)
        z = Variable(torch.randn(bs, opt.latent_dim)).to(device)

        f_imgs = generator(z)

        # train D
        r_logit = discriminator(imgs)
        f_logit = discriminator(f_imgs.detach())

        wd = r_logit.mean() - f_logit.mean()  # Wasserstein-1 Distance
        gp = gradient_penalty(imgs.data, f_imgs.data, discriminator)
        d_loss = -wd + gp * 10.0

        discriminator.zero_grad()
        d_loss.backward()
        optimizerD.step()

        # train G
        z = Variable(torch.randn(bs, opt.latent_dim)).to(device)
        f_imgs = generator(z)
        f_logit = discriminator(f_imgs)
        g_loss = -f_logit.mean()

        discriminator.zero_grad()
        generator.zero_grad()
        g_loss.backward()
        optimizerG.step()

        # Save Losses for plotting later
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())

        global_step += 1

        # Output training stats
        if i % opt.log_interval == 0:
            print(
                "[{}/{}][{}/{}] Loss_D: {:.4f} Loss_G: {:.4f}".format(
                    epoch,
                    opt.epochs,
                    i,
                    len(dataloader),
                    d_loss.item(),
                    g_loss.item()
                )
            )

    # Save generated images after each epoch
    noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (64, opt.latent_dim)))).to(device)
    gen_labels = Variable(torch.LongTensor(np.random.randint(0, n_classes, 64))).to(device)

    fake_images = generator(noise).detach().cpu()

    vutils.save_image(fake_images,
                      "%s/fake_%s_epoch_%03d.png" % (gen_imgs_path, dataset_choice, epoch),
                      normalize=True)

    # Save generator and discriminator weights after each 20 epochs
    if epoch % 20 == 0:
        torch.save(generator.state_dict(),
                   "%s/netG_%s_epoch_%d.pth" % (save_model_checkpoints, dataset_choice, epoch))
        torch.save(discriminator.state_dict(),
                   "%s/netD_%s_epoch_%d.pth" % (save_model_checkpoints, dataset_choice, epoch))

print("Training is finished!")

In [None]:
# results
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator loss during training")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# FID results
plt.figure(figsize=(10, 5))
plt.title("FID score")
plt.plot(fid_score_history)
plt.xlabel("epochs")
plt.ylabel("FID value")
plt.show()

In [None]:
# KID results
plt.figure(figsize=(10, 5))
plt.title("KID score")
plt.plot(kid_score_history)
plt.xlabel("epochs")
plt.ylabel("KID value")
plt.show()

In [None]:
print(fid_score_history)

In [None]:
print(kid_score_history)

In [None]:
print(G_losses)

In [None]:
print(D_losses)