In [24]:
!/opt/bin/nvidia-smi

!rm -rf /content/sample_data

!rm -rf /content/img_CVAE-WGAN

Sat Nov 28 10:16:17 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   68C    P0    30W /  70W |   2991MiB / 15079MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
+-------

In [14]:
from __future__ import print_function
import argparse
import os,time
import random
import math
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.utils import make_grid
import torch.nn.functional as F

import imageio
import matplotlib.pyplot as plt

In [15]:
class VAE(nn.Module):
    def __init__(self):

        super(VAE, self).__init__()
        
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(3,16,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(16,64,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64,32,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2,inplace=True),
        )

        self.encoder_fc1=nn.Linear(32*8*8,nz)
        self.encoder_fc2=nn.Linear(32*8*8,nz)
        self.Sigmoid = nn.Sigmoid()
       
        self.decoder_fc=nn.Linear(nz+10,3 * 64 * 64)
        self.br=nn.Sequential(
            nn.BatchNorm2d(3),
            nn.ReLU(True),
        )
        self.gen = nn.Sequential(
            nn.Conv2d(3,64,3,stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.Conv2d(64,32,3,stride=1,padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.Conv2d(32,3,3,stride=2,padding=1),
            nn.Tanh(),
        )
    
    def decoder(self,z):
        x = self.decoder_fc(z)
        x=x.view(x.shape[0],3,64,64)
        x=self.br(x)
        x=self.gen(x)
        return x

    def noise_reparameterize(self,mean,logvar):
        eps = torch.randn(mean.shape).to(device)
        z = mean + eps * torch.exp(logvar)
        return z

    def encoder(self,x):
        out1, out2 = self.encoder_conv(x), self.encoder_conv(x)
        mean = self.encoder_fc1(out1.view(out1.shape[0], -1))
        logstd = self.encoder_fc2(out2.view(out2.shape[0], -1))
        z = self.noise_reparameterize(mean, logstd)
        return z,mean,logstd

    def forward(self, x):
        z = self.encoder(x)
        output = self.decoder(z)
        return output

In [9]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.dis = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d((2, 2)),

            nn.Conv2d(32, 64, 3, stride=1, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d((2, 2)),
        )
        self.fc = nn.Sequential(
            nn.Linear(8 * 8 * 64, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 10),
            nn.Sigmoid()
        )

    def forward(self, input):
        x = self.dis(input)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x.squeeze(1)

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Conv2d(3,32,3,stride=1,padding=1),
            nn.LeakyReLU(0.2,True),
            nn.MaxPool2d((2,2)),

            nn.Conv2d(32,64,3,stride=1,padding=1),
            nn.LeakyReLU(0.2,True),
            nn.MaxPool2d((2,2)),
        )
        self.fc = nn.Sequential(
            nn.Linear(8*8*64,1024),
            nn.LeakyReLU(0.2,True),
            nn.Linear(1024,1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.dis(x)
        x=x.view(x.size(0),-1)
        x=self.fc(x)
        x = x.view(x.size(0))
        return x

In [7]:
def loss_function(recon_x,x,mean,logstd):
    # BCE = F.binary_cross_entropy(recon_x,x,reduction='sum')
    MSE = MSECriterion(recon_x,x)
    # 因为var是标准差的自然对数，先求自然对数然后平方转换成方差
    var = torch.pow(torch.exp(logstd),2)
    KLD = -0.5 * torch.sum(1+torch.log(var)-torch.pow(mean,2)-var)
    return MSE+KLD

In [25]:
if __name__ == '__main__':
    batchSize = 128
    imageSize = 28
    nz=100
    nepoch= 200
    lambda_ = 10
    c = 0.005

    if not os.path.exists('./img_CVAE-WGAN'):
        os.mkdir('./img_CVAE-WGAN')

    random.seed(1)
    torch.manual_seed(1)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    cudnn.benchmark = True

    dataset = dset.CIFAR10(root='./data',
                train=True,
                transform=transforms.Compose([transforms.ToTensor()]),
                download=True
                )

    dataloader = torch.utils.data.DataLoader(dataset,
                          batch_size=batchSize,
                          shuffle=True)

    print("=====> Initialization")
    vae = VAE().to(device)
    # vae.load_state_dict(torch.load('./VAE-WGANGP-VAE_v2.pth'))

    # discriminator
    D = Discriminator().to(device)
    # D.load_state_dict(torch.load('./VAE-GAN-Discriminator.pth'))

    # Classifier
    C = Classifier().to(device)
    # C.load_state_dict(torch.load('./CVAE-GAN-Classifier.pth'))
    
    criterion = nn.BCELoss().to(device)
    MSECriterion = nn.MSELoss().to(device)

    # best=0.0001,0.0004,0.0008
    optimizerD = optim.Adam(D.parameters(), lr=0.0002,betas=(0.5, 0.999))
    optimizerC = optim.Adam(C.parameters(), lr=0.0002,betas=(0.5, 0.999))
    optimizerVAE = optim.Adam(vae.parameters(), lr=0.0002,betas=(0.5, 0.999))

    # sample_label
    specific_label = torch.zeros((100, 10)).cuda()
    for i in range(10):
      for j in range(10):
        specific_label[10*i+j,i] = 1

    print("=====> Begin training")
    
    start_time = time.time()
    for epoch in range(nepoch):

        if(epoch%5==0): lambda_ -= 1
        
        epoch_start_time = time.time()

        for i, (data,label) in enumerate(dataloader, 0):

            # data processing
            data = data.to(device)
            label = label.to(device)
            label_onehot = torch.zeros((data.shape[0], 10)).to(device)
            label_onehot[torch.arange(data.shape[0]), label] = 1
            batch_size = data.shape[0]


            for n in range(1):
            # if epoch % 2 == 0:
              

                # training C with real
                C.zero_grad()
                
                output = C(data)
                real_label = label_onehot.to(device)
                errC = criterion(output, real_label)
                
                errC.backward()
                optimizerC.step()
                
                ###################################################################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###################################################################

                # train with real
                batch_size = data.shape[0]

                D.zero_grad()

                real_out = D(data)

                real_label = torch.ones(batch_size).to(device)  
                fake_label = torch.zeros(batch_size).to(device)  

                real_data_score = real_out.mean().item()

                # train with fake, taking the noise vector z as the input of D network
                z = torch.randn(batch_size, nz+10).to(device)
                fake_data = vae.decoder(z)
                fake_out = D(fake_data)

                # fake_data_score用来输出查看的，是虚假照片的评分，0最假，1为真
                fake_data_score = fake_out.mean().item()

                d_loss = torch.mean(fake_out) - torch.mean(real_out)               
                d_loss.backward()
                optimizerD.step()

                for layer in D.dis:
                  if (layer.__class__.__name__ == 'Linear'):
                      layer.weight.requires_grad = False
                      layer.weight.clamp_(-c, c)
                      layer.weight.requires_grad = True

            ###################################################
            # (2) Update G network which is the decoder of VAE
            ###################################################
            z,mean,logstd = vae.encoder(data)
            z = torch.cat([z,label_onehot],1)
            recon_data = vae.decoder(z)
            
            vae.zero_grad()
            
            vae_loss = loss_function(recon_data,data,mean,logstd)

            vae_loss.backward(retain_graph=True)
            optimizerVAE.step()
            
            ###############################################
            # (3) Update G network: maximize log(D(G(z)))
            ###############################################
            z,mean,logstd = vae.encoder(data)
            z = torch.cat([z,label_onehot],1)
            recon_data = vae.decoder(z)
            
            vae.zero_grad()

            # real_label = torch.ones(batch_size).to(device)  
            output = D(recon_data)
            errVAE = torch.mean(-output)

            errVAE.backward()
            optimizerVAE.step()

            ###############################################
            # (4) Update C network
            ###############################################   
            z,mean,logstd = vae.encoder(data)
            z = torch.cat([z,label_onehot],1)
            recon_data = vae.decoder(z) 

            vae.zero_grad()

            output = C(recon_data)
            real_label = label_onehot
            vae_loss3 = criterion(output, real_label)
         
            vae_loss3.backward()
            optimizerVAE.step()

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time
        print('[%d/%d] time: %.2f real_score: %.4f fake_score: %.4f '
              % (epoch+1, nepoch, per_epoch_ptime,real_data_score,fake_data_score,))
    
        sample = torch.randn(100, nz).to(device)
        sample = torch.cat([sample,specific_label],1)
        output = vae.decoder(sample)
        fake_images = make_grid(output.cpu(), nrow=10, normalize=False).detach()
        save_image(fake_images, './img_CVAE-WGAN/fake_images-{}.png'.format(epoch + 1))

        if(epoch%10==0 and epoch!=0):
          !cp -r /content/img_CVAE-WGAN/ /content/drive/MyDrive
          
    end_time = time.time()
    total_time = end_time - start_time
    print("total time: %.2f " % total_time )

images = []
for e in range(nepoch):
    img_name = './img_CVAE-WGAN/fake_images-' + str(e+1) + '.png'
    images.append(imageio.imread(img_name))
imageio.mimsave('./generation_animation.gif', images, fps=2)

# torch.save(vae.state_dict(), './CVAE-WGANGP-VAE.pth')
# torch.save(D.state_dict(),'./CVAE-WGANGP-Discriminator.pth')

Files already downloaded and verified
=====> Initialization
=====> Begin training
[1/200] time: 77.58 real_score: 0.7499 fake_score: 0.7469 
[2/200] time: 80.86 real_score: 0.0677 fake_score: 0.0357 
[3/200] time: 80.10 real_score: 0.6751 fake_score: 0.5915 
[4/200] time: 79.98 real_score: 0.5018 fake_score: 0.1424 
[5/200] time: 79.97 real_score: 0.6490 fake_score: 0.2344 
[6/200] time: 79.90 real_score: 0.1209 fake_score: 0.0250 
[7/200] time: 79.88 real_score: 0.5099 fake_score: 0.1198 
[8/200] time: 79.88 real_score: 0.6208 fake_score: 0.1619 
[9/200] time: 79.89 real_score: 0.9070 fake_score: 0.4438 
[10/200] time: 79.91 real_score: 0.7572 fake_score: 0.1133 
[11/200] time: 79.90 real_score: 0.9321 fake_score: 0.1933 
[12/200] time: 79.90 real_score: 0.7207 fake_score: 0.0334 
[13/200] time: 79.87 real_score: 0.9097 fake_score: 0.0963 
[14/200] time: 79.93 real_score: 0.8032 fake_score: 0.0409 
[15/200] time: 79.96 real_score: 0.8771 fake_score: 0.0948 
[16/200] time: 79.87 real_s

In [26]:
!cp -r /content/img_CVAE-WGAN/ /content/drive/MyDrive
!cp -r /content/generation_animation.gif /content/drive/MyDrive