# M2177.003100 Deep Learning <br> Assignment #4 Implementing Generative Adversarial Nets - part2 MNIST data

Copyright (C) Data Science Laboratory, Seoul National University. This material is for educational uses only. Some contents are based on the material provided by other paper/book authors and may be copyrighted by them. Written by Chaehun Shin, September 2020

In this notebook, you will learn how to implement Generative Adversarial Networks (GANs) <br>
The goal here is to build GANs that draw hand-written digits (MNIST data) <br> 

**Note**: certain details are missing or ambiguous on purpose, in order to test your knowledge on the related materials. However, if you really feel that something essential is missing and cannot proceed to the next step, then contact the teaching staff with clear description of your problem.

### Submitting your work:
<font color=red>**DO NOT clear the final outputs**</font> so that TAs can grade both your code and results.  
Once you have done **all parts**, run the *CollectSubmission.sh* script with your **Student_ID** as input argument. <br>
This will produce a zipped file called *[Your Student_ID].zip*. Please submit this file on ETL. &nbsp;&nbsp; (Usage: ./*CollectSubmission.sh* &nbsp; Student_ID)

### Some helpful tutorials and references for assignment #4-2:
- [1] Pytorch official tutorials. [[link]](https://pytorch.org/tutorials/)
- [2] Stanford CS231n lectures. [[link]](http://cs231n.stanford.edu/)
- [3] Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems. 2014.
- [4] Mirza, Mehdi, and Simon Osindero. "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784 (2014).
- [5] Radford, Alec, Luke Metz, and Soumith Chintala. "Unsupervised representation learning with deep convolutional generative adversarial networks." arXiv preprint arXiv:1511.06434 (2015).

## 0. Download and load MNIST datasets
The MNIST datasets will be downloaded into the 'data/mnist' directory. If you want to change the directory the data is saved in, change 'mnist_data_dir' with where you want. <br>

In [None]:
import os
from torchvision.datasets import MNIST
import torchvision.transforms as T

mnist_data_dir = './data/mnist'
dataset = MNIST(root=mnist_data_dir,
               transform=T.ToTensor(), train=True, download=True)
print(dataset.train_data.shape)
print(dataset.train_labels.shape)

## <a name="1"></a> 1. Building a network


In this section, you will implement neural networks for <br>
(1) generator model to draw a digit <br>
(2) discriminator model to distinguish real images from generated images.<br>
You can use some layer function implemented in **'torch.nn'** library (abbretivated as **nn**) or **'torch.nn.functional'** library (abbreviated as **F**) as you want.
Just write the code in whatever way you find most clear.

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

It is time for a generator model.
You can change anything including the argument if you need. Feel Free to implement it.<br>
**(You should output the image as a range (0, 1) with Sigmoid function because we normalize the real images as a range (0, 1))**

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=10, image_dim=1):
        super().__init__()
        ################ ToDo ################

        
    def forward(self, z):
        ################ ToDo ################

        return out
        

Now, it's time for a discriminator model. Again, you can implement anything if you need. <br>
**(You should output the probability of whether the input image of discriminator is real or not. It means that you use the Sigmoid function at the last layer to make the value being in range (0, 1))**

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_dim=3):
        super().__init__()
        ################ ToDo ################

        
    def forward(self, img):
        ################ ToDo ################

        return out
    

## <a name="2"></a> 2. Build a main part and train

In this section, you will implement the main part (define criterion variable and D_loss/G_loss to train in TODO parts, you can also use the criterion variable).
Feel free to set the hyperparmeters and fill in the main part.
When you are done, run the following to check your implementations.
In this section, you will implement the main part (define criterion variable and D_loss/G_loss to train in TODO parts).

You must show **at least three generated images** (At the beginning of, in the midway of, at the end of training).

In [None]:
# hyperparameter setting
img_dim=1
img_size = 28
latent_dim = 100
num_D_updates_per_G_update = 5

batch_size = 128
learning_rate = 1e-4
total_iter = 50000

log_freq = 10
viz_freq = 200

gen_num_samples = 50

In [None]:
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
dataloader_iter = iter(dataloader)

netG = Generator(latent_dim, img_dim).to(device)
netD = Discriminator(img_dim).to(device)

optimG = torch.optim.Adam(netG.parameters(), learning_rate)
optimD = torch.optim.Adam(netD.parameters(), learning_rate)

real_labels = torch.ones(batch_size).to(device)
fake_labels = torch.zeros(batch_size).to(device)

for it in range(total_iter):
    
    # train Discriminator
    for _ in range(num_D_updates_per_G_update):
        try:
            real_imgs, _ = next(dataloader_iter)
        except:
            dataloader_iter = iter(dataloader)
            real_imgs, _ = next(dataloader_iter)

        real_imgs = real_imgs.to(device)

        z = torch.randn((batch_size, latent_dim)).to(device)
        fake_imgs = netG(z).detach()
        
        real_probs = netD(real_imgs).squeeze()
        fake_probs = netD(fake_imgs).squeeze()

        ################ ToDo ################
        D_loss = 
        
        optimD.zero_grad()
        D_loss.backward()
        optimD.step()
      
    # train the Generator
    z = torch.randn((batch_size, latent_dim)).to(device)
    fake_imgs = netG(z)
        
    fake_probs = netD(fake_imgs)

    ################ ToDo ################
    G_loss = 

    optimG.zero_grad()
    G_loss.backward()
    optimG.step()
    
    if (it+1) % log_freq == 0:
        print("Iter: %05d/%d, Gen loss: %.4f, Dis loss: %.4f"%(it+1, total_iter,
                                                              D_loss.data.item(),
                                                              G_loss.data.item()))
    if (it+1) % viz_freq == 0:
        z = torch.randn((gen_num_samples, latent_dim)).to(device)
        with torch.no_grad():
            gen_imgs = netG(z)
        
        gen_imgs = make_grid(gen_imgs, nrow=10).permute(1, 2, 0).cpu().detach().numpy()
        plt.imshow(gen_imgs)
        plt.show()