In [None]:
# coding=utf8

from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

matplotlib.use('TkAgg')

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
# 一些参数
dataroot = './data/CelebA/'  # 数据集路径
batch_size = 512
image_size = 64  # 图像的size
nc = 3  # 图像的channel
nz = 100  # latent vector的大小
ngf = 64  # generator的特征图大小（就是G的输出图像大小）
ndf = 64  # discriminator的特征图大小
epochs = 5
lr = 0.0002
beta1 = 0.5

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# 加载数据集（CelebA）
dataset = dset.ImageFolder(
    root=dataroot,
    transform=transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# 查看一些训练数据
batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(batch[0].to(device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0)))
plt.show()

In [None]:
# 权重初始化函数
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, x):
        return self.main(x)


netG = Generator().to(device)
netG.apply(weights_init)  # 权重初始化

In [None]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)


netD = Discriminator().to(device)
netD.apply(weights_init)

In [None]:
# pretrain
if os.path.exists(dataroot + "netD.pth"):
    netD.load_state_dict(torch.load(dataroot + "netD.pth"))
if os.path.exists(dataroot + "netG.pth"):
    netG.load_state_dict(torch.load(dataroot + "netG.pth"))

In [None]:
# 损失函数，优化器
criterion = nn.BCELoss()

real_label = 1
fake_label = 0
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# latent vector batch
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

In [None]:
# 训练
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training...")
for epoch in range(epochs):
    for i, data in enumerate(dataloader, 0):
        ######################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ######################
        ## 用真实图像训练
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward and backward
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()
        ## 用假图像训练
        # generate latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # generate fake image
        fake = netG(noise)
        label.fill_(fake_label)
        # netD forward and backward
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # error
        errD = errD_real + errD_fake
        optimizerD.step()

        ######################
        # (2) Update G network: maximize log(D(G(z)))
        ######################
        netG.zero_grad()
        label.fill_(real_label)
        # 重新进行判别
        output = netD(fake).view(-1)
        # G loss
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        ######################
        # Training stats
        ######################
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch + 1, epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        if (iters % 500 == 0) or ((epoch == epochs - 1) and (i == len(dataloader) - 1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

In [None]:
# 保存模型
torch.save(netD.state_dict(), dataroot + "netD.pth")
torch.save(netG.state_dict(), dataroot + "netG.pth")

In [None]:
# 查看D和G的loss
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# 查看G的训练过程（用fixed_noise生成图像变化过程）
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [None]:
# 真实图片和假图片的对比
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("真实图片")
plt.imshow(
    np.transpose(vutils.make_grid(real_batch[0].to(device)[:16], padding=5, normalize=True, nrow=4).cpu(), (1, 2, 0)))

# 随机噪声生成图片
noise = torch.randn(16, nz, 1, 1, device=device)
fake = netG(noise).cpu()
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("生成图片")
plt.imshow(np.transpose(vutils.make_grid(fake, padding=5, normalize=True, nrow=4), (1, 2, 0)))

plt.show()