In [1]:
# This is a first simple GAN that learns to model a very simple distribution
# It is not a convolutional GAN

In [2]:
# Imports
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from tqdm import tqdm
import torch.nn.functional as F
from torch.nn.init import kaiming_uniform_
import io
import PIL.Image

from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor

z_width = 5

In [3]:
# Helper function to save matplotlib plots to a buffer for drawing to tensorboard.

def gen_plot():
    buf = io.BytesIO()
    plt.savefig(buf, format='jpeg')
    buf.seek(0)
    return buf

In [4]:
# Create the distribution of the real data and a sampler
# The 'real data' is sampled from a 2 Dimensional Gaussian distribution

class GaussianDataset(Dataset):
    def __init__(self, mean, cov, N):
        self.data = np.random.multivariate_normal(mean, cov, N).astype(float)
        self.mean = mean
        self.cov = cov
        self.N = N
    
    def __len__(self):
        return self.N
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        return self.data[idx]

In [5]:
# Display contour lines for given distribution

def display_contours(mean, cov, mn_x, mx_x, mn_y, mx_y):
    x, y = np.mgrid[mn_x:mx_x:0.1, mn_y:mx_y:0.1]
    pos = np.empty(x.shape + (2,))
    pos[:, :, 0] = x
    pos[:, :, 1] = y
    rv = multivariate_normal(mean, cov)
    plt.contourf(x, y, np.log(rv.pdf(pos)), cmap='Greys')

In [6]:
# Test the creation of the dataset

real_mean = [2, 4]
real_cov = [[3, 1], [1, 12]]
dataset = GaussianDataset(mean=real_mean, cov=real_cov, N=1000)


plt.figure(figsize=(4, 4))
display_contours(real_mean, real_cov, -10, 10, -10, 10)
for i, datum in enumerate(dataset):
    plt.scatter(datum[0], datum[1])
    
plt.xlim(-10, 10)
plt.ylim(-10, 10)

(-10, 10)

In [7]:
# Generator Network
class Generator(torch.nn.Module):
    def __init__(self, z_width = 5):
        super().__init__()
        self.lin1 = torch.nn.Linear(z_width, 15)
        self.lin2 = torch.nn.Linear(15, 2)
        kaiming_uniform_(self.lin1.weight, nonlinearity='relu')
        kaiming_uniform_(self.lin2.weight, nonlinearity='relu')
        
    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        
        return x

In [8]:
# Critic Network
class Critic(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = torch.nn.Linear(2, 25)
        self.lin2 = torch.nn.Linear(25, 1)
        kaiming_uniform_(self.lin1.weight, nonlinearity='relu')
        kaiming_uniform_(self.lin2.weight, nonlinearity='relu')
        
    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x

In [None]:
# Training
writer = SummaryWriter()

# Create the 'real data' dataset
real_mean = [10, 10]
real_cov = [[0.2, 0], [0, 0.2]]
real_dataset = GaussianDataset(mean=real_mean, cov=real_cov, N=1000)
real_dataloader = DataLoader(real_dataset, batch_size=64, shuffle=True, num_workers=1)

epochs = 50000
batch_size = 64

n_critic = 5

# Create the models
critic = Critic()
C_optimizer = torch.optim.RMSprop(critic.parameters(), lr=0.0005)

gen = Generator()
G_optimizer = torch.optim.RMSprop(gen.parameters(), lr=0.0005)

critic.train()
gen.train()

for i in range(epochs):
    # Sample from the 'real data'
    X_real = next(iter(real_dataloader)).type(torch.FloatTensor)

    # Train discriminator
    z = torch.randn((batch_size, z_width))
    gen_opt = gen(z)
    
    C_optimizer.zero_grad()
    real_score = -torch.mean(critic(X_real)) # criterion(disc(X_real), y)
    fake_score = torch.mean(critic(gen_opt.detach())) # criterion(disc(gen_opt.detach()), y)
    
        # The loss is just the mean of the two losses
    c_score = real_score + fake_score
    c_score.backward()
    C_optimizer.step()
    
    writer.add_scalar('D Score', c_score, i)
    
    for p in critic.parameters():
        p.data.clamp_(-0.1, 0.1)
        
    if i % n_critic == 0:
        # Train generator

        G_optimizer.zero_grad()
        z = torch.randn((batch_size, z_width))
        gen_opt = gen(z)
        
        g_loss = -torch.mean(critic(gen_opt)) # criterion(disc(gen_opt), y)
        writer.add_scalar('G Score', g_loss, i)
        
        g_loss.backward()
        G_optimizer.step()        
        
        
    # Show plot every now and then
    if i % 100 == 0:
        figure = plt.figure(figsize=(4, 4))
        display_contours(real_mean, real_cov, -5, 15, -5, 15)
        datum = gen_opt.detach()
        for d in datum:
            plt.scatter(d[0], d[1])
            
        plt.scatter(torch.mean(datum[:,0]), torch.mean(datum[:,1]), marker='x')

        plt.xlim(-5, 15)
        plt.ylim(-5, 15)

        plot_buf = gen_plot()
        image = PIL.Image.open(plot_buf)
        image = ToTensor()(image)
        writer.add_image('Generated Data', image, i)
    
              
writer.close()

  if __name__ == '__main__':
