In [7]:
import torch, numpy, matplotlib.pyplot as plt

#binary classifier that determines whether data points are from the original distribution or the fake (generated) distribution
class Discriminator(torch.nn.Module):
  def __init__(self, input_dim=2, hidden_dim=28, n_layers=3):
    super(Discriminator,self).__init__()

    self.input = torch.nn.Sequential( torch.nn.Linear(input_dim, hidden_dim), torch.nn.LeakyReLU() )
    self.layers = []
    for i in range(n_layers):
      self.layers.append( torch.nn.Sequential( torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.LeakyReLU() ) )
    
    self.layers = torch.nn.ModuleList(self.layers)
    self.output = torch.nn.Sequential( torch.nn.Linear(hidden_dim, 1), torch.nn.Sigmoid() )

  def forward(self, x):
    o = self.input(x)
    for layer in self.layers:
      o = layer(o)

    o = self.output(o)
    return o
  
#Generator, takes in random noise and outputs a fake point
class Generator(torch.nn.Module):
  def __init__(self, z_dim=1, hidden_dim=28, n_layers=3, out_dim=2):
    super(Generator, self).__init__()

    self.input = torch.nn.Sequential( torch.nn.Linear(z_dim, hidden_dim), torch.nn.LeakyReLU() )
    self.layers = []
    for i in range(n_layers):
      self.layers.append( torch.nn.Sequential( torch.nn.Linear(hidden_dim, hidden_dim ), torch.nn.LeakyReLU() ) )
    
    self.layers = torch.nn.ModuleList(self.layers)
    self.output = torch.nn.Linear(hidden_dim, out_dim)

  def forward(self, x):
    o = self.input(x)
    for layer in self.layers:
      o = layer(o)

    o = self.output(o)
    return o