# Import PyTorch

In [0]:
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torch
from torch.autograd import Variable

import os
from torchvision import datasets
import numpy as np
from matplotlib import pyplot

# Let's define some values and hyperparameters



In [0]:
img_shape = (1, 28, 28)
learning_rate = 0.0002
betas = (0.5, 0.999)
z_dimension = 100
epochs = 200
batch_size = 64

# is CUDA there?

In [0]:
cuda = True if torch.cuda.is_available() else False

# Let's get a new dataset, MNIST

In [4]:
os.makedirs("data/minst", exist_ok=True)
os.makedirs("images", exist_ok=True)
data_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
    "data/mnist",
    train=True,
    download=True,
    transform=transforms.Compose(
    [transforms.Resize(28), transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])]
    ),
    ),
    batch_size=batch_size,
    shuffle=True,
  )

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:01, 8520987.42it/s]                            


Extracting data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 133356.51it/s]           
  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:00, 2205734.98it/s]                            
0it [00:00, ?it/s]

Extracting data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 50618.28it/s]            


Extracting data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


# Design the GAN

In [0]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    
    self.model = nn.Sequential(
      nn.Linear(100, 128),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(128,256),
      nn.BatchNorm1d(256, 0.8),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(256, 512),
      nn.BatchNorm1d(512, 0.8),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(512, 1024),
      nn.BatchNorm1d(1024, 0.8),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(1024, int(np.prod(img_shape))),
      nn.Tanh()
    )
    
  def forward(self, z):
    img = self.model(z)
    img = img.view(img.size(0), *img_shape)
    return img

class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    
    self.model = nn.Sequential(
      nn.Linear(int(np.prod(img_shape)), 512),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(512, 256),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(256, 1),
      nn.Sigmoid(),
    )
    
  def forward(self, img):
    img_flat = img.view(img.size(0), -1)
    validity = self.model(img_flat)
    return validity
  
generator = Generator()
discriminator = Discriminator()

if cuda:
  generator.cuda()
  discriminator.cuda()

# Add the Model and the Optimizers

In [0]:
loss = torch.nn.BCELoss()

if cuda:
  loss.cuda()
  
g_optimizer = torch.optim.Adam(generator.parameters(),
                              lr=0.0002,
                              betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(),
                              lr=0.0002,
                              betas=(0.5, 0.999))

# Let's train the model!

In [7]:
float_tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

for epoch in range(epochs):
  for i, (imgs, _) in enumerate(data_loader):
    real = Variable(float_tensor(
            imgs.size(0), 1).fill_(1.0),
                    requires_grad=False)
    fake = Variable(float_tensor(
            imgs.size(0), 1).fill_(0.0),
                    requires_grad=False)
    real_images = Variable(imgs.type(float_tensor))
    
    # Train the Generator
    
    g_optimizer.zero_grad()
    
    z = Variable(float_tensor(
        np.random.normal(0, 1, 
                         (imgs.shape[0],
                        100))))
    generated_images = generator(z)
    g_loss = loss(discriminator(generated_images), real)
    g_loss.backward()
    g_optimizer.step()
    
    # Train the Discriminator
    
    d_optimizer.zero_grad()
    
    real_loss = loss(discriminator(real_images), real)
    fake_loss = loss(discriminator(generated_images.detach()), fake)
    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    d_optimizer.step()
    
    # print training progress
    #print("Epoch %d, Batch %d, D loss: %f, G loss %f" %
     #    (epoch, i, d_loss.item(), g_loss.item()))
    
    batches_done = epoch * len(data_loader) + i
    if batches_done % 500 == 0:
      
      print("Epoch %d, Batch %d" % (epoch, i))
      save_image(generated_images.data[:25], "images/%d.png" % batches_done, 
                 nrow=5, 
                 normalize=True)
     

Epoch 0, Batch 0
Epoch 0, Batch 500
Epoch 1, Batch 62
Epoch 1, Batch 562
Epoch 2, Batch 124
Epoch 2, Batch 624
Epoch 3, Batch 186
Epoch 3, Batch 686
Epoch 4, Batch 248
Epoch 4, Batch 748
Epoch 5, Batch 310
Epoch 5, Batch 810
Epoch 6, Batch 372
Epoch 6, Batch 872
Epoch 7, Batch 434
Epoch 7, Batch 934
Epoch 8, Batch 496
Epoch 9, Batch 58
Epoch 9, Batch 558
Epoch 10, Batch 120
Epoch 10, Batch 620
Epoch 11, Batch 182
Epoch 11, Batch 682
Epoch 12, Batch 244
Epoch 12, Batch 744
Epoch 13, Batch 306
Epoch 13, Batch 806
Epoch 14, Batch 368
Epoch 14, Batch 868
Epoch 15, Batch 430
Epoch 15, Batch 930
Epoch 16, Batch 492
Epoch 17, Batch 54
Epoch 17, Batch 554
Epoch 18, Batch 116
Epoch 18, Batch 616
Epoch 19, Batch 178
Epoch 19, Batch 678
Epoch 20, Batch 240
Epoch 20, Batch 740
Epoch 21, Batch 302
Epoch 21, Batch 802
Epoch 22, Batch 364
Epoch 22, Batch 864
Epoch 23, Batch 426
Epoch 23, Batch 926
Epoch 24, Batch 488
Epoch 25, Batch 50
Epoch 25, Batch 550
Epoch 26, Batch 112
Epoch 26, Batch 612
Epoch