In [43]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import os

if not os.path.exists('./img'):
    os.mkdir('./img')

In [44]:
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out

In [45]:
batch_size = 128
num_epoch = 100
z_dimension = 100

# Image processing
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# MNIST dataset
mnist = datasets.MNIST(
    root='./data/', train=True, transform=img_transform, download=True)
# Data loader
dataloader = torch.utils.data.DataLoader(
    dataset=mnist, batch_size=batch_size, shuffle=True)

In [46]:
# Discriminator
class discriminator(nn.Module):
    """
    判别网络事实上就是一个二分类的网络，分辨图片的真伪
    """
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid())

    def forward(self, x):
        x = self.dis(x)
        return x


# Generator
class generator(nn.Module):
    """
    生成网络随机一个向量生成一个图片
    """
    def __init__(self):
        super(generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 256), nn.ReLU(True), nn.Linear(256, 784), nn.Tanh())

    def forward(self, x):
        x = self.gen(x)
        return x

In [None]:
D = discriminator()
G = generator()
if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        # =================train discriminator
        img = img.view(num_img, -1)
        real_img = Variable(img).cuda()
        real_label = Variable(torch.ones(num_img)).cuda()
        fake_label = Variable(torch.zeros(num_img)).cuda()

        # compute loss of real_img
        real_out = D(real_img)
        d_loss_real = criterion(real_out, real_label)
        real_scores = real_out  # closer to 1 means better

        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension)).cuda()
        fake_img = G(z)
        fake_out = D(fake_img)
        d_loss_fake = criterion(fake_out, fake_label)
        fake_scores = fake_out  # closer to 0 means better

        # bp and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ===============train generator
        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension)).cuda()
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)
        
        # bp and optimize
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'D real: {:.6f}, D fake: {:.6f}'.format(
                      epoch, num_epoch, d_loss.data[0], g_loss.data[0],
                      real_scores.data.mean(), fake_scores.data.mean()))
    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './img/real_images.png')

    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))

torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [0/100], d_loss: 0.121998, g_loss: 3.546974 D real: 0.977048, D fake: 0.091577
Epoch [0/100], d_loss: 0.058318, g_loss: 4.290877 D real: 0.985306, D fake: 0.042083
Epoch [0/100], d_loss: 0.161965, g_loss: 5.152973 D real: 0.965903, D fake: 0.102542
Epoch [0/100], d_loss: 0.001653, g_loss: 8.397080 D real: 0.999845, D fake: 0.001493


  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [1/100], d_loss: 0.117615, g_loss: 5.177507 D real: 0.957375, D fake: 0.055432
Epoch [1/100], d_loss: 0.028472, g_loss: 6.449814 D real: 0.985970, D fake: 0.012861
Epoch [1/100], d_loss: 0.241403, g_loss: 5.855880 D real: 0.924284, D fake: 0.067448
Epoch [1/100], d_loss: 0.127386, g_loss: 7.272000 D real: 0.959430, D fake: 0.046443
Epoch [2/100], d_loss: 0.572582, g_loss: 4.427216 D real: 0.898730, D fake: 0.238263
Epoch [2/100], d_loss: 0.489846, g_loss: 4.319398 D real: 0.872898, D fake: 0.158038
Epoch [2/100], d_loss: 0.297053, g_loss: 5.099032 D real: 0.936843, D fake: 0.107696
Epoch [2/100], d_loss: 0.856211, g_loss: 3.953827 D real: 0.848682, D fake: 0.348639
Epoch [3/100], d_loss: 0.270003, g_loss: 5.351149 D real: 0.926315, D fake: 0.025403
Epoch [3/100], d_loss: 0.379847, g_loss: 4.761645 D real: 0.882656, D fake: 0.122177
Epoch [3/100], d_loss: 3.827452, g_loss: 0.856467 D real: 0.295213, D fake: 0.658173
Epoch [3/100], d_loss: 1.200500, g_loss: 2.176889 D real: 0.73394

