In [1]:
!/opt/bin/nvidia-smi

!rm -rf /content/sample_data

!rm -rf /content/img_VAE-WGANGP

Thu Nov 19 01:57:09 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P8     9W /  70W |      0MiB / 15079MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [2]:
from __future__ import print_function
import argparse
import os,time
import random
import math
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.utils import make_grid
import torch.nn.functional as F

import imageio
import matplotlib.pyplot as plt

In [3]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        # 定义编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(1,16,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(16,32,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(32,32,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2,inplace=True),
        )
        self.encoder_fc1=nn.Linear(32*7*7,nz)
        self.encoder_fc2=nn.Linear(32*7*7,nz)
        self.Sigmoid = nn.Sigmoid()
        self.decoder_fc = nn.Linear(nz,32 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 4, 2, 1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(16, 1, 4, 2, 1),
            nn.Sigmoid(),
        )

    def noise_reparameterize(self,mean,logvar):
        eps = torch.randn(mean.shape).to(device)
        z = mean + eps * torch.exp(logvar)
        return z

    def forward(self, x):
        out1,out2 = self.encoder(x),self.encoder(x)
        mean = self.encoder_fc1(out1.view(out1.shape[0],-1))
        logstd = self.encoder_fc2(out2.view(out2.shape[0],-1))
        z = self.noise_reparameterize(mean,logstd)
        
        out3 = self.decoder_fc(z)
        out3 = out3.view(out3.shape[0],32,7,7)
        out3 = self.decoder(out3)
        return out3,mean,logstd


In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d((2, 2)),

            nn.Conv2d(32, 64, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d((2, 2)),
        )
        self.fc = nn.Sequential(
            nn.Linear(7 * 7 * 64, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, input):
        x = self.dis(input)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x.squeeze(1)


In [5]:
def loss_function(recon_x,x,mean,logstd):
    # BCE = F.binary_cross_entropy(recon_x,x,reduction='sum')
    MSE = MSECriterion(recon_x,x)
    # 因为var是标准差的自然对数，先求自然对数然后平方转换成方差
    var = torch.pow(torch.exp(logstd),2)
    KLD = -0.5 * torch.sum(1+torch.log(var)-torch.pow(mean,2)-var)
    return MSE+KLD

In [None]:
if __name__ == '__main__':
    batchSize = 128
    imageSize = 28
    nz=100
    nepoch=20
    lambda_ = 10
    
    if not os.path.exists('./img_VAE-WGANGP'):
        os.mkdir('./img_VAE-WGANGP')

    # random.seed(88)
    # torch.manual_seed(88)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    cudnn.benchmark = True

    dataset = dset.MNIST(root='./data',
                train=True,
                transform=transforms.Compose([transforms.ToTensor()]),
                download=True
                )

    dataloader = torch.utils.data.DataLoader(dataset,
                          batch_size=batchSize,
                          shuffle=True)

    print("=====> Initialization")
    vae = VAE().to(device)
    # vae.load_state_dict(torch.load('./VAE-WGANGP-VAE_v2.pth'))

    D = Discriminator().to(device)
    # D.load_state_dict(torch.load('./VAE-GAN-Discriminator.pth'))
    
    criterion = nn.BCELoss().to(device)
    MSECriterion = nn.MSELoss().to(device)

    optimizerD = optim.Adam(D.parameters(), lr=0.0001,betas=(0.5, 0.999))
    optimizerVAE = optim.Adam(vae.parameters(), lr=0.0001,betas=(0.5, 0.999))

    print("=====> Begin training")
    
    start_time = time.time()
    for epoch in range(nepoch):

        if(epoch%5==0): lambda_ -= 1
        epoch_start_time = time.time()

        for i, (data,label) in enumerate(dataloader, 0):
            
            for n in range(1):
              
                ###################################################################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###################################################################

                # train with real
                D.zero_grad()
                data = data.to(device)
                label = label.to(device)
                batch_size = data.shape[0]
                real_out = D(data)

                real_label = torch.ones(batch_size).to(device)  
                fake_label = torch.zeros(batch_size).to(device)  

                real_data_score = real_out.mean().item()

                # train with fake, taking the noise vector z as the input of D network
                z = torch.randn(batch_size, nz).to(device)
                fake_data = vae.decoder_fc(z).view(z.shape[0], 32, 7, 7)
                fake_data = vae.decoder(fake_data)
                fake_out = D(fake_data)

                # fake_data_score用来输出查看的，是虚假照片的评分，0最假，1为真
                fake_data_score = fake_out.mean().item()

                alpha = torch.rand((batch_size, 1, 1, 1)).to(device)
                x_hat = alpha * data + (1 - alpha) * fake_data

                pred_hat = D(x_hat)

                gradients = \
                    torch.autograd.grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).to(device),
                                        create_graph=True, retain_graph=True, only_inputs=True)[0]
                gradient_penalty = lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean()

                d_loss = torch.mean(fake_out) - torch.mean(real_out) + gradient_penalty
                d_loss.backward()
                optimizerD.step()

            ###################################################
            # (2) Update G network which is the decoder of VAE
            ###################################################
            recon_data,mean,logstd = vae(data)
            vae.zero_grad()
            vae_loss = loss_function(recon_data,data,mean,logstd)
            vae_loss.backward(retain_graph=True)
            optimizerVAE.step()
            
            ###############################################
            # (3) Update G network: maximize log(D(G(z)))
            ###############################################
            recon_data,mean,logstd = vae(data)
            vae.zero_grad()
            real_label = torch.ones(batch_size).to(device)  
            output = D(recon_data)
            errVAE = torch.mean(-output)
            errVAE.backward()
            D_G_z2 = output.mean().item()
            optimizerVAE.step()
    
        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time
        print('[%d/%d] time: %.2f real_score: %.4f fake_score: %.4f '
              % (epoch+1, nepoch, per_epoch_ptime,real_data_score,fake_data_score,))
    
        sample = torch.randn(80, nz).to(device)
        output = vae.decoder_fc(sample)
        output = vae.decoder(output.view(output.shape[0], 32, 7, 7))
        fake_images = make_grid(output.cpu(), nrow=8, normalize=True).detach()
        save_image(fake_images, './img_VAE-WGANGP/fake_images-{}.png'.format(epoch + 1))

    end_time = time.time()
    total_time = end_time - start_time
    print("total time: %.2f " % total_time )

images = []
for e in range(nepoch):
    img_name = './img_VAE-WGANGP/fake_images-' + str(e+1) + '.png'
    images.append(imageio.imread(img_name))
imageio.mimsave('./generation_animation.gif', images, fps=2)

# torch.save(vae.state_dict(), './VAE-WGANGP-VAE.pth')
# torch.save(D.state_dict(),'./VAE-WGANGP-Discriminator.pth')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!
=====> Initialization




  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


=====> Begin training
[1/20] time: 20.56 real_score: 7.3871 fake_score: -0.6042 
[2/20] time: 20.38 real_score: 3.1134 fake_score: -0.1934 
[3/20] time: 20.33 real_score: 0.8831 fake_score: -1.0677 
[4/20] time: 19.87 real_score: -0.8742 fake_score: -2.5293 
[5/20] time: 19.98 real_score: -1.2477 fake_score: -2.9405 
[6/20] time: 19.97 real_score: -0.5816 fake_score: -2.0436 
