In [None]:
import time

import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision
from torchvision import transforms

import base

In [None]:
torch.__version__

### 数据

In [None]:
PATH_G = './pth/mnist_GAN(G).pth'
PATH_D = './pth/mnist_GAN(D).pth'
PATH_T = './pth/test_input.pt'

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

In [None]:
train_data_set = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)

In [None]:
dataloader = torch.utils.data.DataLoader(train_data_set, batch_size=64, shuffle=True)

### 生成器

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28 * 28),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.main(x)
        x = x.view(-1, 28, 28)
        return x

### 判别器

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(),
            nn.Dropout(p=0.4),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.main(x)
        return x

### 初始化模型,优化器,损失函数

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
gen = Generator().to(device)
dis = Discriminator().to(device)

In [None]:
d_optim = optim.Adam(dis.parameters(), lr=0.001)
g_optim = optim.Adam(gen.parameters(), lr=0.001)

In [None]:
loss_func = nn.BCELoss()

### 绘图

In [None]:
def gen_image_plot(model, t_input):
    prediction = np.squeeze(model(t_input).detach().cpu().numpy())
    plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow((prediction[i] + 1) / 2)
        plt.axis('off')
    plt.show()

In [None]:
try:
    test_input = torch.load(PATH_T, map_location=device)
except FileNotFoundError:
    test_input = torch.randn(16, 100, device=device)
    torch.save(test_input, PATH_T)
# test_input = torch.randn(16, 100, device=device)

### 训练

In [None]:
D_loss = []
G_loss = []

In [None]:
def train(epoch, save=True, load=True):
    t = time.time()
    try:
        if load:
            gen.load_state_dict(torch.load(PATH_G, map_location=device)) and \
            dis.load_state_dict(torch.load(PATH_D, map_location=device))
    except FileNotFoundError:
        pass
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)
    for step, (img, _) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size, 100, device=device)

        d_optim.zero_grad()
        dis.train(True)
        real_out = dis(img)
        d_real_loss = loss_func(real_out, torch.ones_like(real_out))
        d_real_loss.backward()
        gen_img = gen(random_noise)
        fake_out = dis(gen_img.detach())
        d_fake_loss = loss_func(fake_out, torch.zeros_like(fake_out))
        d_fake_loss.backward()
        d_loss = d_real_loss + d_fake_loss
        d_optim.step()

        g_optim.zero_grad()
        dis.train(False)
        fake_out_g = dis(gen_img)
        g_loss = loss_func(fake_out_g, torch.ones_like(fake_out_g))
        g_loss.backward()
        g_optim.step()

        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('Epoch:', epoch)
        print('d-Loss:', d_epoch_loss)
        print('g-Loss:', g_epoch_loss)
        gen_image_plot(gen, test_input)
    if save:
        torch.save(gen.state_dict(), PATH_G)
        torch.save(dis.state_dict(), PATH_D)
    epoch -= 1
    print(time.time() - t)
    print()
    if epoch:
        train(epoch, save, save)

In [None]:
@base.timer
def main():
    train(100, True, True)
    pass

In [None]:
if __name__ == '__main__':
    main()