In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 5GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
from torchvision.transforms import ToTensor, Normalize, Compose
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader


%matplotlib inline

In [None]:
#Compose is used to combined different transformations. Monrmalise will normalize the ranges to -1 to +1. (x-mean)/std

mnist = MNIST(root='data', 
              train=True, 
              download=True,
              transform=Compose([ToTensor(), Normalize(mean=(0.5,), std=(0.5,))]))

In [None]:
img, label = mnist[0]

print('label:', label)
# print(img)
print(img.shape)
print(torch.min(img),torch.max(img))


In [None]:
#a function to denormalize the image

def denorm(img):
    out = (img + 1)/2
    return out.clamp(0,1)

In [None]:
img_norm = denorm(img)
plt.imshow(img_norm[0],cmap='gray')
print(label)

In [None]:

batch_size = 100

dataloader = DataLoader(mnist,batch_size,shuffle=True)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

**Discriminator Network[](http://)**

In [None]:
image_size = 784 #28*28
hidden_size = 256

D = nn.Sequential(
    nn.Linear(image_size,hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size,hidden_size),
    nn.LeakyReLU(0.2),    
    nn.Linear(hidden_size,1),
    nn.Sigmoid())

In [None]:
D.to(device)

**Define Generator Network**

In [None]:
latent_size = 64

G = nn.Sequential(
    nn.Linear(latent_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,image_size),
    nn.Tanh())


In [None]:
# testing G with a random sample a 2 rows of data

y = G(torch.randn(2,latent_size))

gimg = denorm(y.reshape(-1,28,28).detach())

plt.imshow(gimg[0],cmap='gray')


In [None]:
G.to(device)

**Train Discriminator**

In [None]:
learning_rate = 0.0002

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(),lr=learning_rate)


In [None]:
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()
    
def train_discriminator(images):
    fake_labels = torch.zeros(batch_size,1).to(device)
    real_labels = torch.ones(batch_size,1).to(device)
    
    
    #find real loss
    
    d_real_out = D(images)
    d_real_loss = criterion(d_real_out,real_labels)
    
    #find fake_loss
    z = torch.randn(batch_size,latent_size).to(device)
    fake_images = G(z)
    d_fake_out = D(fake_images)
    d_fake_loss = criterion(d_fake_out,fake_labels)
    
    
    d_loss = d_real_loss + d_fake_loss
    
    reset_grad()
    
    d_loss.backward()
    
    d_optimizer.step()
    
    
    return d_loss, d_real_out, d_fake_out
    
       

Train the generator

In [None]:
g_optimizer = torch.optim.Adam(G.parameters(),learning_rate)

In [None]:
def train_generator():
    
    x = torch.randn(batch_size,latent_size).to(device)
    
    fake_images = G(x)
    labels = torch.ones(batch_size,1).to(device)
    
    g_loss = criterion(D(fake_images),labels)
    
    reset_grad()
    g_loss.backward()
    
    g_optimizer.step()
    
    return g_loss, fake_images
        

In [None]:
# creates a Samples directory

import os

sample_dir = 'samples'

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
else:
    print("directory already exists")
    

In [None]:
#save a real image
from IPython.display import Image
from torchvision.utils import save_image

#save the first batch of images

for images,_ in dataloader:
    print(images.shape)
    images = images.reshape(images.size(0),1,28,28)
    save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'), nrow=10)
    break
    
Image(os.path.join(sample_dir, 'real_images.png'))



In [None]:
#save generator output images after each epoch, passing the same sample latent vector

sample_vectors = torch.randn(batch_size, latent_size).to(device)

def save_fake_images(index):
    fake_images = G(sample_vectors)
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    fake_fname = 'fake_images-{0:0=4d}.png'.format(index)
    print('Saving', fake_fname)
    save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=10)
    
# Before training
save_fake_images(0)
Image(os.path.join(sample_dir, 'fake_images-0000.png'))

In [None]:
#training loop

num_epochs = 300
total_step = len(dataloader)

d_losses, g_losses, real_scores, fake_scores = [],[],[],[]

for epoch in range(num_epochs):
    
    for i, (img,_) in enumerate(dataloader):
        
        img = img.reshape(batch_size,-1).to(device)
        
        d_loss,real_score, fake_score = train_discriminator(img)
        
        g_loss,fake_img = train_generator()
        
        if (i+1) % 200 == 0:
            d_losses.append(d_loss.item())
            g_losses.append(g_loss.item())
            real_scores.append(real_score.mean().item())
            fake_scores.append(fake_score.mean().item())
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
        
# Sample and save images
    save_fake_images(epoch+1)

In [None]:
# Save the model checkpoints 
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

In [None]:
#see the images

Image('./samples/fake_images-0025.png')

In [None]:
#see the images

Image('./samples/fake_images-0125.png')

In [None]:
#see the images

Image('./samples/fake_images-0200.png')

In [None]:
#see the images

Image('./samples/fake_images-0300.png')

In [None]:
#### combine the images to a video

import cv2
import os
from IPython.display import FileLink

vid_fname = 'gans_training.avi'

files = [os.path.join(sample_dir, f) for f in os.listdir(sample_dir) if 'fake_images' in f]
files.sort()

out = cv2.VideoWriter(vid_fname,cv2.VideoWriter_fourcc(*'MP4V'), 8, (302,302))
[out.write(cv2.imread(fname)) for fname in files]
out.release()
FileLink('gans_training.avi')

In [None]:
#plot the losses

plt.plot(d_losses, '-')
plt.plot(g_losses, '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['Discriminator', 'Generator'])
plt.title('Losses');




In [None]:
# plot the scores

plt.plot(real_scores, '-')
plt.plot(fake_scores, '-')
plt.xlabel('epoch')
plt.ylabel('score')
plt.legend(['Real Score', 'Fake score'])
plt.title('Scores');