In [0]:
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image

import matplotlib.pyplot as plt
import numpy as np

from tqdm.notebook import tqdm # progress bar

import torch.backends.cudnn as cudnn # tuninig
from IPython.display import clear_output

In [0]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")     #Check whether a GPU is present.

# Loading Data

In [0]:
# Loading and Transforming data
transform = transforms.Compose(
    [
     transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))
      ])

trainset = tv.datasets.MNIST(root='./data',  train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4)

testset = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10, shuffle=True, num_workers=2)


# Model

In [0]:
class DiscriminatorNet(torch.nn.Module):

    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        n_features = 784
        n_out = 1
        
        self.layer1 = nn.Sequential( 
            nn.Linear(n_features, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.out = nn.Sequential(
            torch.nn.Linear(256, n_out),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.out(x)
        return x
    

In [0]:
def images_to_vectors(images):
    return images.view(images.size(0), 784)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 1, 28, 28)

In [0]:
class GeneratorNet(torch.nn.Module):
  
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100
        n_out = 784
        
        self.layer1 = nn.Sequential(
            nn.Linear(n_features, 256),
            nn.LeakyReLU(0.2)
        )
        self.layer2 = nn.Sequential(            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(1024, n_out),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.out(x)
        return x
    


In [0]:
# Noise
def noise(size):
    n = Variable(torch.randn(size, 100))
    return n.to(device)

# Train

In [0]:
def real_data_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = Variable(torch.ones(size, 1))
    return data.to(device)

def fake_data_target(size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = Variable(torch.zeros(size, 1))
    return data.to(device)

In [0]:
def train_discriminator(optimizer, real_data, fake_data):
    # Reset gradients
    optimizer.zero_grad()
    
    # 1.1 Train on Real Data
    prediction_real = discriminator(real_data)
    # Calculate error and backpropagate
    error_real = loss(prediction_real, real_data_target(real_data.size(0)))
    error_real.backward()

    # 1.2 Train on Fake Data
    prediction_fake = discriminator(fake_data)
    # Calculate error and backpropagate
    error_fake = loss(prediction_fake, fake_data_target(real_data.size(0)))
    error_fake.backward()
    
    # 1.3 Update weights with gradients
    optimizer.step()
    
    # Return error
    return error_real + error_fake, prediction_real, prediction_fake


In [0]:
def train_generator(optimizer, fake_data):
    # 2. Train Generator
    # Reset gradients
    optimizer.zero_grad()
    # Sample noise and generate fake data
    prediction = discriminator(fake_data)
    # Calculate error and backpropagate
    error = loss(prediction, real_data_target(prediction.size(0)))
    error.backward()
    # Update weights with gradients
    optimizer.step()
    # Return error
    return error

# Samples

In [0]:
num_test_samples = 16
test_noise = noise(num_test_samples)
save_image(test_noise.view(16, 1, 10, 10),
                  F"/content/gdrive/My Drive/Colab Notebooks/ComputerVision/MNIST/results/sample_.png")

# Training

In [0]:
discriminator = DiscriminatorNet()
generator = GeneratorNet()

discriminator.to(device);
generator.to(device);


In [0]:

# Optimizers
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0001)

# Loss function
loss = nn.BCELoss()

epoch = 0


In [97]:

for _ in range(100):
    epoch += 1
    for real_batch,_ in tqdm(trainloader):

        # 1. Train Discriminator
        real_data = Variable(images_to_vectors(real_batch))
        real_data = real_data.to(device)
        # Generate fake data
        fake_data = generator(noise(real_data.size(0))).detach()
        # Train D
        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer,
                                                                real_data, fake_data)

        # 2. Train Generator
        # Generate fake data
        fake_data = generator(noise(real_batch.size(0)))
        # Train G
        g_error = train_generator(g_optimizer, fake_data)
    # Log error
    clear_output()
    print('epoch: {}, discrim: {}, generator: {}'.format(epoch, d_error, g_error))

    # Model Checkpoints
    with torch.no_grad():
        sample = test_noise.to(device)
        sample = generator(sample).cpu()
        save_name = "sample_" + str(epoch) + ".png"
        save_image(sample.view(16, 1, 28, 28),
                   
                  F"/content/gdrive/My Drive/Colab Notebooks/ComputerVision/MNIST/results/{save_name}")
        
        model_save_name = 'checkpoint_.pt'
        path = F"/content/gdrive/My Drive/Colab Notebooks/ComputerVision/MNIST//{model_save_name}" 
        torch.save({
            'discriminator_state_dict': discriminator.state_dict(),
            'generator_state_dict': generator.state_dict(),
            "d_optimizer": d_optimizer.state_dict(),
            "g_optimizer": g_optimizer.state_dict(),
            "test_noise": test_noise,
            "epoch":epoch
            }, path)

epoch: 100, discrim: 1.3043999671936035, generator: 0.9536974430084229


# Test

# Saving

In [11]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
model_save_name = 'GAN_MNIST.pt'
path = F"/content/gdrive/My Drive/Colab Notebooks/ComputerVision/MNIST//{model_save_name}" 
torch.save({
    'discriminator_state_dict': discriminator.state_dict(),
    'generator_state_dict': generator.state_dict(),
    "d_optimizer": d_optimizer.state_dict(),
    "g_optimizer": g_optimizer.state_dict(),
    "test_noise": test_noise,
    "epoch":epoch
    }, path)

In [0]:
model_save_name = 'GAN_MNIST.pt'
path = F"/content/gdrive/My Drive/Colab Notebooks/ComputerVision/MNIST/{model_save_name}" 
checkpoint = torch.load(path)
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
generator.load_state_dict(checkpoint['generator_state_dict'])
d_optimizer.load_state_dict(checkpoint["d_optimizer"])
g_optimizer.load_state_dict(checkpoint["g_optimizer"])
test_noise = checkpoint["test_noise"]
epoch = checkpoint["epoch"]

# References

https://medium.com/intel-student-ambassadors/mnist-gan-detailed-step-by-step-explanation-implementation-in-code-ecc93b22dc60

https://medium.com/ai-society/gans-from-scratch-1-a-deep-introduction-with-code-in-pytorch-and-tensorflow-cb03cdcdba0f

In [0]:
https://github.com/soumith/ganhacks