# Generative Adversarial Networks

https://thispersondoesnotexist.com/

In [None]:
deep learning architecture that have 

2 neural networks that competing against each other

zero-sum

In [None]:
applications

- image synthesis and generation
- image-to-image translation
- text-to-image synthesis
- style transfer

In [None]:
1 - Generative
- generative
- discriminative

2 - Adversarial
dynamics of both the networks or relationship between both the networks

3 - Networks
deep neural networks are used (convolutional and deconvolutional)

<img src='g1.jpg' /> 

In [None]:
- the generator tries to maximize the probability of the discriminator making mistakes
(maximizing the loss of discriminator)

- the discriminator estimates the probability that the sample is got from the training data 
  not from the generator
    


<img src='g2.png' />

In [None]:
Dx -> prediction of discriminator on the real data

1-DGz -> prediction of discriminator on the fake data(generated by generator)

<img src='g3.png' />

In [None]:
types of GANs

1 - Vanila GAN
2 - Conditional GAN
3 - Deep Convolutional GAN
4 - Super Resolution GAN

# Implementation

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [2]:
class Generator(nn.Module):
    def __init__(self,latent_dim, img_shape):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128,784),
            nn.Tanh()
        )
        self.img_shape = img_shape
    
    # view() reshapes the tensor without copying memory, similar to numpy's reshape().
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128,1),
            nn.Sigmoid()
        )
        self.img_shape = img_shape
    
    # if there is any situation that you don't know how many rows you want 
    # but are sure of the number of columns, then you can specify this with a -1. 
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

In [3]:
latent_dim = 100
img_shape = (1,28,28)
batch_size = 64
epochs = 100
lr = 0.0002

generator = Generator(latent_dim, img_shape)
discriminator = Discriminator(img_shape)

adversarial_loss = nn.BCELoss() # Binary Cross Entropy
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

In [4]:
# data

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:06<00:00, 1468324.09it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 26005945.43it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:04<00:00, 408266.68it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4513273.81it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [None]:
# training

for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        batch_size = imgs.shape[0]
        
        valid = torch.ones(batch_size, 1)
        fake = torch.zeros(batch_size, 1)
        
        # generator
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, latent_dim)
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.stop()
        
        # discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss)/2
        d_loss.backward()
        optimizer_D.step()
        
        if i%100 == 0:
            print(f"Epoch {epoch}/{epochs} [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")
        
        if epoch % 10 == 0:
            with torch.no_grad():
                z = torch.randn(25, latent_dim)
                gen_imgs = generator(z)
                gen_imgs = gen_imgs.view(-1, 28, 28).numpy()

                plt.figure(figsize=(5,5))
                for k in range(gen_imgs.shape[0]):
                    plt.subplot(5,5,k+1)
                    plt.imshow(gen_imgs[k], cmap='gray')
                    plt.axis('off')
                plt.savefig(f'gan_generated_image_{epoch}.png')
                plt.close()