In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# Utility for saving an entire batch of images to a single image file.
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime 
import os

In [2]:
# Makes pixel values between -1 and 1
# Assume they are transformed from (0, 1)
# Min Value = (0 - 0.5) / 0.5
# Max Value = (1 - 0.5) / 0.5

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(  
        mean=(0.5,),
        std=(0.5,)
    )
])

In [3]:
train_dataset = torchvision.datasets.MNIST(
    root='.',
    train=True,
    transform=transform,
    download=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [4]:
len(train_dataset)

60000

In [5]:
batch_size = 128
data_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=128,
    shuffle=True
)

In [6]:
# Discriminator
D = nn.Sequential(
    nn.Linear(784, 512),
    nn.LeakyReLU(0.2),
    nn.Linear(512, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 1)
)

In [7]:
# Generator
latent_dim = 100
G = nn.Sequential(
    nn.Linear(latent_dim, 256),
    nn.LeakyReLU(0.2),
    nn.BatchNorm1d(256, momentum=0.7),
    nn.Linear(256, 512),
    nn.LeakyReLU(0.2),
    nn.BatchNorm1d(512, momentum=0.7),
    nn.Linear(512, 1024),
    nn.LeakyReLU(0.2),
    nn.BatchNorm1d(1024, momentum=0.7),
    nn.Linear(1024, 784),
    nn.Tanh()
)

In [8]:
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
D = D.to(device)
G = G.to(device)

In [9]:
# Loss and Optimizers
criterion = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [10]:
scale_img = lambda img : (img + 1) / 2

In [11]:
# Create a folder to store the generated images
if not os.path.exists("gan_images"):
  os.makedirs("gan_images")

In [None]:
# Training Phase

# labels to use in the loop
ones_ = torch.ones(batch_size, 1).to(device)
zeros_ = torch.zeros(batch_size, 1).to(device)

# save losses
d_losses = []
g_losses = []

# num_of_epochs
num_of_epochs = 200

for epoch in range(num_of_epochs):
  for inputs, _ in data_loader:
    # don't need targets

    # reshape the input and move to GPU
    n = inputs.size(0)
    inputs = inputs.reshape(n, 784).to(device)

    # set the ones and zeros to the correct size
    ones = ones_[:n]
    zeros = zeros_[:n]

    ### Train Discriminator

    # Real Images
    real_outputs = D(inputs)
    d_loss_real = criterion(real_outputs, ones)

    # Fake Images
    noise = torch.randn(n, latent_dim).to(device)
    fake_images = G(noise)
    fake_output = D(fake_images)
    d_loss_fake = criterion(fake_output, zeros)

    # gradient descent step
    d_loss = 0.5 * (d_loss_real + d_loss_fake)
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()

    ### Train Generator

    # Do It Twice
    for _ in range(2):
      # fake images
      noise = torch.randn(n, latent_dim).to(device)
      fake_images = G(noise)
      fake_output = D(fake_images)

      # reverse the labels
      # challenge to discriminator that the images are real
      g_loss = criterion(fake_output, ones)

      # gradient descent step
      d_optimizer.zero_grad()
      g_optimizer.zero_grad()
      g_loss.backward()
      # we are running g_optimizer which only knows about G model's parameters
      # only G model's parameters are updated
      g_optimizer.step()

    d_losses.append(d_loss.item())
    g_losses.append(g_loss.item())

  print(f"Epoch: {epoch+1}, d_loss: {d_loss.item()}, g_loss: {g_loss.item()}")

  # PyTorch has a function to save a batch of images to file
  fake_images = fake_images.reshape(-1, 1, 28, 28)
  save_image(scale_img(fake_images), f"gan_images/{epoch+1}.png")

Epoch: 1, d_loss: 0.690441370010376, g_loss: 0.6716287136077881
Epoch: 2, d_loss: 0.6806662678718567, g_loss: 0.7305930852890015
Epoch: 3, d_loss: 0.6744400262832642, g_loss: 0.7348179817199707
Epoch: 4, d_loss: 0.7066624760627747, g_loss: 0.6802654266357422
Epoch: 5, d_loss: 0.6907883882522583, g_loss: 0.702825665473938
Epoch: 6, d_loss: 0.69249427318573, g_loss: 0.6688544750213623
Epoch: 7, d_loss: 0.6830054521560669, g_loss: 0.7986036539077759
Epoch: 8, d_loss: 0.693386435508728, g_loss: 0.7508489489555359
Epoch: 9, d_loss: 0.6874825358390808, g_loss: 0.7082078456878662
Epoch: 10, d_loss: 0.6889263391494751, g_loss: 0.7369584441184998
Epoch: 11, d_loss: 0.6893898248672485, g_loss: 0.7091339826583862
Epoch: 12, d_loss: 0.6871373057365417, g_loss: 0.7527778148651123
Epoch: 13, d_loss: 0.6859205961227417, g_loss: 0.6909055709838867
Epoch: 14, d_loss: 0.6740624904632568, g_loss: 0.7446452379226685
Epoch: 15, d_loss: 0.6808047294616699, g_loss: 0.7128826379776001
Epoch: 16, d_loss: 0.693

In [None]:
plt.plot(g_losses, label="g_losses")
plt.plot(d_losses, label="d_losses")
plt.legend();

In [None]:
from skimage.io import imread
a = imread("gan_images/1.png")
plt.imshow(a);

In [None]:
a = imread("gan_images/50.png")
plt.imshow(a);

In [None]:
a = imread("gan_images/100.png")
plt.imshow(a);

In [None]:
a = imread("gan_images/150.png")
plt.imshow(a);

In [None]:
a = imread("gan_images/200.png")
plt.imshow(a);