# Introduction to GANs

In [10]:
import torch
import torch.nn as nn

## Generator

In [11]:
class Generator(nn.Module):
    def __init__(self, in_dim, out_dim):
      super(Generator, self).__init__()
      # Define generator block
      self.generator = nn.Sequential(
        gen_block(in_dim, 256),
        gen_block(256, 512),
        gen_block(512, 1024),
        # Add linear layer
        nn.Linear(1024, out_dim),
        # Add activation
        nn.Sigmoid(),
      )

    def forward(self, x):
      	# Pass input through generator
        return self.generator(x)

def gen_block(in_dim, out_dim):
  return nn.Sequential(
    nn.Linear(in_dim, out_dim),
    nn.BatchNorm1d(out_dim),
    nn.ReLU(inplace=True)
  )

## Discriminator

In [12]:
class Discriminator(nn.Module):
  def __init__(self, im_dim):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
      disc_block(im_dim, 1024),
      disc_block(1024, 512),
      # Define last discriminator block
      disc_block(512, 256),
      # Add a linear layer
      nn.Linear(256, 1),
    )

  def forward(self, x):
    # Define the forward method
    return self.disc(x)

def disc_block(in_dim, out_dim):
  return nn.Sequential(
    nn.Linear(in_dim, out_dim),
    nn.LeakyReLU(0.2)
  )

# Deep convolutional GAN

## Convolutional Generator

In [13]:
class DCGenerator(nn.Module):
  def __init__(self, in_dim, kernel_size=4, stride=2):
    super(DCGenerator, self).__init__()
    self.in_dim = in_dim
    self.gen = nn.Sequential(
      dc_gen_block(in_dim, 1024, kernel_size, stride),
      dc_gen_block(1024, 512, kernel_size, stride),
      # Add last generator block
      dc_gen_block(512, 256, kernel_size, stride),
      # Add transposed convolution
      nn.ConvTranspose2d(256, 3, kernel_size, stride=stride),
      # Add tanh activation
      nn.Tanh()
    )

  def forward(self, x):
    x = x.view(len(x), self.in_dim, 1, 1)
    return self.gen(x)
  
def dc_gen_block(in_dim, out_dim, kernel_size, stride):
  return nn.Sequential(
    nn.ConvTranspose2d(
      in_dim,
      out_dim,
      kernel_size,
      stride=stride,
    ),
    nn.BatchNorm2d(out_dim),
    nn.ReLU()
  )

## Convolutional Discriminator

In [14]:
class DCDiscriminator(nn.Module):
  def __init__(self, kernel_size=4, stride=2):
    super(DCDiscriminator, self).__init__()
    self.disc = nn.Sequential(
      # Add first discriminator block
      dc_disc_block(3, 512, kernel_size, stride),
      dc_disc_block(512, 1024, kernel_size, stride),
      # Add a convolution
      nn.Conv2d(1024, 1, kernel_size, stride=stride),
    )

  def forward(self, x):
    # Pass input through sequential block
    x = self.disc(x)
    return x.view(len(x), -1)
  
def dc_disc_block(in_dim, out_dim, kernel_size, stride):
  return nn.Sequential(
    nn.Conv2d(
      in_dim,
      out_dim,
      kernel_size,
      stride=stride,
    ),
    nn.BatchNorm2d(out_dim),
    nn.LeakyReLU(0.2),
  )