# Location Aware Generative Advesarial Network

In [0]:
import torch.nn as nn
import torch.nn.functional as F
import torch

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from torchgan.layers import MinibatchDiscrimination1d

import torchvision.transforms as transforms
from torchvision.utils import save_image

In [0]:
import nn_local as nn_

## Discriminator

The Discriminator network described in the paper is as follow:<br>
<img src="Images/Discriminator.PNG" width=800/>

In [0]:
class Discriminator(nn.Module):
    def __init__(self, ngpu=1):
        super(Discriminator, self).__init__()
        
        # Base Deep Neural Network
        self.common = nn.Sequential(
                    # Block I
                    nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
                    nn.LeakyReLU(negative_slope=0.3, inplace=True),
                    nn.Dropout(p=0.2, inplace=True),

                    # Block II
                    nn.ZeroPad2d(padding=2),
                    nn_.Conv2dLocal(in_height=29, in_width=29, in_channels=32, out_channels=8, kernel_size=5, stride=2),
                    nn.LeakyReLU(negative_slope=0.3, inplace=True),
                    nn.BatchNorm2d(num_features=8, momentum=0.99, eps=1e-3),
                    nn.Dropout(p=0.2, inplace=True),

                    # Block III
                    nn.ZeroPad2d(padding=2),
                    nn_.Conv2dLocal(in_height=17, in_width=17, in_channels=8, out_channels=8, kernel_size=5, stride=1),
                    nn.LeakyReLU(negative_slope=0.3, inplace=True),
                    nn.BatchNorm2d(num_features=8, momentum=0.99, eps=1e-3),
                    nn.Dropout(p=0.2, inplace=True),

                    # Block IV
                    nn.ZeroPad2d(padding=2),
                    nn_.Conv2dLocal(in_height=17, in_width=17, in_channels=8, out_channels=8, kernel_size=5, stride=2),
                    nn.LeakyReLU(negative_slope=0.3, inplace=True),
                    nn.BatchNorm2d(num_features=8, momentum=0.99, eps=1e-3),
                    nn.Dropout(p=0.2, inplace=True),

                    # Block V
                    nn.AvgPool2d(kernel_size=2),
                    nn.Flatten(),

                    # Block VI [MinBatchDiscrimination for Mode Collapse Detection]
                    MinibatchDiscrimination1d(in_features=72, out_features=20)
                  )
        
        # Auxillary Output
        self.auxo = nn.Sequential(
                    nn.Linear(in_features=92, out_features=1),
                    nn.Sigmoid()
                  )
        
        # Prime Output
        self.prim = nn.Sequential(
                    nn.Linear(in_features=92, out_features=1),
                    nn.Sigmoid()
                  )

    def forward(self, input):
        output = self.common(input)
        output = torch.cat([self.prim(output), self.auxo(output)], axis=-1)
        return output.squeeze(1)

## Generator

The Generator network described in the paper is as follow:<br>
<img src="Images/Generator.PNG" width=800/>

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()

        # Base Upscaler Deep Neural Netowrk
        self.mstem = nn.Sequential(
                      # DCGAN Style Project and Reshaping
                      nn.Linear(in_features=latent_dim, out_features=6272),
                      nn_Reshape(-1, 128, 7, 7),
       
                      # Block I
                      nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, padding=2),
                      nn.LeakyReLU(negative_slope=0.3, inplace=True),
                      nn.BatchNorm2d(num_features=64, momentum=0.99, eps=1e-3),
                      nn.UpsamplingNearest2d(scale_factor=2),
        
                      # Block II
                      nn.ZeroPad2d(padding=2),
                      nn_.Conv2dLocal(in_height=18, in_width=18, in_channels=64, out_channels=6, kernel_size=5, stride=1),
                      nn.LeakyReLU(negative_slope=0.3, inplace=True),
                      nn.BatchNorm2d(num_features=6, momentum=0.99, eps=1e-3),
                      nn.UpsamplingNearest2d(scale_factor=2),

                      # Block III
                      nn_.Conv2dLocal(in_height=28, in_width=28, in_channels=6, out_channels=6, kernel_size=3, stride=1),
                      nn.LeakyReLU(negative_slope=0.3, inplace=True),

                      # Block IV
                      nn_.Conv2dLocal(in_height=26, in_width=26, in_channels=6, out_channels=1, kernel_size=2, stride=1),
                      nn.ReLU(inplace=True)
                  )

    def forward(self, input):
        return self.mstem(input)