<a href="https://colab.research.google.com/github/dibyanshu2305/Deep_learning_course_notebooks/blob/main/IE643_GAN_MNIST_moodle.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
tf.test.gpu_device_name()
# device_name = tf.test.gpu_device_name()
#  if device_name != '/device:GPU:0':
#    raise SystemError('GPU device not found')
#  print('Found GPU at: {}'.format(device_name))

In [None]:
import torch
import torch.optim as opt
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import matplotlib.pyplot as plt

In [None]:
mb_size = 64   #Mini-batch size

def get_indices(dataset):
    indices =  []
    for i in range(len(dataset.targets)):  #use train_labels if error with tragets
        if dataset.targets[i] == 2 or dataset.targets[i] == 4 or dataset.targets[i] == 6:
            indices.append(i)   #indices of data with labels 2, 4 or 6
    return indices

trainData = torchvision.datasets.MNIST('./data/', download=True, transform=transforms.ToTensor(), train=True)

idx = get_indices(trainData)
print(len(idx))

trainLoader = torch.utils.data.DataLoader(trainData,batch_size=mb_size, 
                                          sampler = torch.utils.data.sampler.SubsetRandomSampler(idx))

In [None]:
# No. of training data = 17718

In [None]:
dataIter = iter(trainLoader)

imgs, labels = dataIter.next()

In [None]:
imgs.shape  #shape of the tensor data obtained from the train loader

In [None]:
#visualization of data on a grid
def imshow(imgs):
    imgs = torchvision.utils.make_grid(imgs)
    npimgs = imgs.numpy()
    plt.figure(figsize=(8,8))
    plt.imshow(np.transpose(npimgs, (1,2,0)), cmap='gray')
    plt.xticks([])
    plt.yticks([])
    plt.show()

In [None]:
imshow(imgs) #using imshow() to obtain the grid

![Generative Adversarial Network](https://www.kdnuggets.com/wp-content/uploads/generative-adversarial-network.png)

In [None]:
Z_dim = 100  #size of the generated data
H_dim = 128  #no. of hidden neurons
X_dim = imgs.view(imgs.size(0), -1).size(1) #output neurons to generate an image

print(Z_dim, H_dim, X_dim)

In [None]:
#neural network for generative network
class Gen(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(Z_dim, H_dim),
            nn.ReLU(),
            nn.Linear(H_dim, X_dim),
            nn.Sigmoid()
        )
          
    def forward(self, input):
        return self.model(input)

In [None]:
G = Gen()

In [None]:
#neural network for discriminative model
class Dis(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(X_dim, H_dim),
            nn.ReLU(),
            nn.Linear(H_dim, 1),
            nn.Sigmoid()
        )
        
    def forward(self, input):
        return self.model(input)

In [None]:
D = Dis()

In [None]:
#print the network architecture
print(G)
print(D)

In [None]:
lr = 1e-3  #learning rate
#optimizers for both models
g_opt = opt.Adam(G.parameters(), lr=lr)
d_opt = opt.Adam(D.parameters(), lr=lr)

In [None]:
for epoch in range(100):
    G_loss_run = 0.0
    D_loss_run = 0.0
    
    for i, data in enumerate(trainLoader):
        X, _ = data
        X = X.view(X.size(0), -1)
        mb_size = X.size(0)
        
        one_labels = torch.ones(mb_size, 1)
        zero_labels = torch.zeros(mb_size, 1)
        
        z = torch.randn(mb_size, Z_dim)
        
        D_real = D(X)
        D_fake = D(G(z))
        
        D_real_loss = F.binary_cross_entropy(D_real, one_labels)  #loss -(1/m)(log D(x))
        D_fake_loss = F.binary_cross_entropy(D_fake, zero_labels)  #loss -(1/m)(log(1-D(G(z))))
        D_loss = D_real_loss + D_fake_loss
        
        d_opt.zero_grad()
        D_loss.backward()
        d_opt.step()
        
        z = torch.randn(mb_size, Z_dim)
        
        D_fake = D(G(z))
        G_loss = F.binary_cross_entropy(D_fake, one_labels)  #loss -(1/m)(log (1-D(G(z))))
        
        g_opt.zero_grad()
        G_loss.backward()
        g_opt.step()
        
        G_loss_run += G_loss.item()
        D_loss_run += D_loss.item()
        
    print('Epoch:{},   G_loss:{},    D_loss:{}'.format(epoch, G_loss_run/(i+1), D_loss_run/(i+1)))
    
    with torch.no_grad():
        samples = G(z).detach()
        samples = samples.view(samples.size(0), 1, 28, 28)
        imshow(samples)