In [None]:
import os
import random
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm

sys.path.insert(0, '../')

import metrics as mt
import settings as opt
import weights
from preprocess.dataset import ValenceArousalWithClassesDataset

In [None]:
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

# Choose model and dataset

# Select dataset to train on - possible choices: affective, augmented_affective, cifar10
dataset_choice = "affective"

# Select model to train on - PAGAN, PAGAN_D, PAGAN_SN
model_choice = "PAGAN"

import model.pagan as gan

In [None]:
gen_imgs_path = opt.outi + model_choice + '/' + dataset_choice + '/'
save_model_checkpoints = opt.outc + model_choice + '/' + dataset_choice + '/'

# create directories for output
try:
    os.makedirs(gen_imgs_path)
except OSError:
    pass

try:
    os.makedirs(save_model_checkpoints)
except OSError:
    pass

In [None]:
# Image preprocessing
pre_process = transforms.Compose([
    transforms.Resize(opt.image_size),
    transforms.CenterCrop(opt.image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

if dataset_choice == "cifar10":
    dataset = dset.CIFAR10(
        root="../data/cifar10",
        download=True,
        transform=pre_process
    )
    n_classes = 10
elif dataset_choice == "affective":
    dataset = ValenceArousalWithClassesDataset(csv_file=opt.annonations_file,
                                               root_dir=opt.all_images_path,
                                               transform=pre_process)
    n_classes = 13
else:
    dataset = ValenceArousalWithClassesDataset(csv_file=opt.augmented_annonations_file,
                                               root_dir=opt.augmented_images_path,
                                               transform=pre_process)
    n_classes = 13

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4,
                        drop_last=True)

In [None]:
# Create the GAN model

# Create or load the generator
generator = gan.Generator(opt.latent_dim, opt.num_channels, opt.ngf, opt.n_classes).to(device)
generator.apply(weights.weights_init)

# Create or load the discriminator
if model_choice == "PAGAN":
    discriminator = gan.Discriminator(opt.num_channels, opt.ndf, opt.n_classes).to(device)
elif model_choice == "PAGAN_D":
    discriminator = gan.DiscriminatorDropout(opt.num_channels, opt.ndf, opt.n_classes).to(device)
else:
    discriminator = gan.DiscriminatorSN(opt.num_channels, opt.ndf, opt.n_classes).to(device)

discriminator.apply(weights.weights_init)

# Loss functions
adversarial_loss = nn.BCELoss()

# Establish convention for real and fake labels during training
# with condition GAN, these fields are unnecessary
real_label = 1.  # GAN trick, real examples are real in 90%
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(discriminator.parameters(), lr=opt.lr_D, betas=(opt.beta1, opt.beta2))
optimizerG = optim.Adam(generator.parameters(), lr=opt.lr_G, betas=(opt.beta1, opt.beta2))

In [None]:
####################################################################################################
# PA-GAN HyperParameters

# keep actual augmentation level
augmentation_level = 0
last_augmentation_step = 0

# steps for p to reach 0.5
tr = 5000

In [None]:
# Training Loop

# Lists to keep track of progress
G_losses = []
D_losses = []

fid_score_history = []
kid_score = []
kid_score_history = []

fake_images_list = []

global_step = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(opt.epochs):

    ##############################
    ### COMPUTE METRICS
    if epoch % 5 == 0:
        print("Global step: {}. Computing metrics...".format(global_step))

        # get random real samples
        samples = random.sample(range(len(dataset)), opt.fid_batch)
        real_samples = [dataset[s][0] for s in samples]
        real_samples = torch.stack(real_samples, dim=0).to(device)

        # generate random fake samples
        fake_samples = []

        with torch.no_grad():
            noise = Variable(
                torch.FloatTensor(np.random.normal(0, 1, (opt.fid_batch, opt.latent_dim)))).to(
                device)
            gen_labels = Variable(
                torch.LongTensor(np.random.randint(0, n_classes, opt.fid_batch))).to(device)

            for k in tqdm(range(opt.fid_batch), desc="Generating fake images"):
                noise_ = noise[k * opt.batch_size: (k + 1) * opt.batch_size]
                gen_labels_ = gen_labels[k * opt.batch_size: (k + 1) * opt.batch_size]

                fake_samples.append(generator(noise_, gen_labels_))
            fake_samples = torch.cat(fake_samples, dim=0).to(device)

        print("Computing KID and FID...")
        kid, fid = mt.compute_metrics(real_samples, fake_samples)

        print("FID: {:.4f}".format(fid))
        print("KID: {:.4f}".format(kid))

        fid_score_history.append(fid)
        # Augment discriminator dimension
        if (len(kid_score) >= 2 and kid >= (
                kid_score[-1] + kid_score[-2]) * 19 / 40):
            # there should be calculated KID score and depending on the result and previous ones decide whether increment augmentation level
            augmentation_level += 1
            last_augmentation_step = global_step
            discriminator.main.conv1 = nn.Conv2d(opt.num_channels + augmentation_level,
                                                 opt.ndf, 4, 2, 1, bias=False).to(
                device)
            discriminator.num_channels = discriminator.num_channels + augmentation_level
            discriminator.main.conv1.apply(weights.weights_init)
            optimizerD = optim.Adam(discriminator.parameters(), lr=opt.lr_D,
                                    betas=(opt.beta1, opt.beta2))

            print("Augmentation level increased to {}".format(augmentation_level))

            kid_score = []
            kid_score_history.append(kid)
        else:
            kid_score.append(kid)
            kid_score_history.append(kid)

    ##############################

    # For each batch in the dataloader
    for i, data in enumerate(dataloader, start=0):

        # Format batch
        real_images = data[0].to(device)
        # real_labels = data[1].to(device)  # Without condition in GAN, this is unnecessary

        batch_size = real_images.size(0)

        # Adversarial ground truths
        true_labels = Variable(torch.FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False).to(
            device).view(-1)
        false_labels = Variable(torch.FloatTensor(batch_size, 1).fill_(0.0),
                                requires_grad=False).to(device).view(-1)

        # Configure input
        real_images = Variable(real_images.type(torch.FloatTensor)).to(device)
        # real_labels = Variable(real_labels.type(torch.LongTensor)).to(device).view(-1)

        # Generate batch of latent vectors
        noise = Variable(
            torch.FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim)))).to(device)
        gen_labels = Variable(torch.LongTensor(np.random.randint(0, n_classes, batch_size))).to(
            device)
        # Generate fake image batch with G

        fake_images = generator(noise, gen_labels)

        ###################################
        # (1) Update D network
        # Train with all-real batch
        discriminator.zero_grad()

        if augmentation_level > 0:
            p = min(0.5 * (global_step - last_augmentation_step) / tr, 0.5)
            if augmentation_level > 1:
                augmentation_bits_old = np.random.randint(0, 2,
                                                          size=(batch_size, augmentation_level - 1))
                augmentation_bits_new = np.where(np.random.rand(batch_size, 1) < p,
                                                 np.ones((batch_size, 1)),
                                                 np.zeros((batch_size, 1)))
                augmentation_bits = np.concatenate((augmentation_bits_old, augmentation_bits_new),
                                                   axis=1)
            else:
                augmentation_bits = np.where(np.random.rand(batch_size, 1) < p,
                                             np.ones((batch_size, 1)),
                                             np.zeros((batch_size, 1)))
        else:
            augmentation_bits = None

        real_augmented, real_labels_augmented = gan.add_channel(real_images, augmentation_bits,
                                                                real=True)

        real_images = real_augmented
        label = real_labels_augmented

        # Forward pass real batch through D
        output = discriminator(real_images).view(-1)
        # Calculate loss on all-real batch
        errD_real = adversarial_loss(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        # D_x = output.mean().item()

        # Add channel to images
        fake_augmented, fake_labels_augmented = gan.add_channel(fake_images, augmentation_bits,
                                                                real=False)

        fake_images = fake_augmented
        fake_label = fake_labels_augmented

        # Classify all fake batch with D
        output = discriminator(fake_images.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = adversarial_loss(output, fake_label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        # D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ###################################
        # (2) Update G network
        generator.zero_grad()
        # label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = discriminator(fake_images).view(-1)
        # Calculate G's loss based on this output

        errG = adversarial_loss(output, 1 - fake_label)

        # Calculate gradients for G
        errG.backward()
        # D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        global_step += 1

        # Output training stats
        if i % opt.log_interval == 0:
            print(
                "[{}/{}][{}/{}] Loss_D: {:.4f} Loss_G: {:.4f}".format(
                    epoch,
                    opt.epochs,
                    i,
                    len(dataloader),
                    errD.item(),
                    errG.item()
                )
            )

    # Save generated images after each epoch
    noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (64, opt.latent_dim)))).to(device)
    gen_labels = Variable(torch.LongTensor(np.random.randint(0, n_classes, 64))).to(device)

    fake_images = generator(noise, gen_labels).detach().cpu()

    vutils.save_image(fake_images,
                      "%s/fake_%s_epoch_%03d.png" % (gen_imgs_path, dataset_choice, epoch),
                      normalize=True)

    # Save generator and discriminator weights after each 20 epochs
    if epoch % 20 == 0:
        torch.save(generator.state_dict(),
                   "%s/netG_%s_epoch_%d.pth" % (save_model_checkpoints, dataset_choice, epoch))
        torch.save(discriminator.state_dict(),
                   "%s/netD_%s_epoch_%d.pth" % (save_model_checkpoints, dataset_choice, epoch))

print("Training is finished!")

In [None]:
# results
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator loss during training")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# FID results
plt.figure(figsize=(10, 5))
plt.title("FID score")
plt.plot(fid_score_history)
plt.xlabel("epochs")
plt.ylabel("FID value")
plt.show()

In [None]:
# KID results
plt.figure(figsize=(10, 5))
plt.title("KID score")
plt.plot(kid_score_history)
plt.xlabel("epochs")
plt.ylabel("KID value")
plt.show()

In [None]:
print(fid_score_history)

In [None]:
print(kid_score_history)

In [None]:
print(G_losses)

In [None]:
print(D_losses)