# Import PyTorch

In [0]:
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torch
from torch.autograd import Variable

import os
from torchvision import datasets
import numpy as np
from matplotlib import pyplot

# Let's define some values and hyperparameters



In [0]:
img_shape = (1, 28, 28)
learning_rate = 0.0002
betas = (0.5, 0.999)
z_dimension = 100
epochs = 200
batch_size = 64

# is CUDA there?

In [0]:
cuda = True if torch.cuda.is_available() else False

# Let's get a new dataset, CIFAR10

In [0]:
os.makedirs("data/minst", exist_ok=True)
os.makedirs("images", exist_ok=True)
data_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
    "data/mnist",
    train=True,
    download=True,
    transform=transforms.Compose(
    [transforms.Resize(28), transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])]
    ),
    ),
    batch_size=batch_size,
    shuffle=True,
  )


0it [00:00, ?it/s][A

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw/train-images-idx3-ubyte.gz



  0%|          | 0/9912422 [00:00<?, ?it/s][A
  0%|          | 16384/9912422 [00:00<02:08, 77090.42it/s][A
  0%|          | 40960/9912422 [00:01<01:51, 88440.37it/s][A
  1%|          | 98304/9912422 [00:01<01:27, 112634.38it/s][A
  2%|▏         | 212992/9912422 [00:01<01:04, 149337.58it/s][A
  4%|▍         | 425984/9912422 [00:01<00:46, 202333.47it/s][A
  7%|▋         | 655360/9912422 [00:01<00:34, 270561.98it/s][A
 10%|█         | 1024000/9912422 [00:01<00:24, 365706.33it/s][A
 19%|█▊        | 1835008/9912422 [00:02<00:16, 504797.26it/s][A
 33%|███▎      | 3260416/9912422 [00:02<00:09, 701856.15it/s][A
 59%|█████▉    | 5832704/9912422 [00:02<00:04, 981898.21it/s][A
 77%|███████▋  | 7593984/9912422 [00:02<00:01, 1344664.37it/s][A
9920512it [00:02, 3619126.44it/s]                             [A

Extracting data/mnist/MNIST/raw/train-images-idx3-ubyte.gz



0it [00:00, ?it/s][A

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz



  0%|          | 0/28881 [00:00<?, ?it/s][A
 57%|█████▋    | 16384/28881 [00:00<00:00, 86342.82it/s][A
32768it [00:00, 55949.71it/s]                           [A
0it [00:00, ?it/s][A

Extracting data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz



  0%|          | 0/1648877 [00:00<?, ?it/s][A
  1%|          | 16384/1648877 [00:00<00:21, 76824.29it/s][A
  2%|▏         | 40960/1648877 [00:00<00:18, 88135.66it/s][A
  6%|▌         | 98304/1648877 [00:00<00:13, 112240.89it/s][A
 13%|█▎        | 212992/1648877 [00:01<00:09, 148692.09it/s][A
 19%|█▉        | 319488/1648877 [00:01<00:06, 191576.58it/s][A
 34%|███▍      | 565248/1648877 [00:01<00:04, 258095.50it/s][A
 51%|█████     | 843776/1648877 [00:01<00:02, 343929.72it/s][A
 89%|████████▉ | 1466368/1648877 [00:01<00:00, 471108.09it/s][A
1654784it [00:01, 874646.82it/s]                             [A
0it [00:00, ?it/s][A

Extracting data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz



  0%|          | 0/4542 [00:00<?, ?it/s][A

Extracting data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


# Design the GAN

In [0]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    
    self.model = nn.Sequential(
      nn.Linear(100, 128),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(128,256),
      nn.BatchNorm1d(256, 0.8),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(256, 512),
      nn.BatchNorm1d(512, 0.8),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(512, 1024),
      nn.BatchNorm1d(1024, 0.8),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(1024, int(np.prod(img_shape))),
      nn.Tanh()
    )
    
  def forward(self, z):
    img = self.model(z)
    img = img.view(img.size(0), *img_shape)
    return img

class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    
    self.model = nn.Sequential(
      nn.Linear(int(np.prod(img_shape)), 512),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(512, 256),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(256, 1),
      nn.Sigmoid(),
    )
    
  def forward(self, img):
    img_flat = img.view(img.size(0), -1)
    validity = self.model(img_flat)
    return validity
  
generator = Generator()
discriminator = Discriminator()

if cuda:
  generator.cuda()
  discriminator.cuda()

# Add the Model and the Optimizers

In [0]:
loss = torch.nn.BCELoss()

if cuda:
  loss.cuda()
  
g_optimizer = torch.optim.Adam(generator.parameters(),
                              lr=0.0002,
                              betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(),
                              lr=0.0002,
                              betas=(0.5, 0.999))

# Let's train the model!

In [0]:
float_tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

for epoch in range(epochs):
  for i, (imgs, _) in enumerate(data_loader):
    real = Variable(float_tensor(
            imgs.size(0), 1).fill_(1.0),
                    requires_grad=False)
    fake = Variable(float_tensor(
            imgs.size(0), 1).fill_(0.0),
                    requires_grad=False)
    real_images = Variable(imgs.type(float_tensor))
    
    # Train the Generator
    
    g_optimizer.zero_grad()
    
    z = Variable(float_tensor(
        np.random.normal(0, 1, 
                         (imgs.shape[0],
                        100))))
    generated_images = generator(z)
    g_loss = loss(discriminator(generated_images), real)
    g_loss.backward()
    g_optimizer.step()
    
    # Train the Discriminator
    
    d_optimizer.zero_grad()
    
    real_loss = loss(discriminator(real_images), real)
    fake_loss = loss(discriminator(generated_images.detach()), fake)
    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    d_optimizer.step()
    
    # print training progress
    #print("Epoch %d, Batch %d, D loss: %f, G loss %f" %
     #    (epoch, i, d_loss.item(), g_loss.item()))
    
    batches_done = epoch * len(data_loader) + i
    if batches_done % 500 == 0:
      
      print("Epoch %d, Batch %d" % (epoch, i))
      save_image(generated_images.data[:25], "images/%d.png" % batches_done, 
                 nrow=5, 
                 normalize=True)
     

Epoch 0, Batch 0
Epoch 0, Batch 500
Epoch 1, Batch 62
Epoch 1, Batch 562
Epoch 2, Batch 124
Epoch 2, Batch 624


KeyboardInterrupt: ignored