<a href="https://colab.research.google.com/github/nathanbarry474/google-colab-notebooks/blob/master/MNISTGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Importing the libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, models
import torchvision
from torchvision.utils import make_grid

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Create directory to store results
os.makedirs('output', exist_ok=True)

img_shape = (1, 28, 28)

In [None]:
# Creating the Generator
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.fc1 = nn.Linear(100, 128)
    self.in1 = nn.BatchNorm1d(128)
    self.fc2 = nn.Linear(128, 512)
    self.in2 = nn.BatchNorm1d(512)
    self.fc3 = nn.Linear(512, 1024)
    self.in3 = nn.BatchNorm1d(1024)
    self.fc4 = nn.Linear(1024, 28*28)

  def forward(self, x):
    x = F.leaky_relu(self.in1(self.fc1(x)), 0.2)
    x = F.leaky_relu(self.in2(self.fc2(x)), 0.2)
    x = F.leaky_relu(self.in3(self.fc3(x)), 0.2)
    x = F.tanh(self.fc4(x))
    return x.view(x.shape[0], *img_shape)

In [None]:
# Creating the Discriminator
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.fc1 = nn.Linear(28*28, 512)
    self.fc2 = nn.Linear(512, 256)
    self.fc3 = nn.Linear(256, 128)
    self.fc4 = nn.Linear(128, 1)

  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = F.leaky_relu(self.fc1(x), 0.2)
    x = F.leaky_relu(self.fc2(x), 0.2)
    x = F.leaky_relu(self.fc3(x), 0.2)
    x = F.sigmoid(self.fc4(x))
    return x

In [None]:
# Initializing the classes
generator = Generator()
discriminator = Discriminator()

In [None]:
# Loading the dataset
bs = 64

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = datasets.MNIST('./data', train=True, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=bs, shuffle=True)

In [None]:
# Check for GPU
if torch.cuda.is_available():
  generator.cuda()
  discriminator.cuda()
  loss_func.cuda()
print(torch.cuda.is_available())

False


In [None]:
# Creating the optimizers
G_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.4, 0.999))
D_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.4, 0.999))

In [None]:
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

In [None]:
# Training for loop
epochs = 2

for epoch in range(epochs):
  for i, (X, _) in enumerate(trainloader):

    # Defining real and fake
    # real = Tensor(X.size(0), 1).fill_(1.0)
    # fake = Tensor(X.size(0), 1).fill_(0.0)
    mb_size = X.size(0)
    real = torch.ones(mb_size, 1)
    fake = torch.zeros(mb_size, 1)

    if torch.cuda.is_available() == True:
      real_imgs = X.cuda()
    else:
      real_imgs = X

    G_optimizer.zero_grad()

    G_input = Tensor(np.random.normal(0, 1, (X.shape[0], 100)))

    # Creating the fake image
    G = generator(G_input)

    # Create Generator loss function
    G_loss = F.binary_cross_entropy(discriminator(G), real)
    G_loss.backward()
    G_optimizer.step()

    D_optimizer.zero_grad()

    real_loss = F.binary_cross_entropy(discriminator(real_imgs), real)
    fake_loss = F.binary_cross_entropy(discriminator(G.detach()), fake)
    D_loss = (real_loss + fake_loss)

    D_loss.backward()
    D_optimizer.step()

    print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % 
           (epoch, 20, i, len(trainloader), D_loss.item(), G_loss.item()))

    total_batch = epoch * len(trainloader) + i
    if total_batch % 400 == 0:
      torchvision.utils.save_image(G.data[:25], './output/%d.png' % total_batch, nrow=5, normalize=True)



[Epoch 0/20] [Batch 0/938] [D loss: 1.386544] [G loss: 0.744311]
[Epoch 0/20] [Batch 1/938] [D loss: 1.385953] [G loss: 0.744550]
[Epoch 0/20] [Batch 2/938] [D loss: 1.384624] [G loss: 0.744314]
[Epoch 0/20] [Batch 3/938] [D loss: 1.384537] [G loss: 0.744394]
[Epoch 0/20] [Batch 4/938] [D loss: 1.389061] [G loss: 0.743707]
[Epoch 0/20] [Batch 5/938] [D loss: 1.387528] [G loss: 0.744369]
[Epoch 0/20] [Batch 6/938] [D loss: 1.387593] [G loss: 0.743520]
[Epoch 0/20] [Batch 7/938] [D loss: 1.386473] [G loss: 0.744382]
[Epoch 0/20] [Batch 8/938] [D loss: 1.387883] [G loss: 0.743532]
[Epoch 0/20] [Batch 9/938] [D loss: 1.387345] [G loss: 0.744149]
[Epoch 0/20] [Batch 10/938] [D loss: 1.385799] [G loss: 0.743610]
[Epoch 0/20] [Batch 11/938] [D loss: 1.386323] [G loss: 0.744214]
[Epoch 0/20] [Batch 12/938] [D loss: 1.386132] [G loss: 0.744174]
[Epoch 0/20] [Batch 13/938] [D loss: 1.384783] [G loss: 0.744058]
[Epoch 0/20] [Batch 14/938] [D loss: 1.385975] [G loss: 0.744264]
[Epoch 0/20] [Batch 

KeyboardInterrupt: ignored