In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
import os

In [2]:
if not os.path.exists('./dc_img'):
    os.mkdir('./dc_img')


def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out


batch_size = 128
num_epoch = 100
z_dimension = 100  # noise dimension

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

mnist = datasets.MNIST('./data', transform=img_transform)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True,
                        num_workers=4)

In [3]:
class Discriminator(nn.Module):
    """
    判别网络，用卷积神经网络来识别图片的真伪
    """
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 32, 5, padding=2),  # Conv2d(in_chanel, out_channel, kernel_size)
                                  nn.LeakyReLU(0.2, True),
                                  nn.AvgPool2d(2, stride=2))  # AvgPool2d(pool_size, strid)
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, padding=2),  # batch, 64, 14, 14
            nn.LeakyReLU(0.2, True),
            nn.AvgPool2d(2, stride=2))  # batch, 64, 7, 7
        
        self.fc = nn.Sequential(
            nn.Linear(64*7*7, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 1),
            nn.Sigmoid())  # 返回的sigmoid的值，也就是结果为1的概率值
        
    def forward(self, x):
        """
        x = [batch_size, width, leight, in_channel]
        """
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [4]:
class Generator(nn.Module):
    """
    生成网络，输入随机的噪声来生成图片
    """
    def __init__(self, input_size, num_features):
        super(Generator, self).__init__()
        self.fc = nn.Linear(input_size, num_features)  # input_size为噪声的大小，num_featurs为映射后的大小，取3136=1x56x56
        
        self.br= nn.Sequential(nn.BatchNorm2d(1),
                              nn.ReLU(True))
        
        self.downsample1 = nn.Sequential(
            nn.Conv2d(1, 50, 3, stride=1, padding=1),  # batch, 50, 56, 56
            nn.BatchNorm2d(50),  # 卷积之后使用bn来归一化
            nn.ReLU(True)
        )
        self.downsample2 = nn.Sequential(
            nn.Conv2d(50, 25, 3, stride=1, padding=1),  # batch, 25, 56, 56
            nn.BatchNorm2d(25),
            nn.ReLU(True)
        )
        self.downsample3 = nn.Sequential(
            nn.Conv2d(25, 1, 2, stride=2),  # batch, 1, 28, 28
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 1, 56, 56)
        x = self.br(x)
        x = self.downsample1(x)
        x = self.downsample2(x)
        x = self.downsample3(x)
        return x

In [5]:
D = Discriminator().cuda()  # discriminator model
G = Generator(z_dimension, 3136).cuda()  # generator model

criterion = nn.BCELoss()  # binary cross entropy

d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

In [6]:
# train
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        # =================train discriminator
        real_img = Variable(img).cuda()
        real_label = Variable(torch.ones(num_img)).cuda()
        fake_label = Variable(torch.zeros(num_img)).cuda()

        # compute loss of real_img
        real_out = D(real_img)
        d_loss_real = criterion(real_out, real_label)
        real_scores = real_out  # closer to 1 means better

        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension)).cuda()
        fake_img = G(z)  # 利用生成网络生成假图片
        fake_out = D(fake_img)
        d_loss_fake = criterion(fake_out, fake_label)
        fake_scores = fake_out  # closer to 0 means better

        # bp and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # ===============train generator
        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension)).cuda()
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)

        # bp and optimize
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'D real: {:.6f}, D fake: {:.6f}'
                  .format(epoch, num_epoch, d_loss.data[0], g_loss.data[0],
                          real_scores.data.mean(), fake_scores.data.mean()))
    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './dc_img/real_images.png')

    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images, './dc_img/fake_images-{}.png'.format(epoch+1))

torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [0/100], d_loss: 0.000971, g_loss: 8.986004 D real: 0.999737, D fake: 0.000707
Epoch [0/100], d_loss: 0.064025, g_loss: 5.765369 D real: 0.988550, D fake: 0.046203
Epoch [0/100], d_loss: 0.517583, g_loss: 2.014836 D real: 0.913938, D fake: 0.271534
Epoch [0/100], d_loss: 0.833227, g_loss: 1.914553 D real: 0.633502, D fake: 0.122991


  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [1/100], d_loss: 0.855469, g_loss: 1.829101 D real: 0.583707, D fake: 0.088974
Epoch [1/100], d_loss: 0.683514, g_loss: 1.154147 D real: 0.641464, D fake: 0.088397
Epoch [1/100], d_loss: 0.463760, g_loss: 2.779452 D real: 0.920519, D fake: 0.266325
Epoch [1/100], d_loss: 0.397184, g_loss: 2.451964 D real: 0.804047, D fake: 0.077926
Epoch [2/100], d_loss: 0.554621, g_loss: 3.787340 D real: 0.966343, D fake: 0.337910
Epoch [2/100], d_loss: 0.343109, g_loss: 2.907626 D real: 0.907131, D fake: 0.141722
Epoch [2/100], d_loss: 0.353770, g_loss: 3.109071 D real: 0.818277, D fake: 0.049098
Epoch [2/100], d_loss: 0.316770, g_loss: 3.797498 D real: 0.935791, D fake: 0.168997
Epoch [3/100], d_loss: 0.341866, g_loss: 3.877265 D real: 0.925136, D fake: 0.176629
Epoch [3/100], d_loss: 0.451903, g_loss: 2.416387 D real: 0.923148, D fake: 0.199185
Epoch [3/100], d_loss: 0.323490, g_loss: 2.557395 D real: 0.925631, D fake: 0.144054
Epoch [3/100], d_loss: 0.253316, g_loss: 2.968826 D real: 0.93521

