# Relativistic average GAN for MNIST

## Imports

In [1]:
import torch
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from viz import updatable_display2

import torch.nn as nn
import torch.nn.functional as F
from torch import nn, optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

from generators import GumbelSAGenerator
from discriminators import GumbelSADiscriminator
from utils import img2vec,vec2img,sample_sequence_noise,true_target,fake_target,onehot,num_parameters

from timeSeries import Sinusoids
from visualize import plotSamples

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cuda


## Dataset loading

In [2]:
batch_size = 64
num_steps = 3
dataset_size = 10000
num_classes = 10

data_loader = DataLoader(Sinusoids(num_steps,virtual_size=dataset_size,quantization=num_classes),batch_size=batch_size,shuffle=True)
valid_data_loader = DataLoader(Sinusoids(num_steps,virtual_size=dataset_size,quantization=num_classes),batch_size=batch_size,shuffle=True)

## Training setup

In [3]:
lr = 1e-4
dropout_prob = 0.3
noise_dim = 100
output_size = num_classes

num_test_samples = 100
test_noise = sample_sequence_noise(num_test_samples,num_steps,noise_dim,device).view(num_test_samples,noise_dim,num_steps)

# intialize models
generator = GumbelSAGenerator(input_size=noise_dim,hidden_size=128,output_size=output_size,device=device).to(device)
discriminator = GumbelSADiscriminator(input_size=output_size,hidden_size=128,output_size=1,num_embeddings=6).to(device)

# otpimizers
g_optimizer = optim.Adam(generator.parameters(),lr=lr)
d_optimizer = optim.Adam(discriminator.parameters(),lr=lr)
loss_fun = nn.BCEWithLogitsLoss()
pretrain_loss_fun = nn.NLLLoss()

# Create logger instance
dis = updatable_display2(['train'],["epoch","d_error","g_error","beta"])
pretrain_dis = updatable_display2(['train'],["pretrain epoch","pretrain_g_error"])
# Total number of epochs to train
num_epochs = 200


print(generator)
print()
print(discriminator)
np_g = num_parameters(generator)
np_d = num_parameters(discriminator)
print("Number of parameters for G: {}\nNumber of parameters for D: {}\nNumber of parameters in total: {}"
      .format(np_g,np_d,np_g+np_d))

GumbelSAGenerator(
  (layers): Sequential(
    (0): ConvTranspose1d(100, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose1d(512, 256, kernel_size=(3,), stride=(1,), padding=(1,))
    (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): SelfAttention(hidden_size = 256)
    (7): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): ConvTranspose1d(256, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (10): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): ConvTranspose1d(128, 10, kernel_size=(3,), stride=(1,), padding=(1,))
  )
  (gumbelsoftmax): GumbelSoftmax()
)

