# Location Aware Generative Advesarial Network

In [0]:
import torch.nn as nn
import torch.nn.functional as F
import torch

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torchvision.transforms as transforms
from torchvision.utils import save_image

In [0]:
# Source: https://discuss.pytorch.org/t/locally-connected-layers/26979

from torch.nn.modules.utils import _pair

class LocallyConnected2d(nn.Module):
    def __init__(self, in_channels, out_channels, output_size, kernel_size, stride, bias=False):
        super(LocallyConnected2d, self).__init__()
        output_size = _pair(output_size)
        self.weight = nn.Parameter(
            torch.randn(1, out_channels, in_channels, output_size[0], output_size[1], kernel_size**2)
        )
        if bias:
            self.bias = nn.Parameter(
                torch.randn(1, out_channels, output_size[0], output_size[1])
            )
        else:
            self.register_parameter('bias', None)
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        
    def forward(self, x):
        _, c, h, w = x.size()
        kh, kw = self.kernel_size
        dh, dw = self.stride
        x = x.unfold(2, kh, dh).unfold(3, kw, dw)
        x = x.contiguous().view(*x.size()[:-2], -1)
        # Sum in in_channel and kernel_size dims
        out = (x.unsqueeze(1) * self.weight).sum([2, -1])
        if self.bias is not None:
            out += self.bias
        return out


In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
                        # Block I
                        nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5),
                        nn.LeakyReLU(negative_slope=0.3, inplace=True),
                        nn.Dropout(p=0.2, inplace=True),

                        # Block II
                        nn.ZeroPad2d(padding=2),
                        LocallyConnected2d(in_channels=32, out_channels=8, kernel_size=5)
                        nn.LeakyReLU(negative_slope=0.3, inplace=True),
                        nn.BatchNorm2d(num_features=8, momentum=0.99, eps=1e-3),
                        nn.Dropout(p=0.2, inplace=True),

                        # Block III

                     )