# GAN Basic
* Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks(https://arxiv.org/pdf/1511.06434.pdf)

<img src="https://github.com/GunhoChoi/PyTorch-FastCampus/raw/1e9ba63e9ccffa28637f9c441d6d5faf35696268/09_GAN/0_GAN/GAN.png" width="50%">

### 1. Import required libraries

In [1]:
import numpy as np
import torch
from torch import nn
import torch.utils as utils
import torch.nn.init as init
import torchvision.utils as v_utils
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict

### 2. HyperParameter Setting

In [2]:
epoch = 100
batch_size = 500
learning_rate = 0.002
z_size = 100
middle_size= 200

### 3. Data Setting

In [3]:
# Download Data
mnist_train=torchvision.datasets.MNIST(root="./", 
                          train=True, 
                          transform=transforms.ToTensor(),
                          download=True)
# Set Data Loader
train_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                           batch_size=batch_size,
                                           shuffle=True)

### 4. Generator

In [4]:
# OrderedDict를 쓴다. 
class Generator(nn.Module): # 가짜 이미지를 만든다. 
    def __init__(self):
        super(Generator,self).__init__()
        self.layer1 = nn.Sequential(
                nn.Linear(in_features=50, out_features=middle_size),
                #('bn1', nn.BatchNorm1d(middle_size)),
                nn.ReLU()
        )
        
        self.layer2 = nn.Sequential(
                nn.Linear(middle_size, 784),
                nn.Tanh()
            )
        
    def forward(self, z):
        print(self.layer1[0].weight.type())
        print(">>>>",z_size)
        print(z.size())
        print(self.layer1)
        
        z = z.view(-1,100)
        print(z.size())
        out = self.layer1(z.cpu().float())
        out = self.layer2(out)
        print(out.size())
        out = out.view(batch_size,1,28,28)
        
        return out

### 5. Discriminator

In [5]:
class Discriminator(nn.Module): # 얘는 구별한다. 
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Linear(784, middle_size),
            nn.LeakyReLU(),
        )
        
        self.layer2 = nn.Sequential(
            nn.Linear(middle_size,1),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        out = x.view(batch_size, -1)
        print("*****")
        print(out.size())
        out = self.layer1(out.cpu())
        out = self.layer2(out.cpu())
        print(out.size())
        print("*****")
        return out

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
generater = Generator().to(device)
discriminater = Discriminator().to(device)
print(next(generater.parameters()).is_cuda)
print(next(discriminater.parameters()).is_cuda)
torch.cuda.device_count()

True
True


1

### 모델 안에는 state_dict()라는 것이 있다.
### 6. Check layers

In [7]:
# gen_params = generater.state_dict().keys()
# dis_params = discriminater.state_dict().keys()

# for i in gen_params:
#     print(i)
# print("-----------------------")    
# for j in dis_params:
#     print(j)

### 7. Set Loss function & Optimizer

In [8]:

loss_func = nn.MSELoss()
gen_optim = torch.optim.Adam(generater.parameters(), 
                             lr=learning_rate,betas=(0.5, 0.999))
dis_optim = torch.optim.Adam(discriminater.parameters(), 
                             lr=learning_rate, betas=(0.5 ,0.999))

In [9]:
print(device)
ones_label  = torch.ones(batch_size,1).requires_grad_(True).to(device)
zeros_label = torch.zeros(batch_size,1).requires_grad_(True).to(device)

cuda:0


In [10]:
try :
    generater, discriminater = torch.load('./model/vanilla_gan.pkl')
    print("\n--------model restored---------\n")
except:
    print("\n--------model not restored---------\n")
    pass


--------model restored---------



### 8. Train Model

* Model 안의 레이어는 torch.FloatTensor이다. 이러한 모델에 인자를 넣어주는데 넣어주는 것들을 .requires_grad_(True)로 만들고 쿠다로 해야 이게 GPU로 돌아간다.  

In [11]:
for i in range(epoch):
    for j,(image, label) in enumerate(train_loader):   
        
        # discriminator
        dis_optim.zero_grad()
        # GPU로 돌리자 @@
        #z=init.normal(torch.Tensor(batch_size,z_size), mean=0, std=0.1).to('cuda')
        z = init.normal(torch.Tensor(batch_size,z_size), mean=0, std=0.1).requires_grad_(True).to(device)
        gen_fake = generater(z)            # 만든 가짜 이미지
        dis_fake = discriminater(gen_fake).requires_grad_(True).to(device)
        print("&&&>",image.size())
        dis_real = discriminater(image).requires_grad_(True).to(device)
        
        dis_loss = torch.sum(loss_func(dis_fake, zeros_label)) + torch.sum(loss_func(dis_real, ones_label))
        dis_loss.backward(retain_graph=True)
        dis_optim.step()
        
        # generator 
        gen_optim.zero_grad()
        z = init.normal(torch.Tensor(batch_size,z_size), mean=0, std=0.1).requires_grad_(True).to(device)
        gen_fake = generater(z)
        dis_fake = discriminater(gen_fake).requires_grad_(True).to(device)
        
        gen_loss = torch.sum(loss_func(dis_fake, ones_label)) # fake classified as real
        gen_loss.backward()
        gen_optim.step()
        
        if j % 100 == 0:
            print(gen_loss, dis_loss)
            torch.save( [generater, discriminater], "./model/vanilla_gan.pkl")
            
            print("{}th iteration gen_loss: {} dis_loss: {}".format(i, gen_loss.data, dis_loss.data))
            v_utils.save_image(gen_fake.data[0:25],"./result/gen_{}_{}.png".format(i,j), nrow=5)
        

torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
tensor(92.5797, device='cuda:0', grad_fn=<SumBackward0>) tensor(215.6323, device='cuda:0', grad_fn=<ThAddBackward>)
0th iteration gen_loss: 92.57969665527344 dis_loss: 215.63233947753906
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500

  
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size

torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_feat

torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_feat

torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_feat

torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_feat

torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_feat

torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_feat

torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_features=100, out_features=200, bias=True)
  (act1): ReLU()
)
torch.Size([500, 100])
torch.Size([500, 784])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
&&&> torch.Size([500, 1, 28, 28])
*****
torch.Size([500, 784])
torch.Size([500, 1])
*****
torch.FloatTensor
>>>> 100
torch.Size([500, 100])
Sequential(
  (fc1): Linear(in_feat