GumbelSADiscriminator(
  (embeddings): ModuleList(
    (0): Linear(in_features=10, out_features=128, bias=False)
    (1): Linea

### Train Generator

In [4]:
def train_generator(real_data,fake_data,optimizer):
    '''
    Train the generator to generate realistic samples and thereby fool the discriminator
    '''
    N = fake_data.size(0)
    
    # Reset gradients
    optimizer.zero_grad()
    
    # 1.1 Train on Real Data
    c_x_r = discriminator(real_data)

    # 1.2 Train on Fake Data
    c_x_f = discriminator(fake_data)
    
    # compute the average of c_x_*
    c_x_r_mean = torch.mean(c_x_r,dim=0)
    c_x_f_mean = torch.mean(c_x_f,dim=0)
    
    losses_real = []
    losses_fake = []
    for i in range(c_x_r.size(1)):
        losses_real.append(loss_fun(c_x_r[:,i,:]-c_x_f_mean[i,:],fake_target(N,device)))
        losses_fake.append(loss_fun(c_x_f[:,i,:]-c_x_r_mean[i,:],true_target(N,device)))
    loss_real = torch.stack(losses_real).mean()
    loss_fake = torch.stack(losses_fake).mean()
    
    loss = (loss_real + loss_fake)/2.0
    loss.backward()
    
    optimizer.step()
    return loss

### Train Discriminator

In [5]:
def train_discriminator(real_data,fake_data,optimizer):
    '''
    Train the discriminator to distinguish between real and fake data
    '''
    N = real_data.size(0)
    
    # Reset gradients
    optimizer.zero_grad()

    # 1.1 Train on Real Data
    c_x_r = discriminator(real_data)

    # 1.2 Train on Fake Data
    c_x_f = discriminator(fake_data)
    
    # compute the average of c_x_*
    c_x_r_mean = torch.mean(c_x_r,dim=0)
    c_x_f_mean = torch.mean(c_x_f,dim=0)
    
    losses_real = []
    losses_fake = []
    for i in range(c_x_r.size(1)):
        losses_real.append(loss_fun(c_x_r[:,i,:]-c_x_f_mean[i,:],true_target(N,device)))
        losses_fake.append(loss_fun(c_x_f[:,i,:]-c_x_r_mean[i,:],fake_target(N,device)))
    loss_real = torch.stack(losses_real).mean()
    loss_fake = torch.stack(losses_fake).mean()
        
    #loss_real = loss_fun(c_x_r-c_x_f_mean,true_target(N,device))
    #loss_fake = loss_fun(c_x_f-c_x_r_mean,fake_target(N,device))
    
    loss = (loss_real + loss_fake)/2.0
    loss.backward()
    
    # 1.3 Update weights with gradients
    optimizer.step()
    return loss

### Pretrain Generator

In [6]:
def pretrain_generator(real_data,fake_data,optimizer):
    '''
    Pretrain the generator to generate realistic samples for a good initialization
    '''
    num_classes = fake_data.size(2)
    # Reset gradients
    optimizer.zero_grad()
    loss = 0
    fake_data = torch.log(fake_data)
    for i in range(fake_data.size(1)):
        loss += pretrain_loss_fun(fake_data[:,i,:],real_data[:,i])
    loss.backward()
    
    optimizer.step()
    return loss

## Train the model

In [None]:
gen_steps = 1
gen_train_freq = 1
max_temperature = torch.FloatTensor([2]).to(device)
pretrain_temperature = torch.FloatTensor([1]).to(device)

epochs_pretrain = 1000
pretrain_step = 0

# pretrain generator
for ep in range(epochs_pretrain):
    for n_batch,real_batch in enumerate(data_loader):
        N = real_batch.size(0)
        # 1. Train Discriminator
        real_batch = real_batch.squeeze(2)
        real_data = real_batch.to(device)

        # Generate fake data
        fake_data = generator(sample_sequence_noise(N,num_steps,noise_dim,device).view(N,noise_dim,num_steps),
                              pretrain_temperature)
        # Train G
        pretrain_g_error = pretrain_generator(real_data,fake_data,g_optimizer)
        
        # Log batch error and delete tensors
        pretrain_dis.update(pretrain_step,'train',{"pretrain epoch":ep,"pretrain_g_error":pretrain_g_error.item()})
        pretrain_step += 1
        if pretrain_step % 200 == 0:
            test_samples = generator(test_noise,pretrain_temperature)
            test_samples_vals = torch.argmax(test_samples,dim=2)
            pretrain_dis.display(scale=True)
            plotSamples(real_data.type(torch.FloatTensor),xu=num_steps,yu=num_classes,title="Real data")
            plotSamples(test_samples_vals,xu=num_steps,yu=num_classes,title="Fake data")
            plt.show()
        del fake_data
        del real_data
    
global_step = 0
epoch = 0
d_error = 0
g_error = 0

# train adverserially
try:
    while epoch < num_epochs:
        temperature = max_temperature**((epoch+1)/num_epochs)
        for n_batch,real_batch in enumerate(data_loader):
            N = real_batch.size(0)
            # 1. Train Discriminator
            real_batch = onehot(real_batch.squeeze(2),num_classes).type(torch.FloatTensor)
            real_data = real_batch.to(device)

            # Generate fake data and detach 
            # (so gradients are not calculated for generator)
            noise_tensor = sample_sequence_noise(N,num_steps,noise_dim,device).view(N,noise_dim,1)
            with torch.no_grad():
                fake_data = generator(noise_tensor,temperature).detach()
            # Train D
            d_error = train_discriminator(real_data,fake_data,d_optimizer)

            # 2. Train Generator every 'gen_train_freq' steps
            if global_step % gen_train_freq == 0:
                for _ in range(gen_steps):
                    # Generate fake data
                    fake_data = generator(sample_sequence_noise(N,num_steps,noise_dim,device).view(N,noise_dim,1),temperature)
                    # Train G
                    g_error = train_generator(real_data,fake_data,g_optimizer)
                    g_error = g_error.item()

            # Log batch error and delete tensors
            dis.update(global_step,'train',{"epoch":epoch,"d_error":d_error.item(),"g_error":g_error,
                                            "beta":temperature.item()})
            global_step += 1
            del fake_data
            del noise_tensor

            # Display Progress every few batches
            if global_step % 50 == 0:
                test_samples = generator(test_noise,temperature)
                test_samples_vals = torch.argmax(test_samples,dim=2)
                dis.display(scale=True)
                plotSamples(torch.argmax(real_data,dim=2),xu=num_steps,yu=num_classes,title="Real data")
                plotSamples(test_samples_vals,xu=num_steps,yu=num_classes,title="Fake data")
                if epoch % 50 == 0:
                    plt.savefig("Figures/RaSGAN-SAGANCONV-ToyData-Epoch="+str(epoch)+".png")
                plt.show()
            del real_data
        epoch += 1
except:
    test_samples = generator(test_noise,temperature)
    test_samples_vals = torch.argmax(test_samples,dim=2)
    dis.display(scale=True)
    plotSamples(torch.argmax(real_data,dim=2),xu=num_steps,yu=num_classes,title="Real data")
    plotSamples(test_samples_vals,xu=num_steps,yu=num_classes,title="Fake data")
    plt.savefig("Figures/RaSGAN-SAGANCONV-ToyData.png")
    plt.show()