# GAN Basic

- Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks(https://arxiv.org/pdf/1511.06434.pdf)

<img src="./GAN.png" width="400">

## 1. Import required libraries

In [39]:
# Vanilla GAN with Multi GPUs + Naming Layers using OrderedDict
# Code by GunhoChoi

import torch
import torch.nn as nn
import torch.utils as utils
import torch.nn.init as init
import torchvision.utils as v_utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict

## 2. Hyperparameter setting

In [40]:
# Set Hyperparameters
# change num_gpu to the number of gpus you want to use

epoch = 50
batch_size = 512
learning_rate = 0.0002
num_gpus = 1
z_size = 50
middle_size = 200
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## 3. Data Setting

In [41]:
# Download Data

mnist_train = dset.MNIST("./", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)

# Set Data Loader(input pipeline)

train_loader = torch.utils.data.DataLoader(dataset=mnist_train,batch_size=batch_size,shuffle=True,drop_last=True)

## 4. Generator

In [42]:
# Generator receives random noise z and create 1x28x28 image
# we can name each layer using OrderedDict

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.layer1 = nn.Sequential(OrderedDict([
                        ('fc1',nn.Linear(z_size,middle_size)),
                        ('bn1',nn.BatchNorm1d(middle_size)),
                        ('act1',nn.ReLU()),
        ]))
        self.layer2 = nn.Sequential(OrderedDict([
                        ('fc2', nn.Linear(middle_size,784)),
                        ('tanh', nn.Tanh()),
        ]))
    def forward(self,z):
        out = self.layer1(z)
        out = self.layer2(out)
        out = out.view(batch_size,1,28,28)

        return out

## 5. Discriminator

In [43]:
# Discriminator receives 1x28x28 image and returns a float number 0~1
# we can name each layer using OrderedDict

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.layer1 = nn.Sequential(OrderedDict([
                        ('fc1',nn.Linear(784,middle_size)),
                        ('bn1',nn.BatchNorm1d(middle_size)),
                        ('act1',nn.LeakyReLU()),  
            
        ]))
        self.layer2 = nn.Sequential(OrderedDict([
                        ('fc2', nn.Linear(middle_size,1)),
                        ('act2', nn.Sigmoid()),
        ]))
                                    
    def forward(self,x):
        out = x.view(batch_size, -1)
        out = self.layer1(out)
        out = self.layer2(out)

        return out

## 6. Put instances on Multi-gpu

In [44]:
# Put class objects on Multiple GPUs using 
# torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
# device_ids: default all devices / output_device: default device 0 
# along with .cuda()

generator = nn.DataParallel(Generator()).to(device)
discriminator = nn.DataParallel(Discriminator()).to(device)

## 8. Set Loss function & Optimizer

In [45]:
# loss function, optimizers, and labels for training

loss_func = nn.MSELoss()
gen_optim = torch.optim.Adam(generator.parameters(), lr=learning_rate,betas=(0.5,0.999))
dis_optim = torch.optim.Adam(discriminator.parameters(), lr=learning_rate,betas=(0.5,0.999))

ones_label = torch.ones(batch_size,1).to(device)
zeros_label = torch.zeros(batch_size,1).to(device)

## 9. Restore Model

In [46]:
# model restore if any

try:
    generator, discriminator = torch.load('./model/vanilla_gan.pkl')
    print("\n--------model restored--------\n")
except:
    print("\n--------model not restored--------\n")
    pass


--------model not restored--------



## 10. Train Model

In [48]:
# train

for i in range(epoch):
    for j,(image,label) in enumerate(train_loader):
        image = image.to(device)
        
        # discriminator

        z = init.normal(torch.Tensor(batch_size,z_size),mean=0,std=0.1).to(device)
        gen_fake = generator.forward(z)
        dis_fake = discriminator.forward(gen_fake)
        
        dis_real = discriminator.forward(image)
        dis_loss = torch.sum(loss_func(dis_fake,zeros_label)) + torch.sum(loss_func(dis_real,ones_label))
        dis_loss.backward(retain_graph=True)
        
        dis_optim.step()
        dis_optim.zero_grad()
        
        # generator

        
        z = init.normal(torch.Tensor(batch_size,z_size),mean=0,std=0.1).to(device)
        gen_fake = generator.forward(z)
        dis_fake = discriminator.forward(gen_fake)
        
        gen_loss = torch.sum(loss_func(dis_fake,ones_label)) # fake classified as real
        gen_loss.backward()
        
        gen_optim.step()
        gen_optim.zero_grad()
    
       
        # model save
        if j % 100 == 0:
            print(gen_loss,dis_loss)
        
            print("{}th iteration gen_loss: {} dis_loss: {}".format(i,gen_loss.data,dis_loss.data))
            v_utils.save_image(gen_fake.data[0:25],"gen_{}_{}.png".format(i,j), nrow=5)

  if __name__ == '__main__':


tensor(0.1315, grad_fn=<SumBackward0>) tensor(0.5196, grad_fn=<ThAddBackward>)
0th iteration gen_loss: 0.1315039098262787 dis_loss: 0.5195950865745544
tensor(0.1241, grad_fn=<SumBackward0>) tensor(0.5205, grad_fn=<ThAddBackward>)
0th iteration gen_loss: 0.12408328801393509 dis_loss: 0.5205001831054688
tensor(0.1243, grad_fn=<SumBackward0>) tensor(0.5172, grad_fn=<ThAddBackward>)
1th iteration gen_loss: 0.12426163256168365 dis_loss: 0.5172330737113953
tensor(0.1308, grad_fn=<SumBackward0>) tensor(0.5024, grad_fn=<ThAddBackward>)
1th iteration gen_loss: 0.130750373005867 dis_loss: 0.5023770332336426
tensor(0.1318, grad_fn=<SumBackward0>) tensor(0.4989, grad_fn=<ThAddBackward>)
2th iteration gen_loss: 0.13181491196155548 dis_loss: 0.4989229738712311
tensor(0.1410, grad_fn=<SumBackward0>) tensor(0.4750, grad_fn=<ThAddBackward>)
2th iteration gen_loss: 0.1410215198993683 dis_loss: 0.4750310182571411
tensor(0.1420, grad_fn=<SumBackward0>) tensor(0.4711, grad_fn=<ThAddBackward>)
3th iteration

KeyboardInterrupt: 