Epoch [25/100], d_loss: 0.484576, g_loss: 2.316967 D real: 0.804843, D fake: 0.138265
Epoch [25/100], d_loss: 0.442930, g_loss: 2.318432 D real: 0.816298, D fake: 0.104674
Epoch [25/100], d_loss: 0.526719, g_loss: 2.817770 D real: 0.776619, D fake: 0.086164
Epoch [25/100], d_loss: 0.575008, g_loss: 2.834360 D real: 0.801920, D fake: 0.179903
Epoch [26/100], d_loss: 0.517709, g_loss: 2.298701 D real: 0.844067, D fake: 0.194874
Epoch [26/100], d_loss: 0.510721, g_loss: 2.422199 D real: 0.806196, D fake: 0.117715
Epoch [26/100], d_loss: 0.784615, g_loss: 2.584378 D real: 0.913629, D fake: 0.384112
Epoch [26/100], d_loss: 0.538694, g_loss: 1.922727 D real: 0.850854, D fake: 0.193680
Epoch [27/100], d_loss: 0.509394, g_loss: 2.720214 D real: 0.841371, D fake: 0.179714
Epoch [27/100], d_loss: 0.511959, g_loss: 2.376315 D real: 0.849862, D fake: 0.173044
Epoch [27/100], d_loss: 0.475214, g_loss: 2.988818 D real: 0.842785, D fake: 0.123665
Epoch [27/100], d_loss: 0.447422, g_loss: 2.731344 D r

Epoch [49/100], d_loss: 0.436261, g_loss: 2.700859 D real: 0.859238, D fake: 0.130507
Epoch [49/100], d_loss: 0.479139, g_loss: 2.148749 D real: 0.847312, D fake: 0.137440
Epoch [49/100], d_loss: 0.483352, g_loss: 1.952468 D real: 0.843624, D fake: 0.147019
Epoch [49/100], d_loss: 0.464542, g_loss: 3.274563 D real: 0.822932, D fake: 0.088023
Epoch [50/100], d_loss: 0.471760, g_loss: 2.274518 D real: 0.865757, D fake: 0.181012
Epoch [50/100], d_loss: 0.560712, g_loss: 2.617124 D real: 0.906920, D fake: 0.257239
Epoch [50/100], d_loss: 0.620042, g_loss: 2.665453 D real: 0.829888, D fake: 0.213861
Epoch [50/100], d_loss: 0.439736, g_loss: 2.643804 D real: 0.830374, D fake: 0.086072
Epoch [51/100], d_loss: 0.411667, g_loss: 1.907278 D real: 0.859222, D fake: 0.147174
Epoch [51/100], d_loss: 0.377327, g_loss: 3.173761 D real: 0.882921, D fake: 0.153634
Epoch [51/100], d_loss: 0.411039, g_loss: 2.804836 D real: 0.860419, D fake: 0.111495
Epoch [51/100], d_loss: 0.514554, g_loss: 2.255914 D r

Epoch [73/100], d_loss: 0.437860, g_loss: 2.808015 D real: 0.858732, D fake: 0.094690
Epoch [73/100], d_loss: 0.343217, g_loss: 3.143754 D real: 0.879333, D fake: 0.066127
Epoch [73/100], d_loss: 0.372570, g_loss: 3.249603 D real: 0.909081, D fake: 0.126405
Epoch [73/100], d_loss: 0.409506, g_loss: 2.241008 D real: 0.865363, D fake: 0.117869
Epoch [74/100], d_loss: 0.453081, g_loss: 3.057210 D real: 0.833540, D fake: 0.066149
Epoch [74/100], d_loss: 0.357558, g_loss: 3.122952 D real: 0.875552, D fake: 0.080472
Epoch [74/100], d_loss: 0.367603, g_loss: 1.996183 D real: 0.874910, D fake: 0.119747
Epoch [74/100], d_loss: 0.387925, g_loss: 2.442420 D real: 0.867914, D fake: 0.097667
Epoch [75/100], d_loss: 0.382183, g_loss: 2.860499 D real: 0.879270, D fake: 0.103891
Epoch [75/100], d_loss: 0.398482, g_loss: 2.361688 D real: 0.863886, D fake: 0.100752
Epoch [75/100], d_loss: 0.401252, g_loss: 2.630914 D real: 0.884851, D fake: 0.115993
Epoch [75/100], d_loss: 0.348269, g_loss: 2.201433 D r

Epoch [97/100], d_loss: 0.394904, g_loss: 2.613120 D real: 0.868078, D fake: 0.102210
Epoch [97/100], d_loss: 0.370717, g_loss: 2.562987 D real: 0.889415, D fake: 0.124230
Epoch [97/100], d_loss: 0.316521, g_loss: 2.295799 D real: 0.916907, D fake: 0.132357
Epoch [97/100], d_loss: 0.301588, g_loss: 2.911119 D real: 0.896217, D fake: 0.080332
Epoch [98/100], d_loss: 0.366354, g_loss: 2.939683 D real: 0.867228, D fake: 0.061232
Epoch [98/100], d_loss: 0.293405, g_loss: 2.472677 D real: 0.940030, D fake: 0.155018
Epoch [98/100], d_loss: 0.269857, g_loss: 2.983252 D real: 0.905191, D fake: 0.080449
Epoch [98/100], d_loss: 0.360143, g_loss: 2.937191 D real: 0.885197, D fake: 0.095256
Epoch [99/100], d_loss: 0.294742, g_loss: 3.900194 D real: 0.909974, D fake: 0.090328
Epoch [99/100], d_loss: 0.346860, g_loss: 2.656889 D real: 0.900022, D fake: 0.109207
Epoch [99/100], d_loss: 0.302360, g_loss: 4.092312 D real: 0.905461, D fake: 0.058273
Epoch [99/100], d_loss: 0.301657, g_loss: 2.906443 D r