# Location Aware Generative Advesarial Network

In [0]:
import torch.nn as nn
import torch.nn.functional as F
import torch

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from torchgan.layers import MinibatchDiscrimination1d

import torchvision.transforms as transforms
from torchvision.utils import save_image

In [0]:
import nn_local as nn_

### Discriminator

In [0]:
class Discriminator(nn.Module):
    def __init__(self, ngpu=1):
        super(Discriminator, self).__init__()
        
        # Base Deep Neural Network
        self.common = nn.Sequential(
                    # Block I
                    nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
                    nn.LeakyReLU(negative_slope=0.3, inplace=True),
                    nn.Dropout(p=0.2, inplace=True),

                    # Block II
                    nn.ZeroPad2d(padding=2),
                    nn_.Conv2dLocal(in_height=29, in_width=29, in_channels=32, out_channels=8, kernel_size=5, stride=2),
                    nn.LeakyReLU(negative_slope=0.3, inplace=True),
                    nn.BatchNorm2d(num_features=8, momentum=0.99, eps=1e-3),
                    nn.Dropout(p=0.2, inplace=True),

                    # Block III
                    nn.ZeroPad2d(padding=2),
                    nn_.Conv2dLocal(in_height=17, in_width=17, in_channels=8, out_channels=8, kernel_size=5, stride=1),
                    nn.LeakyReLU(negative_slope=0.3, inplace=True),
                    nn.BatchNorm2d(num_features=8, momentum=0.99, eps=1e-3),
                    nn.Dropout(p=0.2, inplace=True),

                    # Block IV
                    nn.ZeroPad2d(padding=2),
                    nn_.Conv2dLocal(in_height=17, in_width=17, in_channels=8, out_channels=8, kernel_size=5, stride=2),
                    nn.LeakyReLU(negative_slope=0.3, inplace=True),
                    nn.BatchNorm2d(num_features=8, momentum=0.99, eps=1e-3),
                    nn.Dropout(p=0.2, inplace=True),

                    # Block V
                    nn.AvgPool2d(kernel_size=2),
                    nn.Flatten(),

                    # Block VI [MinBatchDiscrimination for Mode Collapse Detection]
                    MinibatchDiscrimination1d(in_features=72, out_features=20)
                  )
        
        # Auxillary Output
        self.auxo = nn.Sequential(
                    nn.Linear(in_features=92, out_features=1),
                    nn.Sigmoid()
                  )
        
        # Prime Output
        self.prim = nn.Sequential(
                    nn.Linear(in_features=92, out_features=1),
                    nn.Sigmoid()
                  )

    def forward(self, input):
        output = self.common(input)
        output = torch.cat([self.prim(output), self.auxo(output)], axis=-1)
        return output.squeeze(1)

In [0]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()

        # Base Upscaler Deep Neural Netowrk
        self.model  = nn.Sequential(
                      # DCGAN Style Project and Reshaping
                      nn.Linear(in_features=latent_dim, out_features=6272),
                      nn_.Reshape(-1, 128, 7, 7),
       
                      # Block I
                      nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, padding=2),
                      nn.LeakyReLU(negative_slope=0.3, inplace=True),
                      nn.BatchNorm2d(num_features=64, momentum=0.99, eps=1e-3),
                      nn.UpsamplingNearest2d(scale_factor=2),
        
                      # Block II
                      nn.ZeroPad2d(padding=2),
                      nn_.Conv2dLocal(in_height=18, in_width=18, in_channels=64, out_channels=6, kernel_size=5, stride=1),
                      nn.LeakyReLU(negative_slope=0.3, inplace=True),
                      nn.BatchNorm2d(num_features=6, momentum=0.99, eps=1e-3),
                      nn.UpsamplingNearest2d(scale_factor=2),

                      # Block III
                      nn_.Conv2dLocal(in_height=28, in_width=28, in_channels=6, out_channels=6, kernel_size=3, stride=1),
                      nn.LeakyReLU(negative_slope=0.3, inplace=True),

                      # Block IV
                      nn_.Conv2dLocal(in_height=26, in_width=26, in_channels=6, out_channels=1, kernel_size=2, stride=1),
                      nn.ReLU(inplace=True)
                  )
        
        # Latent Vector(Z) and Auxillary Input Label
        self.aux = nn.Embedding(num_embeddings=2, embedding_dim=latent_dim) 

    def forward(self, z, label):
        hadmard_product = torch.mul(self.aux(label), z)
        return self.model(hadmard_product)

### HyperParameters & Modes

In [0]:
nb_epochs = 10
batch_size = 64
latent_size = 200
nb_classes = 2

adam_lr = 0.05
adam_beta_1 = 0.999
adam_beta_2 = 0.999

In [0]:
verbose = False
sample_interval = 1e3
sample_count = 25
cuda = False

### Model Instantiation

In [0]:
disc_network = Discriminator()
genr_network = Generator(latent_size)

if cuda:
    disc_network.cuda()
    genr_network.cuda()

### Cost Function

In [0]:
adv_loss = torch.nn.BCELoss()

if cuda:
    adv_loss.cuda()

### Optimizers

In [0]:
optimizer_d = torch.optim.Adam(disc_network.parameters(), 
                               lr=adam_lr, 
                               betas=(adam_beta_1, adam_beta_2))

optimizer_g = torch.optim.Adam(genr_network.parameters(), 
                               lr=adam_lr, 
                               betas=(adam_beta_1, adam_beta_2))

### Data Loaders

In [0]:
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [0]:
#### Dataloader to be written

### Training

In [0]:
for epoch in range(nb_epochs):
    for i, imgs in dataloader:
        # Adversarial ground truths
        real = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # Training of Generator
        optimizer_g.zero_grad()

        # Sample Gaussian Noise and Uniformly Distributed Labels
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_size))))
        labels = Variable(Tensor(np.random.randint(0, nb_classes, imgs.shape[0])))

        # Generate a batch of images
        gen_imgs = genr_network(z, labels)

        # Generator Loss
        g_loss = adv_loss(disc_network(gen_imgs), valid)
        g_loss.backward()
        optimizer_g.step()

        # Training of Discriminator
        optimizer_d.zero_grad()

        # Discriminator Losses
        auxlr_real_loss = adv_loss(disc_network(real_imgs), real)
        auxlr_real_loss = adv_loss(disc_network(real_imgs), real)
        auxlr_fake_loss = adv_loss(discriminator(gen_imgs.detach()), fake)
        prime_fake_loss = adv_loss(discriminator(gen_imgs.detach()), fake)

        d_loss = 0.25*(auxlr_real_loss + auxlr_real_loss + auxlr_fake_loss + prime_fake_loss)
        d_loss.backward()
        optimizer_d.step()

        if verbose:
            print("[Epoch {}/{}] [Batch {}/{}] [D loss: {:0.8f}] [G loss: {:0.8f}]".\
                  format(epoch, nb_epochs, i, len(dataloader), d_loss.item(), g_loss.item())

        batches_done = (epoch * len(dataloader) + i)
        if  batches_done % sample_interval == 0:
            save_image(gen_imgs.data[:sample_count], 
                       "images/%d.png".format(batches_done), 
                       nrow=5, 
                       normalize=True)