# Introduction to GANs

In [3]:
import torch
import torch.nn as nn

## Generator

In [4]:
class Generator(nn.Module):
    def __init__(self, in_dim, out_dim):
      super(Generator, self).__init__()
      # Define generator block
      self.generator = nn.Sequential(
        gen_block(in_dim, 256),
        gen_block(256, 512),
        gen_block(512, 1024),
        # Add linear layer
        nn.Linear(1024, out_dim),
        # Add activation
        nn.Sigmoid(),
      )

    def forward(self, x):
      	# Pass input through generator
        return self.generator(x)

def gen_block(in_dim, out_dim):
  return nn.Sequential(
    nn.Linear(in_dim, out_dim),
    nn.BatchNorm1d(out_dim),
    nn.ReLU(inplace=True)
  )

## Discriminator

In [5]:
class Discriminator(nn.Module):
  def __init__(self, im_dim):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
      disc_block(im_dim, 1024),
      disc_block(1024, 512),
      # Define last discriminator block
      disc_block(512, 256),
      # Add a linear layer
      nn.Linear(256, 1),
    )

  def forward(self, x):
    # Define the forward method
    return self.disc(x)

def disc_block(in_dim, out_dim):
  return nn.Sequential(
    nn.Linear(in_dim, out_dim),
    nn.LeakyReLU(0.2)
  )