In [54]:
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid, save_image

print('PyTorch version:', torch.__version__)
print('torchvision version:', torchvision.__version__)
use_gpu = torch.cuda.is_available()
print('Is GPU available:', use_gpu)

PyTorch version: 1.0.0
torchvision version: 0.2.1
Is GPU available: True


In [55]:
# general settings
device = torch.device('cuda' if use_gpu else 'cpu')

batchsize = 64

data_dir = '../../data/CelebA/celeba-64x64-images-npy/'

output_dir = '../../data/CelebA/celeba-64x64-outputs_VGAN/'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    
save_dir = '../../data/CelebA/celeba-64x64-save/'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

In [56]:
files = os.listdir(data_dir)
for i, f in enumerate(files):
    if i == 0:
        train_images = np.load(data_dir + f)
    else:
        train_images = np.concatenate((train_images, np.load(data_dir + f)))
train_images = np.transpose(train_images, [0, 3, 1, 2])

In [57]:
train_loader = DataLoader((torch.from_numpy(train_images).float() - 127.5) / 127.5, batch_size=batchsize, shuffle=True)

In [58]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [59]:
class Generator(nn.Module):
    def __init__(self, nz, nglf):
        super(Generator, self).__init__()
        self.ct1 = nn.ConvTranspose2d(nz, nglf*8, kernel_size=4, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(nglf*8)
        self.rl1 = nn.ReLU(inplace=True)
        
        self.ct2 = nn.ConvTranspose2d(nglf*8, nglf*4, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(nglf*4)
        self.rl2 = nn.ReLU(inplace=True)
        
        self.ct3 = nn.ConvTranspose2d(nglf*4, nglf*2, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(nglf*2)
        self.rl3 = nn.ReLU(inplace=True)
        
        self.ct4 = nn.ConvTranspose2d(nglf*2, nglf, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(nglf)
        self.rl4 = nn.ReLU(inplace=True)
        
        self.ct5 = nn.ConvTranspose2d(nglf, 3, kernel_size=4, stride=2, padding=1, bias=False)
        self.th5 = nn.Tanh()
        
    def forward(self, z):
        out = self.rl1(self.bn1(self.ct1(z)))
        out = self.rl2(self.bn2(self.ct2(out)))
        out = self.rl3(self.bn3(self.ct3(out)))
        out = self.rl4(self.bn4(self.ct4(out)))
        out = self.th5(self.ct5(out))
        return out

In [99]:
class Discriminator(nn.Module):
    def __init__(self, ndff, n_embed_dim):
        super(Discriminator, self).__init__()
        self.n_embed_dim = n_embed_dim
        
        self.cv1 = nn.Conv2d(3, ndff, 4, 2, 1, bias=False)
        self.lr1 = nn.LeakyReLU(0.2, inplace=True)
        
        self.cv2 = nn.Conv2d(ndff, ndff*2, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(ndff*2)
        self.lr2 = nn.LeakyReLU(0.2, inplace=True)
        
        self.cv3 = nn.Conv2d(ndff*2, ndff*4, 4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(ndff*4)
        self.lr3 = nn.LeakyReLU(0.2, inplace=True)
        
        self.cv4 = nn.Conv2d(ndff*4, ndff*8, 4, 2, 1, bias=False)
        self.bn4 = nn.BatchNorm2d(ndff*8)
        self.lr4 = nn.LeakyReLU(0.2, inplace=True)
        
        self.cv5 = nn.Conv2d(ndff*8, n_embed_dim*2, 4, 1, 0, bias=False)
        
        self.fc6 = nn.Linear(n_embed_dim, 1)
        self.sg6 = nn.Sigmoid()
        
    def forward(self, x, mean_mode=False):
        out = self.lr1(self.cv1(x))
        out = self.lr2(self.bn2(self.cv2(out)))
        out = self.lr3(self.bn3(self.cv3(out)))
        out = self.lr4(self.bn4(self.cv4(out)))
        out = self.cv5(out)
        
        out = out.view(out.size(0), -1)
        mean = out[:, :self.n_embed_dim]
        logvar = out[:, self.n_embed_dim:]
        if mean_mode:
            out = self.sg6(self.fc6(mean))
            return out
        else:    
            z = torch.randn(mean.size(), device=device)
            out = (0.5 * logvar).exp() * z + mean
            out = self.sg6(self.fc6(out))
            return out, mean, logvar

In [104]:
nz = 100
netG = Generator(nz, 64).to(device)
netD = Discriminator(64, 100).to(device)

netG.apply(weights_init)
netD.apply(weights_init)

learning_rate = 0.0002
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(0.5, 0.999))

I_c = 0.5
def VDB_loss(out, label, mean, logvar, beta):
    normal_D_loss = F.binary_cross_entropy(out, label)
    kldiv_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    kldiv_loss = kldiv_loss / out.size(0) - I_c
    final_loss = normal_D_loss + beta * kldiv_loss
    return final_loss, kldiv_loss.detach()

criterionG = nn.BCELoss()
criterionD = VDB_loss

In [105]:
real_label = 1
fake_label = 0

beta = 1.0
alpha = 1e-5

save_image_interval = 1
n_save_image = 25

In [106]:
def train(train_loader, epoch, beta):
    netG.train()
    netD.train()
    
    running_D_loss = 0
    running_G_loss = 0
    
    for real_image in train_loader:
        # prepare for D learning with real image
        optimizerD.zero_grad()
        real_image = real_image.to(device)
        batch_size = real_image.size(0)
        label = torch.full((batch_size,), real_label, device=device)
        
        # train D with real image
        output, mean, logvar = netD(real_image)
        loss_D_real, loss_kldiv_real = criterionD(output, label, mean, logvar, beta)
        loss_D_real.backward()
        running_D_loss += loss_D_real.item()
        
        # prepare for D learning with fake image
        input_noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_image = netG(input_noise)
        label.fill_(fake_label)
        
        # train D with fake image
        output, mean, logvar = netD(fake_image.detach())
        loss_D_fake, loss_kldiv_fake = criterionD(output, label, mean, logvar, beta)
        loss_D_fake.backward()
        running_D_loss += loss_D_fake.item()
        
        optimizerD.step()
        
        loss_kldiv = loss_kldiv_real.item() + loss_kldiv_fake.item()
        beta = max(0.0, beta + alpha * loss_kldiv)
        
        # prepare for G learning
        optimizerG.zero_grad()
        label.fill_(real_label)
        
        # train G
        output, _, _ = netD(fake_image, mean_mode=False)
        loss_G = criterionG(output, label)
        loss_G.backward()
        running_G_loss += loss_G.item()
        
        optimizerG.step()
        
    if epoch % save_image_interval == 0:
        save_image(fake_image[:n_save_image], output_dir + 'generated_images_%d.png' % (epoch), nrow=5, normalize=True)
    
    
    average_D_loss = running_D_loss / len(train_loader)
    average_G_loss = running_G_loss / len(train_loader)
    return average_D_loss, average_G_loss, beta

In [107]:
D_loss_list = []
G_loss_list = []
n_epochs = 100
for epoch in range(n_epochs):
    D_loss, G_loss, beta = train(train_loader, epoch, beta)
    print('epoch[%d/%d] D_loss:%1.4f G_loss:%1.4f Beta:%1.4f' % (epoch+1, n_epochs, D_loss, G_loss, beta))
    D_loss_list.append(D_loss)
    G_loss_list.append(G_loss)

  "Please ensure they have the same size.".format(target.size(), input.size()))
  "Please ensure they have the same size.".format(target.size(), input.size()))


epoch[1/100] D_loss:13.6220 G_loss:0.8584 Beta:1.0547
epoch[2/100] D_loss:2.4249 G_loss:0.6808 Beta:1.0591
epoch[3/100] D_loss:2.2798 G_loss:0.6880 Beta:1.0630


KeyboardInterrupt: 