<a href="https://colab.research.google.com/github/liupengzhouyi/LearningPytorch/blob/master/20200331/GANNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GAN Net

## import package

In [0]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

## setting device

In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Setting super paramets

In [0]:
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

## create dir

* 建立文件夹用于存放生成的样本
* exists(path):测试是否存在某路径，如不存在则建立mkdirs(name)

In [0]:
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

## Download Data

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

mnist = torchvision.datasets.MNIST(root='../../data/',
                                   train=True,
                                   transform=transform,
                                   download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size,
                                          shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ../../data/MNIST/raw/train-images-idx3-ubyte.gz to ../../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ../../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw
Processing...
Done!


## Create Discriminator

1、定义判别器（Discriminator）

2、LeakyReLU(negative_slope=1e-2, inplace=False):关键参数只有一个斜率，默认值1e-2

3、分别构造了3个线性层实例


> * 1 : 784 -> 256
> * 2 : 256 -> 256
> * 3 : 256 -> 1


In [0]:
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid()
)

## Create Generator

* 1、定义生成（Generator）

> * 1 : 64 -> 256
> * 2 : 256 -> 256
> * 3 : 256 -> 784

In [0]:
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

## Create Model

In [0]:
D = D.to(device)
G = G.to(device)

## Create optimizer and loss function

In [0]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

## Train GAN Net

*   1：前向传播、计算损失
*   2：清零梯度、反向传播、更新权重

---

*    optimizer.zero_grad()
*    loss.backward()
*    optimizer.step()

In [11]:
def denorm(x):
    out = (x+1)/2
    return out.clamp(0, 1)
    
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)      # images.shape(100,784)
        real_labels = torch.ones(batch_size, 1).to(device)      # 定义真标签，全1，shape（100，1）
        fake_labels = torch.zeros(batch_size, 1).to(device)     # 定义假标签，全0，shape（100, 1）

        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)           # 计算判别器真图输出和真标签的损失
        real_score = outputs
        # print('real_score.shape()', real_score.shape)           # real_score.shape() torch.Size([100, 1])

        z = torch.randn(batch_size, latent_size).to(device)     # 设置随机种子z，shape(100,64)
        fake_images = G(z)                                      # 随机种子喂入生成器，形成假图 (100, 784)
        outputs = D(fake_images)                                # 利用判别器，计算假图输出
        d_loss_fake = criterion(outputs, fake_labels)           # 计算判别器假图输出和假标签损失
        fake_score = outputs
        # print('fake_score.shape()', fake_score.shape)           # fake_score.shape() torch.Size([100, 1])

        d_loss = d_loss_fake+d_loss_real                        # 计算损失和

        # 梯度归零，在反向传播之前，使用optimizer将它要更新的所有张量的梯度清零(这些张量是模型可学习的权重)
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()

        d_loss.backward()                                       # (整体)梯度反向传播

        d_optimizer.step()                                      # 单步优化，调用optimizer的step函数更新所有参数

        # 训练生成器
        z = torch.randn(batch_size, latent_size).to(device)     # 生成随机种子z shape(100,64)
        fake_images = G(z)                                      # 随机种子喂入生成器
        outputs = D(fake_images)                                # 判别器判别假图输出
        g_loss = criterion(outputs, real_labels)                # 判别器假图输出和真标签损失
        d_optimizer.zero_grad()                                 # 梯度归零
        g_optimizer.zero_grad()                                 
        g_loss.backward()                                       # 反向传播
        g_optimizer.step()                                      # 优化

        if (i % 200 == 0):
            print('Epoch [{}/{}],step[{}/{}],d_loss:{:.4f},g_loss:{:.4f},D(x):{:.2f},D(G(z)):{:.2f}'
            .format(epoch, num_epochs, 
                    i+1, total_step, 
                    d_loss.item(), g_loss.item(), real_score.mean(), fake_score.mean()))

    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

Epoch [0/200],step[1/600],d_loss:0.1607,g_loss:3.5065,D(x):0.93,D(G(z)):0.07
Epoch [0/200],step[201/600],d_loss:0.4003,g_loss:3.6108,D(x):0.88,D(G(z)):0.08
Epoch [0/200],step[401/600],d_loss:0.2325,g_loss:2.9610,D(x):0.96,D(G(z)):0.11
Epoch [1/200],step[1/600],d_loss:0.3478,g_loss:5.2051,D(x):0.92,D(G(z)):0.11
Epoch [1/200],step[201/600],d_loss:0.2985,g_loss:5.1744,D(x):0.90,D(G(z)):0.11
Epoch [1/200],step[401/600],d_loss:0.3826,g_loss:3.4403,D(x):0.86,D(G(z)):0.13
Epoch [2/200],step[1/600],d_loss:0.3951,g_loss:2.8535,D(x):0.87,D(G(z)):0.15
Epoch [2/200],step[201/600],d_loss:0.1811,g_loss:3.6162,D(x):0.92,D(G(z)):0.06
Epoch [2/200],step[401/600],d_loss:0.1968,g_loss:3.8504,D(x):0.93,D(G(z)):0.06
Epoch [3/200],step[1/600],d_loss:0.1520,g_loss:5.4148,D(x):0.95,D(G(z)):0.05
Epoch [3/200],step[201/600],d_loss:0.3160,g_loss:5.4249,D(x):0.94,D(G(z)):0.11
Epoch [3/200],step[401/600],d_loss:0.1872,g_loss:4.6160,D(x):0.96,D(G(z)):0.05
Epoch [4/200],step[1/600],d_loss:0.1322,g_loss:5.3739,D(x):0

## Save model

In [0]:
torch.save(G.state_dict(), 'G1.ckpt')
torch.save(D.state_dict(), 'D1.ckpt')