In [None]:
import os
import random
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm

sys.path.insert(0, '../')

import metrics as mt
import settings as opt
import weights
from preprocess.dataset import ValenceArousalWithClassesDataset

In [None]:
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

# Choose model and dataset

# Select dataset to train on - possible choices: affective, augmented_affective, cifar10
dataset_choice = "affective"

# Select model to train on - ACGAN, ACGAN_D, ACGAN_SN
model_choice = "ACGAN"

import model.acgan as gan

In [None]:
gen_imgs_path = opt.outi + model_choice + '/' + dataset_choice + '/'
save_model_checkpoints = opt.outc + model_choice + '/' + dataset_choice + '/'

# create directories for output
try:
    os.makedirs(gen_imgs_path)
except OSError:
    pass

try:
    os.makedirs(save_model_checkpoints)
except OSError:
    pass

In [None]:
# Image preprocessing
pre_process = transforms.Compose([
    transforms.Resize(opt.image_size),
    transforms.CenterCrop(opt.image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

if dataset_choice == "cifar10":
    dataset = dset.CIFAR10(
        root="../data/cifar10",
        download=True,
        transform=pre_process
    )
    n_classes = 10
elif dataset_choice == "affective":
    dataset = ValenceArousalWithClassesDataset(csv_file=opt.annonations_file,
                                               root_dir=opt.all_images_path,
                                               transform=pre_process)
    n_classes = 13
else:
    dataset = ValenceArousalWithClassesDataset(csv_file=opt.augmented_annonations_file,
                                               root_dir=opt.augmented_images_path,
                                               transform=pre_process)
    n_classes = 13

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4,
                        drop_last=True)

In [None]:
# Method for storing generated images
def generate_imgs(generator, epoch):
    fixed_z = torch.randn(1, n_classes, opt.latent_dim)
    fixed_z = torch.repeat_interleave(fixed_z, 10, 0).reshape(-1, opt.latent_dim).to(device)
    fixed_label = torch.arange(0, n_classes)
    fixed_label = torch.repeat_interleave(fixed_label, 10).to(device)

    generator.eval()
    fake_imgs = generator(fixed_z, fixed_label)
    fake_imgs_ = vutils.make_grid(fake_imgs.to(device)[:n_classes * 10], padding=2, normalize=True,
                                  nrow=n_classes).cpu()

    vutils.save_image(fake_imgs_, os.path.join(gen_imgs_path, 'sample_' + str(epoch) + '.png'))

In [None]:
# Create the GAN model
generator = gan.Generator(latent_dim=opt.latent_dim, num_classes=n_classes,
                          num_channels=opt.num_channels)
generator.apply(weights.weights_init)

if model_choice == "ACGAN":
    discriminator = gan.Discriminator(num_channels=opt.num_channels, n_classes=n_classes)
elif model_choice == "ACGAN_D":
    discriminator = gan.DiscriminatorDropout(num_channels=opt.num_channels, n_classes=n_classes)
else:
    discriminator = gan.DiscriminatorSN(num_channels=opt.num_channels, n_classes=n_classes)

discriminator.apply(weights.weights_init)

# Loss functions
adversarial_loss = nn.BCELoss()

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(discriminator.parameters(), lr=opt.lr_D, betas=(opt.beta1, opt.beta2))
optimizerG = optim.Adam(generator.parameters(), lr=opt.lr_G, betas=(opt.beta1, opt.beta2))

In [None]:
# Fix images for viz
fixed_z = torch.randn(1, n_classes, opt.latent_dim)
fixed_z = torch.repeat_interleave(fixed_z, 10, 0).reshape(-1, opt.latent_dim)
fixed_label = torch.arange(0, n_classes)
fixed_label = torch.repeat_interleave(fixed_label, 10)

# Labels
real_label = torch.ones(opt.batch_size)
fake_label = torch.zeros(opt.batch_size)

# GPU Compatibility
is_cuda = torch.cuda.is_available()
if is_cuda:
    generator, discriminator = generator.cuda(), discriminator.cuda()
    real_label, fake_label = real_label.cuda(), fake_label.cuda()
    fixed_z, fixed_label = fixed_z.cuda(), fixed_label.cuda()

total_iters = 0
max_iter = len(dataloader)

In [None]:
# Training Loop

# Lists to keep track of progress
G_losses = []
D_losses = []

fid_score_history = []
kid_score_history = []

fake_images_list = []

global_step = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(opt.epochs):
    generator.train()
    discriminator.train()

    ##############################
    ### COMPUTE METRICS
    if epoch % 5 == 0:
        print("Global step: {}. Computing metrics...".format(global_step))

        # get random real samples
        samples = random.sample(range(len(dataset)), 192)
        real_samples = [dataset[s][0] for s in samples]
        real_samples = torch.stack(real_samples, dim=0).to(device)

        # generate random fake samples
        fake_samples = []

        with torch.no_grad():
            z = torch.randn(192, opt.latent_dim) * 2 - 1
            z.to(device)
            # z = torch.rand(192, opt.latent_dim, device=device) * 2 - 1
            gen_labels = Variable(
                torch.LongTensor(np.random.randint(0, n_classes, 192))).to(device)

            for k in tqdm(range(192 // opt.batch_size), desc="Generating fake images"):
                z_ = z[k * opt.batch_size: (k + 1) * opt.batch_size].to(device)
                gen_labels_ = gen_labels[k * opt.batch_size: (k + 1) * opt.batch_size].to(device)

                fake_samples.append(generator(z_, gen_labels_))

            fake_samples = torch.cat(fake_samples, dim=0).to(device)

        print("Computing KID and FID...")
        kid, fid = mt.compute_metrics(real_samples, fake_samples)

        print("FID: {:.4f}".format(fid))

        print("KID: {:.4f}".format(kid))

        fid_score_history.append(fid)
        kid_score_history.append(kid)

    ##############################

    for i, data in enumerate(dataloader):

        total_iters += 1

        # Format batch
        x_real = data[0].to(device)
        x_label = data[1].to(device)  # Without condition in GAN, this is unnecessary

        batch_size = x_real.size(0)
        # Loading data
        # x_real, x_label = data

        z_fake = torch.randn(batch_size, opt.latent_dim)

        if is_cuda:
            x_real = x_real.cuda()
            x_label = x_label.cuda()
            z_fake = z_fake.cuda()

        # Generate fake data
        x_fake = generator(z_fake, x_label)

        # Train Discriminator
        fake_out = discriminator(x_fake.detach(), x_label)
        real_out = discriminator(x_real.detach(), x_label)

        d_loss = (adversarial_loss(fake_out, fake_label) + adversarial_loss(real_out,
                                                                            real_label)) / 2

        optimizerD.zero_grad()
        d_loss.backward()
        optimizerD.step()

        # Train Generator
        fake_out = discriminator(x_fake, x_label)
        g_loss = adversarial_loss(fake_out, real_label)

        optimizerG.zero_grad()
        g_loss.backward()
        optimizerG.step()

        # Save Losses for plotting later
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())

        global_step += 1

        # Output training stats
        if i % opt.log_interval == 0:
            print("Epoch: " + str(epoch + 1) + "/" + str(opt.epochs)
                  + "\titer: " + str(i) + "/" + str(max_iter)
                  + "\ttotal_iters: " + str(total_iters)
                  + "\td_loss:" + str(round(d_loss.item(), 4))
                  + "\tg_loss:" + str(round(g_loss.item(), 4))
                  )

    if (epoch + 1) % 1 == 0:
        generate_imgs(generator, epoch)

    # Save generator and discriminator weights after each 20 epochs
    if epoch % 20 == 0:
        torch.save(generator.state_dict(),
                   "%s/netG_%s_epoch_%d.pth" % (save_model_checkpoints, dataset_choice, epoch))
        torch.save(discriminator.state_dict(),
                   "%s/netD_%s_epoch_%d.pth" % (save_model_checkpoints, dataset_choice, epoch))

print("Training is finished!")

In [None]:
# results
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator loss during training")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# FID results
plt.figure(figsize=(10, 5))
plt.title("FID score")
plt.plot(fid_score_history)
plt.xlabel("epochs")
plt.ylabel("FID value")
plt.show()

In [None]:
# KID results
plt.figure(figsize=(10, 5))
plt.title("KID score")
plt.plot(kid_score_history)
plt.xlabel("epochs")
plt.ylabel("KID value")
plt.show()

In [None]:
print(fid_score_history)

In [None]:
print(kid_score_history)

In [None]:
print(G_losses)

In [None]:
print(D_losses)