Epoch [25/100], d_loss: 0.442195, g_loss: 2.711858 D real: 0.933837, D fake: 0.217963
Epoch [25/100], d_loss: 0.510269, g_loss: 2.995680 D real: 0.845244, D fake: 0.090710
Epoch [25/100], d_loss: 0.450366, g_loss: 3.307637 D real: 0.917693, D fake: 0.207150
Epoch [25/100], d_loss: 0.368705, g_loss: 3.072907 D real: 0.859574, D fake: 0.076115
Epoch [26/100], d_loss: 0.366540, g_loss: 3.157058 D real: 0.930553, D fake: 0.177950
Epoch [26/100], d_loss: 0.440447, g_loss: 3.956893 D real: 0.877068, D fake: 0.161179
Epoch [26/100], d_loss: 0.659666, g_loss: 4.350977 D real: 0.820283, D fake: 0.146988
Epoch [26/100], d_loss: 0.278245, g_loss: 3.039406 D real: 0.933202, D fake: 0.134388
Epoch [27/100], d_loss: 0.385205, g_loss: 3.183555 D real: 0.902588, D fake: 0.116742
Epoch [27/100], d_loss: 0.489432, g_loss: 3.307326 D real: 0.825440, D fake: 0.104576
Epoch [27/100], d_loss: 0.382388, g_loss: 3.259119 D real: 0.879281, D fake: 0.100267
Epoch [27/100], d_loss: 0.586138, g_loss: 3.190042 D r

Epoch [49/100], d_loss: 0.704936, g_loss: 2.901541 D real: 0.832852, D fake: 0.239864
Epoch [49/100], d_loss: 0.452876, g_loss: 3.441339 D real: 0.825948, D fake: 0.095942
Epoch [49/100], d_loss: 0.619574, g_loss: 2.393365 D real: 0.781770, D fake: 0.181856
Epoch [49/100], d_loss: 0.536264, g_loss: 2.736376 D real: 0.800440, D fake: 0.132543
Epoch [50/100], d_loss: 0.537762, g_loss: 2.226238 D real: 0.903118, D fake: 0.274597
Epoch [50/100], d_loss: 0.852824, g_loss: 2.224986 D real: 0.853746, D fake: 0.351694
Epoch [50/100], d_loss: 0.412315, g_loss: 2.483872 D real: 0.877988, D fake: 0.147682
Epoch [50/100], d_loss: 0.404211, g_loss: 3.341721 D real: 0.883285, D fake: 0.168896
Epoch [51/100], d_loss: 0.412401, g_loss: 2.640524 D real: 0.862663, D fake: 0.158232
Epoch [51/100], d_loss: 0.479909, g_loss: 2.896540 D real: 0.851526, D fake: 0.150938
Epoch [51/100], d_loss: 0.550382, g_loss: 2.536888 D real: 0.828847, D fake: 0.144185
Epoch [51/100], d_loss: 0.665353, g_loss: 2.700093 D r

Epoch [73/100], d_loss: 0.625800, g_loss: 2.762475 D real: 0.790383, D fake: 0.185025
Epoch [73/100], d_loss: 0.726915, g_loss: 1.846308 D real: 0.751871, D fake: 0.202921
Epoch [73/100], d_loss: 0.747400, g_loss: 2.000257 D real: 0.776225, D fake: 0.253975
Epoch [73/100], d_loss: 0.777153, g_loss: 2.332372 D real: 0.753980, D fake: 0.231522
Epoch [74/100], d_loss: 0.795353, g_loss: 2.410901 D real: 0.752318, D fake: 0.207912
Epoch [74/100], d_loss: 0.769414, g_loss: 2.167787 D real: 0.789959, D fake: 0.318568
Epoch [74/100], d_loss: 0.632812, g_loss: 2.801226 D real: 0.756181, D fake: 0.147025
Epoch [74/100], d_loss: 0.484532, g_loss: 3.162210 D real: 0.842913, D fake: 0.150591
Epoch [75/100], d_loss: 0.735673, g_loss: 1.822918 D real: 0.772775, D fake: 0.230203
Epoch [75/100], d_loss: 0.861266, g_loss: 1.780242 D real: 0.750698, D fake: 0.270627
Epoch [75/100], d_loss: 0.538852, g_loss: 2.000544 D real: 0.847184, D fake: 0.217492
Epoch [75/100], d_loss: 0.728914, g_loss: 2.147697 D r