In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import pickle
import os
import imageio
from matplotlib import pyplot as plt
import numpy as np
import torchvision.utils as vutils

class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(in_features=100, out_features=256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(in_features=256, out_features=512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(in_features=512, out_features=1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(in_features=1024, out_features=28*28),
            nn.Tanh())
    

    def forward(self, inputs):
        return self.main(inputs).view(-1, 1, 28, 28)


class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(in_features=28*28, out_features=1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(inplace=True),
            nn.Linear(in_features=1024, out_features=512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(inplace=True),
            nn.Linear(in_features=512, out_features=256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(inplace=True),
            nn.Linear(in_features=256, out_features=1),
            nn.Sigmoid())
    

    def forward(self, inputs):
        inputs = inputs.view(-1, 28*28)
        return self.main(inputs)

def show_generated_data(real_data, fake_data):
  plt.figure(figsize=(15,5))
  plt.subplot(1,2,1)
  plt.axis("off")
  plt.title("Real Images")
  plt.imshow(np.transpose(vutils.make_grid(real_data[:64], padding=5, normalize=True).cpu(), (1,2,0)))

  plt.subplot(1,2,2)
  plt.axis("off")
  plt.title("Fake Images")
  plt.imshow(np.transpose(vutils.make_grid(fake_data[:64], padding=5, normalize=True).cpu(), (1,2,0)))




def square_plot(data, path):
    if type(data) == list:
	    data = np.concatenate(data)
    data = (data - data.min()) / (data.max() - data.min())

    n = int(np.ceil(np.sqrt(data.shape[0])))

    padding = (((0, n ** 2 - data.shape[0]) ,
                (0, 1), (0, 1))  
               + ((0, 0),) * (data.ndim - 3))  
    data = np.pad(data , padding, mode='constant' , constant_values=1)  

    data = data.reshape((n , n) + data.shape[1:]).transpose((0 , 2 , 1 , 3) + tuple(range(4 , data.ndim + 1)))

    data = data.reshape((n * data.shape[1] , n * data.shape[3]) + data.shape[4:])

    plt.imsave(path, data, cmap='gray')


transform = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize(mean=(0.5,), std=(0.5,)) 
])

mnist = datasets.MNIST(root='/content/gdrive/My Drive/AI/data', download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=60, shuffle=True)

use_gpu = False
if torch.cuda.is_available():
    use_gpu = True
leave_log = True
if leave_log:
    result_dir = '/content/gdrive/My Drive/AI/GAN_generated_images'
    if not os.path.isdir(result_dir):
        os.mkdir(result_dir)

G = Generator()
D = Discriminator()

if use_gpu:
    G.cuda()
    D.cuda()


criterion = nn.BCELoss()

G_optimizer = Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
D_optimizer = Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

if leave_log:
    train_hist = {}
    train_hist['D_losses'] = []
    train_hist['G_losses'] = []
    generated_images = []


z_fixed = Variable(torch.randn(5 * 5, 100), volatile=True)
if use_gpu:
    z_fixed = z_fixed.cuda()


for epoch in range(100):
    
    if leave_log:
        D_losses = []
        G_losses = []
    
    for real_data, _ in dataloader:
        batch_size = real_data.size(0)
        
        real_data = Variable(real_data)

   
        target_real = Variable(torch.ones(batch_size, 1))
        target_fake = Variable(torch.zeros(batch_size, 1))
         
        if use_gpu:
            real_data, target_real, target_fake = real_data.cuda(), target_real.cuda(), target_fake.cuda()
            
        D_result_from_real = D(real_data)
        D_loss_real = criterion(D_result_from_real, target_real)

        z = Variable(torch.randn((batch_size, 100)))
        
        if use_gpu:
            z = z.cuda()
            
        fake_data = G(z)
        
        D_result_from_fake = D(fake_data)
        D_loss_fake = criterion(D_result_from_fake, target_fake)
        
        D_loss = D_loss_real + D_loss_fake
        
        D.zero_grad()
        D_loss.backward()
        D_optimizer.step()
        
        if leave_log:
            D_losses.append(D_loss.data[0])

        z = Variable(torch.randn((batch_size, 100)))
        
        if use_gpu:
            z = z.cuda()
        
        fake_data = G(z)
        D_result_from_fake = D(fake_data)
        G_loss = criterion(D_result_from_fake, target_real)
        
        G.zero_grad()
        G_loss.backward()
        G_optimizer.step()
        
        if leave_log:
            G_losses.append(G_loss.data[0])
   

torch.save(G.state_dict(), "/content/gdrive/My Drive/AI/gan_generator.pkl")
torch.save(D.state_dict(), "/content/gdrive/My Drive/AI/gan_discriminator.pkl")