In [2]:
!/opt/bin/nvidia-smi

!rm -rf /content/sample_data

!rm -rf /content/img_CVAE-WGAN

Wed Nov 25 08:44:34 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [3]:
from __future__ import print_function
import argparse
import os,time
import random
import math
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.utils import make_grid
import torch.nn.functional as F

import imageio
import matplotlib.pyplot as plt

In [5]:
class VAE(nn.Module):
    def __init__(self):

        super(VAE, self).__init__()
        
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(1,16,kernel_size=3,stride=2,padding=1),
            
            nn.LeakyReLU(0.2,inplace=True),
            nn.BatchNorm2d(16),
            nn.Conv2d(16,32,kernel_size=3,stride=2,padding=1),
            
            nn.LeakyReLU(0.2,inplace=True),
            nn.BatchNorm2d(32),
            nn.Conv2d(32,32,kernel_size=3,stride=1,padding=1),
            
            nn.LeakyReLU(0.2,inplace=True),
            nn.BatchNorm2d(32),
        )

        self.encoder_fc1=nn.Linear(32*7*7,nz)
        self.encoder_fc2=nn.Linear(32*7*7,nz)
        self.Sigmoid = nn.Sigmoid()
        self.decoder_fc = nn.Linear(nz+10,32 * 7 * 7)
        self.decoder_deconv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 4, 2, 1),
            nn.LeakyReLU(0.2,inplace=True),
            nn.ConvTranspose2d(16, 1, 4, 2, 1),
            nn.Tanh(),
        )

    def noise_reparameterize(self,mean,logvar):
        eps = torch.randn(mean.shape).to(device)
        z = mean + eps * torch.exp(logvar)
        return z

    def encoder(self,x):
        out1, out2 = self.encoder_conv(x), self.encoder_conv(x)
        mean = self.encoder_fc1(out1.view(out1.shape[0], -1))
        logstd = self.encoder_fc2(out2.view(out2.shape[0], -1))
        z = self.noise_reparameterize(mean, logstd)
        return z,mean,logstd

    def decoder(self,z):
        out3 = self.decoder_fc(z)
        out3 = out3.view(out3.shape[0], 32, 7, 7)
        out3 = self.decoder_deconv(out3)
        return out3

    def forward(self, x):
        z = self.encoder(x)
        output = self.decoder(z)
        return output

In [6]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.dis = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d((2, 2)),

            nn.Conv2d(32, 64, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d((2, 2)),
        )
        self.fc = nn.Sequential(
            nn.Linear(7 * 7 * 64, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 10),
            nn.Sigmoid()
        )

    def forward(self, input):
        x = self.dis(input)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x.squeeze(1)

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784, 256),  # 输入特征数为784，输出为256
            nn.LeakyReLU(0.2),  # 进行非线性映射
            nn.Linear(256, 256),  # 进行一个线性映射
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 也是一个激活函数，二分类问题中，
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.dis(x)
        return x

In [8]:
def loss_function(recon_x,x,mean,logstd):
    # BCE = F.binary_cross_entropy(recon_x,x,reduction='sum')
    MSE = MSECriterion(recon_x,x)
    # 因为var是标准差的自然对数，先求自然对数然后平方转换成方差
    var = torch.pow(torch.exp(logstd),2)
    KLD = -0.5 * torch.sum(1+torch.log(var)-torch.pow(mean,2)-var)
    return MSE+KLD

