In [14]:
!/opt/bin/nvidia-smi

!rm -rf /content/sample_data

!rm -rf /content/img_CVAE-WGANGP

Wed Dec  2 01:10:31 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   47C    P0    28W /  70W |   2631MiB / 15079MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
+-------

In [15]:
from __future__ import print_function
import argparse
import os,time
import random
import math
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.utils import make_grid
import torch.nn.functional as F
import scipy.stats
import imageio
import matplotlib.pyplot as plt

In [16]:
class VAE(nn.Module):
    def __init__(self):

        super(VAE, self).__init__()
        
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(3,16,kernel_size=5,stride=2,padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(16,64,kernel_size=5,stride=2,padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64,32,kernel_size=5,stride=1,padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2,inplace=True),
        )

        self.encoder_fc1=nn.Linear(32*25,nz)
        self.encoder_fc2=nn.Linear(32*25,nz)
        self.Sigmoid = nn.Sigmoid()
       
        self.decoder_fc=nn.Linear(nz+10,3 * 64 * 64)
        self.br=nn.Sequential(
            nn.BatchNorm2d(3),
            nn.ReLU(True),
        )
        self.gen = nn.Sequential(
            nn.Conv2d(3,64,5,stride=1,padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.Conv2d(64,32,5,stride=1,padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.Conv2d(32,3,4,stride=2,padding=1),
            nn.Tanh(),
        )
    
    def decoder(self,z):
        x = self.decoder_fc(z)
        x=x.view(x.shape[0],3,64,64)
        x=self.br(x)
        x = nn.functional.dropout(x, p=0.5, training=self.training)
        x=self.gen(x)
        return x

    def noise_reparameterize(self,mean,logvar):
        eps = torch.randn(mean.shape).to(device)
        z = mean + eps * torch.exp(logvar)
        return z

    def encoder(self,x):
        out1, out2 = self.encoder_conv(x), self.encoder_conv(x)
        mean = self.encoder_fc1(out1.view(out1.shape[0], -1))
        logstd = self.encoder_fc2(out2.view(out2.shape[0], -1))
        z = self.noise_reparameterize(mean, logstd)
        return z,mean,logstd

    def forward(self, x):
        z = self.encoder(x)
        output = self.decoder(z)
        return output

In [17]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Conv2d(3,32,5,stride=1,padding=1),
            nn.LeakyReLU(0.2,True),
            nn.MaxPool2d((2,2)),

            nn.Conv2d(32,64,5,stride=1,padding=1),
            nn.LeakyReLU(0.2,True),
            nn.MaxPool2d((2,2)),
        )
        self.fc = nn.Sequential(
            # nn.Linear(8*8*64,1024),
            nn.Linear(2304,1024),
            nn.LeakyReLU(0.2,True),
            nn.Linear(1024,1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.dis(x)
        x=x.view(x.size(0),-1)
        x=self.fc(x)
        x = x.view(x.size(0))
        return x

In [18]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.dis = nn.Sequential(
            nn.Conv2d(3, 32, 5, stride=1, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d((2, 2)),

            nn.Conv2d(32, 64, 5, stride=1, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d((2, 2)),
        )
        self.fc = nn.Sequential(
            # nn.Linear(8 * 8 * 64, 1024),
            nn.Linear(2304, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 10),
            nn.Sigmoid()
        )

    def forward(self, input):
        x = self.dis(input)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x.squeeze(1)

In [19]:
def loss_function(recon_x,x,mean,logstd):
    # BCE = F.binary_cross_entropy(recon_x,x,reduction='sum')
    MSE = MSECriterion(recon_x,x)
    # 因为var是标准差的自然对数，先求自然对数然后平方转换成方差
    var = torch.pow(torch.exp(logstd),2)
    KLD = -0.5 * torch.sum(1+torch.log(var)-torch.pow(mean,2)-var)
    return MSE+KLD

def loss_function(recon_x,x,mean,logstd):
    # BCE = F.binary_cross_entropy(recon_x,x,reduction='sum')
    MSE = MSECriterion(recon_x,x)
    return MSE+js_div(recon_x,x)

def js_div(p_output, q_output, get_softmax=True):
    """
    Function that measures JS divergence between target and output logits:
    """
    KLDivLoss = nn.KLDivLoss(reduction='batchmean')
    if get_softmax:
        p_output = F.softmax(p_output)
        q_output = F.softmax(q_output)
    log_mean_output = ((p_output + q_output )/2).log()
    return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2

In [21]:
if __name__ == '__main__':
    batchSize = 128
    nz=100
    nepoch=200
    lambda_ = 10
    
    if not os.path.exists('./img_CVAE-WGANGP'):
        os.mkdir('./img_CVAE-WGANGP')

    random.seed(1)
    torch.manual_seed(1)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    cudnn.benchmark = True

    dataset = dset.CIFAR10(root='./data',
                train=True,
                transform=transforms.Compose([transforms.ToTensor()]),
                download=True
                )

    dataloader = torch.utils.data.DataLoader(dataset,
                          batch_size=batchSize,
                          shuffle=True)

    print("=====> Initialization")
    vae = VAE().to(device)
    # vae.load_state_dict(torch.load('./VAE-WGANGP-VAE_v2.pth'))

    # discriminator
    D = Discriminator().to(device)
    # D.load_state_dict(torch.load('./VAE-GAN-Discriminator.pth'))

    # Classifier
    C = Classifier().to(device)
    # C.load_state_dict(torch.load('./CVAE-GAN-Classifier.pth'))
    
    criterion = nn.BCELoss().to(device)
    MSECriterion = nn.MSELoss().to(device)

    optimizerD = optim.Adam(D.parameters(), lr=0.0002,betas=(0.5, 0.999))
    optimizerC = optim.Adam(C.parameters(), lr=0.0002,betas=(0.5, 0.999))
    optimizerVAE = optim.Adam(vae.parameters(), lr=0.0002,betas=(0.5, 0.999))

    # sample_label
    specific_label = torch.zeros((100, 10)).to(device)
    for i in range(10):
      for j in range(10):
        specific_label[10*i+j,i] = 1

    print("=====> Begin training")
    
    start_time = time.time()
    for epoch in range(nepoch):

        # if(epoch%5==0): lambda_ -= 1
        
        epoch_start_time = time.time()

        for i, (data,label) in enumerate(dataloader, 0):
            
            for n in range(3):
              
                # data processing
                data = data.to(device)
                label = label.to(device)
                label_onehot = torch.zeros((data.shape[0], 10)).to(device)
                label_onehot[torch.arange(data.shape[0]), label] = 1
                batch_size = data.shape[0]

                # training C with real
                C.zero_grad()
                output = C(data)
                real_label = label_onehot.to(device)
                errC = criterion(output, real_label)
                errC.backward()
                optimizerC.step()
                
                ###################################################################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###################################################################
                for p in D.parameters():
                  p.requires_grad = True

                # train with real
                D.zero_grad()
                batch_size = data.shape[0]
                real_out = D(data)

                # real_label = torch.ones(batch_size).to(device)  
                # fake_label = torch.zeros(batch_size).to(device)  

                real_label = torch.tensor( [ random.uniform(0.9,1) for _ in range(batch_size) ] ).to(device)
                fake_label = torch.tensor( [ random.uniform(0,0.1) for _ in range(batch_size) ] ).to(device)

                real_data_score = real_out.mean().item()

                # train with fake, taking the noise vector z as the input of D network
                z = torch.randn(batch_size, nz+10).to(device)
                fake_data = vae.decoder(z)
                fake_out = D(fake_data)

                # fake_data_score用来输出查看的，是虚假照片的评分，0最假，1为真
                fake_data_score = fake_out.mean().item()

                alpha = torch.rand((batch_size, 1, 1, 1)).to(device)
                x_hat = alpha * data + (1 - alpha) * fake_data

                pred_hat = D(x_hat)

                gradients = \
                    torch.autograd.grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).to(device),
                                        create_graph=True, retain_graph=True, only_inputs=True)[0]
                gradient_penalty = lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean()

                d_loss = torch.mean(fake_out) - torch.mean(real_out) + gradient_penalty
                d_loss.backward()
                optimizerD.step()

            ###################################################
            # (2) Update G network which is the decoder of VAE
            ###################################################
            for p in D.parameters():
              p.requires_grad = False
            z,mean,logstd = vae.encoder(data)
            z = torch.cat([z,label_onehot],1)
            vae.zero_grad()
            recon_data = vae.decoder(z)
            vae_loss = loss_function(recon_data,data,mean,logstd)
            vae_loss.backward(retain_graph=True)
            optimizerVAE.step()
            
            ###############################################
            # (3) Update G network: maximize log(D(G(z)))
            ###############################################
            z,mean,logstd = vae.encoder(data)
            z = torch.cat([z,label_onehot],1)
            vae.zero_grad()
            recon_data = vae.decoder(z) 
            output = D(recon_data)
            errVAE = torch.mean(-output)
            errVAE.backward()
            optimizerVAE.step()

            ###############################################
            # (4) Update C network
            ###############################################   
            z,mean,logstd = vae.encoder(data)
            z = torch.cat([z,label_onehot],1)
            vae.zero_grad()
            recon_data = vae.decoder(z)         
            output = C(recon_data)
            real_label = label_onehot
            vae_loss3 = criterion(output, real_label)
            vae_loss3.backward()
            optimizerVAE.step()

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time
        print('[%d/%d] time: %.2f real_score: %.4f fake_score: %.4f '
              % (epoch+1, nepoch, per_epoch_ptime,real_data_score,fake_data_score,))
    
        sample = torch.randn(100, nz).to(device)
        sample = torch.cat([sample,specific_label],1)
        output = vae.decoder(sample)
        fake_images = make_grid(output.cpu(), nrow=10, normalize=True).detach()
        save_image(fake_images, './img_CVAE-WGANGP/fake_images-{}.png'.format(epoch + 1))

        if(epoch%10==0 and epoch!=0):
          !cp -r /content/img_CVAE-WGANGP/ /content/drive/MyDrive

    end_time = time.time()
    total_time = end_time - start_time
    print("total time: %.2f " % total_time )

images = []
for e in range(nepoch):
    img_name = './img_CVAE-WGANGP/fake_images-' + str(e+1) + '.png'
    images.append(imageio.imread(img_name))
imageio.mimsave('./generation_animation.gif', images, fps=2)

# torch.save(vae.state_dict(), './CVAE-WGANGP-VAE.pth')
# torch.save(D.state_dict(),'./CVAE-WGANGP-Discriminator.pth')

!cp -r /content/img_CVAE-WGAN-GP/ /content/drive/MyDrive
!cp -r /content/generation_animation.gif /content/drive/MyDrive

Files already downloaded and verified
=====> Initialization
=====> Begin training
[1/200] time: 191.70 real_score: 0.4432 fake_score: 0.5923 
[2/200] time: 192.58 real_score: 0.4868 fake_score: 0.4857 
[3/200] time: 193.31 real_score: 0.5519 fake_score: 0.4309 
[4/200] time: 194.70 real_score: 0.5113 fake_score: 0.3356 
[5/200] time: 194.58 real_score: 0.5314 fake_score: 0.3438 
[6/200] time: 194.97 real_score: 0.5203 fake_score: 0.2955 
[7/200] time: 195.21 real_score: 0.5953 fake_score: 0.3360 
[8/200] time: 195.21 real_score: 0.6361 fake_score: 0.3156 
[9/200] time: 194.49 real_score: 0.6599 fake_score: 0.3582 
[10/200] time: 194.96 real_score: 0.6758 fake_score: 0.3719 
[11/200] time: 194.49 real_score: 0.6444 fake_score: 0.3811 
[12/200] time: 190.54 real_score: 1.0000 fake_score: 1.0000 
[13/200] time: 188.26 real_score: 1.0000 fake_score: 1.0000 
[14/200] time: 188.19 real_score: 1.0000 fake_score: 1.0000 
[15/200] time: 188.08 real_score: 1.0000 fake_score: 1.0000 


KeyboardInterrupt: ignored

In [None]:
!cp -r /content/img_CVAE-WGANGP/ /content/drive/MyDrive
!cp -r /content/generation_animation.gif /content/drive/MyDrive