In [9]:
if __name__ == '__main__':
    batchSize = 128
    imageSize = 28
    nz=100
    nepoch= 40
    lambda_ = 10
    c = 0.005

    if not os.path.exists('./img_CVAE-WGAN'):
        os.mkdir('./img_CVAE-WGAN')

    # random.seed(88)
    # torch.manual_seed(88)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    cudnn.benchmark = True

    dataset = dset.MNIST(root='./data',
                train=True,
                transform=transforms.Compose([transforms.ToTensor()]),
                download=True
                )

    dataloader = torch.utils.data.DataLoader(dataset,
                          batch_size=batchSize,
                          shuffle=True)

    print("=====> Initialization")
    vae = VAE().to(device)
    # vae.load_state_dict(torch.load('./VAE-WGANGP-VAE_v2.pth'))

    # discriminator
    D = Discriminator().to(device)
    # D.load_state_dict(torch.load('./VAE-GAN-Discriminator.pth'))

    # Classifier
    C = Classifier().to(device)
    # C.load_state_dict(torch.load('./CVAE-GAN-Classifier.pth'))
    
    criterion = nn.BCELoss().to(device)
    MSECriterion = nn.MSELoss().to(device)

    # best=0.0001,0.0004,0.0008
    optimizerD = optim.Adam(D.parameters(), lr=0.0004,betas=(0.5, 0.999))
    optimizerC = optim.Adam(C.parameters(), lr=0.0004,betas=(0.5, 0.999))
    optimizerVAE = optim.Adam(vae.parameters(), lr=0.0004,betas=(0.5, 0.999))

    # sample_label
    s_label = []
    for i in [1,8,0,9,4,0,1,0,5,3]:
      s_label += [ i for _ in range(8)]
    specific_label = torch.zeros((80, 10)).cuda()
    specific_label[torch.arange(80), s_label] = 1

    print("=====> Begin training")
    
    start_time = time.time()
    for epoch in range(nepoch):

        if(epoch%5==0): lambda_ -= 1
        
        epoch_start_time = time.time()

        for i, (data,label) in enumerate(dataloader, 0):
            
            for n in range(1):
              
                # data processing
                data = data.to(device)
                label = label.to(device)
                label_onehot = torch.zeros((data.shape[0], 10)).to(device)
                label_onehot[torch.arange(data.shape[0]), label] = 1
                batch_size = data.shape[0]

                # training C with real
                C.zero_grad()
                
                output = C(data)
                real_label = label_onehot.to(device)
                errC = criterion(output, real_label)
                
                errC.backward()
                optimizerC.step()
                
                ###################################################################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###################################################################

                # train with real
                batch_size = data.shape[0]

                D.zero_grad()

                real_out = D(data)

                real_label = torch.ones(batch_size).to(device)  
                fake_label = torch.zeros(batch_size).to(device)  

                real_data_score = real_out.mean().item()

                # train with fake, taking the noise vector z as the input of D network
                z = torch.randn(batch_size, nz+10).to(device)
                fake_data = vae.decoder(z)
                fake_out = D(fake_data)

                # fake_data_score用来输出查看的，是虚假照片的评分，0最假，1为真
                fake_data_score = fake_out.mean().item()

                d_loss = torch.mean(fake_out) - torch.mean(real_out)               
                d_loss.backward()
                optimizerD.step()

                for layer in D.dis:
                  if (layer.__class__.__name__ == 'Linear'):
                      layer.weight.requires_grad = False
                      layer.weight.clamp_(-c, c)
                      layer.weight.requires_grad = True

            ###################################################
            # (2) Update G network which is the decoder of VAE
            ###################################################
            z,mean,logstd = vae.encoder(data)
            z = torch.cat([z,label_onehot],1)
            recon_data = vae.decoder(z)
            
            vae.zero_grad()
            
            vae_loss = loss_function(recon_data,data,mean,logstd)

            vae_loss.backward(retain_graph=True)
            optimizerVAE.step()
            
            ###############################################
            # (3) Update G network: maximize log(D(G(z)))
            ###############################################
            z,mean,logstd = vae.encoder(data)
            z = torch.cat([z,label_onehot],1)
            recon_data = vae.decoder(z)
            
            vae.zero_grad()

            # real_label = torch.ones(batch_size).to(device)  
            output = D(recon_data)
            errVAE = torch.mean(-output)

            errVAE.backward()
            optimizerVAE.step()

            ###############################################
            # (4) Update C network
            ###############################################   
            z,mean,logstd = vae.encoder(data)
            z = torch.cat([z,label_onehot],1)
            recon_data = vae.decoder(z) 

            vae.zero_grad()

            output = C(recon_data)
            real_label = label_onehot
            vae_loss3 = criterion(output, real_label)
         
            vae_loss3.backward()
            optimizerVAE.step()

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time
        print('[%d/%d] time: %.2f real_score: %.4f fake_score: %.4f '
              % (epoch+1, nepoch, per_epoch_ptime,real_data_score,fake_data_score,))
    
        sample = torch.randn(80, nz).to(device)
        sample = torch.cat([sample,specific_label],1)
        output = vae.decoder(sample)
        fake_images = make_grid(output.cpu(), nrow=8, normalize=True).detach()
        save_image(fake_images, './img_CVAE-WGAN/fake_images-{}.png'.format(epoch + 1))

    end_time = time.time()
    total_time = end_time - start_time
    print("total time: %.2f " % total_time )

images = []
for e in range(nepoch):
    img_name = './img_CVAE-WGAN/fake_images-' + str(e+1) + '.png'
    images.append(imageio.imread(img_name))
imageio.mimsave('./generation_animation.gif', images, fps=2)

# torch.save(vae.state_dict(), './CVAE-WGANGP-VAE.pth')
# torch.save(D.state_dict(),'./CVAE-WGANGP-Discriminator.pth')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!
=====> Initialization


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


=====> Begin training
[1/40] time: 17.84 real_score: 0.6050 fake_score: 0.5527 
[2/40] time: 17.48 real_score: 0.6365 fake_score: 0.5845 
[3/40] time: 17.79 real_score: 0.6708 fake_score: 0.5832 
[4/40] time: 17.44 real_score: 0.6954 fake_score: 0.6255 
[5/40] time: 17.46 real_score: 0.7238 fake_score: 0.6356 
[6/40] time: 17.41 real_score: 0.7433 fake_score: 0.6498 
[7/40] time: 17.46 real_score: 0.7532 fake_score: 0.6608 
[8/40] time: 17.36 real_score: 0.7569 fake_score: 0.6604 
[9/40] time: 17.38 real_score: 0.7629 fake_score: 0.6603 
[10/40] time: 17.44 real_score: 0.7592 fake_score: 0.6687 
[11/40] time: 17.36 real_score: 0.7703 fake_score: 0.6521 
[12/40] time: 17.57 real_score: 0.7691 fake_score: 0.6747 
[13/40] time: 17.34 real_score: 0.7737 fake_score: 0.6561 
[14/40] time: 17.42 real_score: 0.7705 fake_score: 0.7127 
[15/40] time: 17.39 real_score: 0.7888 fake_score: 0.7290 
[16/40] time: 17.38 real_score: 0.7918 fake_score: 0.6637 
[17/40] time: 17.55 real_score: 0.7850 fake

In [11]:
!cp -r /content/img_CVAE-WGAN/ /content/drive/MyDrive