In [11]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
import torch
from torch.autograd import Variable
from torch.nn import Parameter
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import sys
from numpy import random
import time


## Hyperparameters

In [2]:
# VAE model parameters
NUM_PIXELS = 4096
NUM_HIDDEN_1 = 32
NUM_HIDDEN_2 = 32
NUM_HIDDEN_3 = 64
NUM_HIDDEN_4 = 64
NUM_HIDDEN_5 = 256
Z_Dimension = 10
Filter_Size = 4
Pooling_Size = 2
Stride_Size = 2
# VAE training parameters
BATCH_SIZE = 64
EPOCH = 200
Log_Interval = 10

# Discriminator parameters
Gamma = 35
DISC_HIDDEN = 1000
D_BATCH_SIZE = 256
Update_frequency = 5
False
#Path parameters
PATH_vae = './factorVAE_dsprite_vae-%02d' % Gamma
PATH_disc = './factorVAE_dsprite_disc-%02d' % Gamma

# Restore
Restore = False

# Metric parameters
NUM_Data_Metric = 100
NUM_Factors = 5

# Reconstruction error metric
Rec_loss = []

## Load DSprite Dataset

In [3]:
# load DSprite Dataset
imgs = np.load('imgs.npy')
#latents_values = np.load('latents_values.npy')
data_size, x, y = imgs.shape
imgs_Tensor = torch.FloatTensor(imgs)
train_loader = torch.utils.data.DataLoader(imgs_Tensor, batch_size=BATCH_SIZE, shuffle=True)

## VAE Model

In [4]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        #.unsqueeze(0)
        self.enc_conv1 = nn.Conv2d(1, NUM_HIDDEN_1, kernel_size=Filter_Size, stride=Stride_Size)
        self.enc_conv2 = nn.Conv2d(NUM_HIDDEN_1, NUM_HIDDEN_2, kernel_size=Filter_Size, stride=Stride_Size)
        self.enc_conv3 = nn.Conv2d(NUM_HIDDEN_2, NUM_HIDDEN_3, kernel_size=Filter_Size, stride=Stride_Size)
        self.enc_conv4 = nn.Conv2d(NUM_HIDDEN_3, NUM_HIDDEN_4, kernel_size=Filter_Size, stride=Stride_Size)
        self.enc_linear1 = nn.Linear(NUM_HIDDEN_4*2*2, NUM_HIDDEN_5)
        self.enc_mu_z = nn.Linear(NUM_HIDDEN_5, Z_Dimension)
        self.enc_logvar_z = nn.Linear(NUM_HIDDEN_5, Z_Dimension)
        #
        self.dec_linear1 = nn.Linear(Z_Dimension, NUM_HIDDEN_5)
        self.dec_linear2 = nn.Linear(NUM_HIDDEN_5, NUM_HIDDEN_4*2*2)
        self.dec_conv1 = nn.ConvTranspose2d(NUM_HIDDEN_4, NUM_HIDDEN_3, kernel_size=Filter_Size, stride=Stride_Size)
        self.dec_conv2 = nn.ConvTranspose2d(NUM_HIDDEN_3, NUM_HIDDEN_2, kernel_size=Filter_Size, stride=Stride_Size)
        self.dec_conv3 = nn.ConvTranspose2d(NUM_HIDDEN_2, NUM_HIDDEN_1, kernel_size=Filter_Size, stride=Stride_Size,output_padding=1)
        self.dec_conv4 = nn.ConvTranspose2d(NUM_HIDDEN_1, 1, kernel_size=Filter_Size, stride=Stride_Size)
                
        #
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

        
        
    def Encoder(self, x):
        enc_hidden_1 = self.relu(self.enc_conv1(x))
        enc_hidden_2 = self.relu(self.enc_conv2(enc_hidden_1))
        enc_hidden_3 = self.relu(self.enc_conv3(enc_hidden_2))
        enc_hidden_4 = self.relu(self.enc_conv4(enc_hidden_3))
        enc_hidden_5 = self.relu(self.enc_linear1(enc_hidden_4.view(-1, NUM_HIDDEN_4*2*2)))
        mu_z = self.enc_mu_z(enc_hidden_5)
        logvar_z = self.enc_logvar_z(enc_hidden_5)
        return mu_z, logvar_z

    def Reparam(self, mu_z, logvar_z):
        std = logvar_z.mul(0.5).exp() 
        eps = Variable(std.data.new(std.size()).normal_())
        eps = eps.cuda()
        return eps.mul(std).add_(mu_z)
    
    def Decoder(self, z):
        dec_hidden_1 = self.relu(self.dec_linear1(z))
        dec_hidden_2 = self.relu(self.dec_linear2(dec_hidden_1))
        #dec_hidden_2 = dec_hidden_2.view(2, 2, NUM_HIDDEN_4)
        #dec_hidden_2 = dec_hidden_2.transpose(1,2).transpose(0,1).unsqueeze(0)
        dec_hidden_3 = self.relu(self.dec_conv1(dec_hidden_2.view(-1, NUM_HIDDEN_4, 2, 2)))
        dec_hidden_4 = self.relu(self.dec_conv2(dec_hidden_3))
        dec_hidden_5 = self.relu(self.dec_conv3(dec_hidden_4))
        x = self.dec_conv4(dec_hidden_5).squeeze(1).view(-1, NUM_PIXELS)
        return self.sigmoid(x)
        
    def forward(self, x):
        mu_z, logvar_z = self.Encoder(x)
        z_sample = self.Reparam(mu_z, logvar_z)
        return self.Decoder(z_sample), mu_z, logvar_z, z_sample


## Discriminator Model

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.linear1 = nn.Linear(Z_Dimension, DISC_HIDDEN)
        self.linear2 = nn.Linear(DISC_HIDDEN, DISC_HIDDEN)
        self.linear3 = nn.Linear(DISC_HIDDEN, DISC_HIDDEN)
        self.linear4 = nn.Linear(DISC_HIDDEN, DISC_HIDDEN)
        self.linear5 = nn.Linear(DISC_HIDDEN, DISC_HIDDEN)
        self.linear6 = nn.Linear(DISC_HIDDEN, 1)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, z):
        D_hidden_1 = self.relu(self.linear1(z))
        D_hidden_2 = self.relu(self.linear2(D_hidden_1))
        D_hidden_3 = self.relu(self.linear3(D_hidden_2))
        D_hidden_4 = self.relu(self.linear4(D_hidden_3))
        D_hidden_5 = self.relu(self.linear5(D_hidden_4))
        D_logit = self.linear6(D_hidden_5)
        return self.sigmoid(D_logit)
        

## Initialize VAE and Discriminator

In [6]:
# initialize model
vae = VAE()
vae.cuda()
VAE_optimizer = optim.Adam(vae.parameters(), lr = 1e-3)

disc = Discriminator()
disc.cuda()
D_optimizer = optim.Adam(disc.parameters(), lr = 1e-4)

## Loss Function

In [7]:
# loss function
criterion = nn.BCELoss(size_average=False)
def elbo_loss(recon_x, x, mu_z, logvar_z, z_sample):
    recon_loss = criterion(recon_x, x.view(-1, NUM_PIXELS))
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_loss = 0.5 * torch.sum(-1. - logvar_z + mu_z.pow(2) + torch.exp(logvar_z))
    tc_loss = Gamma * torch.sum((torch.log(disc.forward(z_sample) + 1e-6) - torch.log(1. - disc.forward(z_sample) + 1e-6)))
    loss = recon_loss + kl_loss + tc_loss
    return loss

def disc_loss(D_real, D_fake, D_BATCH_SIZE):
    ones_label = Variable(torch.ones((D_BATCH_SIZE, 1)))
    zeros_label = Variable(torch.zeros((D_BATCH_SIZE, 1)))
    ones_label = ones_label.cuda()
    zeros_label = zeros_label.cuda()
    D_loss_real = criterion(D_real, ones_label)
    D_loss_fake = criterion(D_fake, zeros_label)
    D_loss = D_loss_real + D_loss_fake
    return D_loss

def disc_accuracy(D_joint, D_marginal, D_BATCH_SIZE):
    D_joint = D_joint.cpu()
    D_joint = D_joint.data.numpy()

    D_marginal = D_marginal.cpu()
    D_marginal = D_marginal.data.numpy()
    D_accuracy = 0.0
    for i in range(D_BATCH_SIZE):
        if D_joint[i][0] > 0.5:
            D_accuracy += 1
        if D_marginal[i][0] < 0.5:
            D_accuracy += 1
    return D_accuracy

## Ancestral Sampling

In [8]:
def ancestral_sampling(batch_x, D_BATCH_SIZE, joint_flag=True):
    # sample joint distribution
    if joint_flag == True:
        x_idx = Variable(torch.LongTensor(D_BATCH_SIZE).random_(batch_x.size()[0]))
        x_samples = torch.index_select(batch_x, 0, x_idx)
        x_samples = x_samples.cuda()
       
        mu_z, logvar_z = vae.Encoder(x_samples)
        joint_z = vae.Reparam(mu_z, logvar_z)       
        return joint_z
    # sample marginal distribution
    else:      
        x_idx = Variable(torch.LongTensor(D_BATCH_SIZE).random_(batch_x.size()[0]))
        x_samples = torch.index_select(batch_x, 0, x_idx)
        x_samples = x_samples.cuda()

        mu_z, logvar_z = vae.Encoder(x_samples)
        z_samples = vae.Reparam(mu_z, logvar_z) 
        
        
        for i in range(Z_Dimension):
            if i == 0:
                rand_ind = torch.randperm(D_BATCH_SIZE)
            else:
                sample_ind = torch.randperm(D_BATCH_SIZE)
                rand_ind = torch.cat((rand_ind.view(D_BATCH_SIZE, i), sample_ind.view(D_BATCH_SIZE, 1)), 1)  
        z_samples = z_samples.cpu()
        z_samples = z_samples.data
        marginal_z = torch.zeros(D_BATCH_SIZE, Z_Dimension).scatter_(0, rand_ind, z_samples)
        marginal_z = Variable(marginal_z)
        marginal_z = marginal_z.cuda()
        return marginal_z

## Training 

In [12]:
# training 
if Restore == False:
    print("Start Training...")

    for epoch in range(EPOCH):
        time_start = time.time()
        VAE_train_loss = 0.0
        D_train_loss = 0.0
        D_train_accuracy = 0.0
        for batch_idx, data in enumerate(train_loader):
        # update VAE
            data = data.unsqueeze(1)
            data = Variable(data)
            data_vae = data.cuda()
            VAE_optimizer.zero_grad()
            recon_batch, mu_z, logvar_z, z_sample = vae.forward(data_vae)
            VAE_loss = elbo_loss(recon_batch, data_vae, mu_z, logvar_z, z_sample)
            VAE_loss.backward()
            VAE_train_loss += VAE_loss.data[0]

            VAE_optimizer.step()
            # update discriminator
            for d_train_idx in range(Update_frequency):
                joint_z = ancestral_sampling(data, D_BATCH_SIZE, joint_flag=True)
                marginal_z = ancestral_sampling(data, D_BATCH_SIZE, joint_flag=False)
                D_optimizer.zero_grad()
                D_joint = disc.forward(joint_z)
                D_marginal = disc.forward(marginal_z)
                D_loss = disc_loss(D_joint, D_marginal, D_BATCH_SIZE)
                D_accuracy = disc_accuracy(D_joint, D_marginal, D_BATCH_SIZE)
                D_loss.backward()
                D_train_loss += D_loss.data[0]
                D_train_accuracy += D_accuracy
                D_optimizer.step()
                    
        time_end = time.time()
        print('====> Epoch: %d elbo_Loss : %0.8f' % ((epoch + 1), VAE_train_loss / len(train_loader.dataset)))
        print('                discriminator_loss : %0.8f' % (D_train_loss / (D_BATCH_SIZE * 2 * Update_frequency * (batch_idx + 1))))
        print('====>           accuracy : %0.8f' % ( D_train_accuracy / (D_BATCH_SIZE * 2 * Update_frequency * (batch_idx + 1))))

        samples = Variable(torch.randn(16, Z_Dimension))
        samples = samples.cuda()
        samples = vae.Decoder(samples).cpu()

        samples = samples.data.numpy()
        fig = plt.figure(figsize = (4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace = 0.05, hspace = 0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(64, 64), cmap='Greys_r')
        plt.show()

torch.save(vae.state_dict(), PATH_vae)
torch.save(disc.state_dict(), PATH_disc)

Start Training...




KeyboardInterrupt: 

In [None]:
if Restore:
    vae.load_state_dict(torch.load(PATH_vae))
    disc.load_state_dict(torch.load(PATH_disc))    

In [None]:
def plot_reconstruction():
# Define number of values per latents and functions to convert to indices
    latents_sizes = np.array([ 1,  3,  6, 40, 32, 32])
    latents_bases = np.concatenate((latents_sizes[::-1].cumprod()[::-1][1:],
                                    np.array([1,])))

    latents_sampled = sample_latent(size=1)
    latents_sampled[:, -5] = 2
    # Select images
    indices_sampled = latent_to_index(latents_sampled)
    imgs_sampled = imgs[indices_sampled]
    imgs_sampled.shape
    img = imgs_sampled[0]
    img_variable = Variable(torch.FloatTensor(img))
    img_variable = img_variable.unsqueeze(0).unsqueeze(0)
    img_variable = img_variable.cuda()
    img_z_mu, img_z_logvar = vae.Encoder(img_variable)
    img_z = vae.Reparam(img_z_mu, img_z_logvar)
    img_z_cpu = img_z.cpu()
    img_z_cpu = img_z_cpu.data.numpy()[0]


    fig = plt.figure(figsize = (Z_Dimension, 10))
    gs = gridspec.GridSpec(Z_Dimension, 10)
    gs.update(wspace = 0.05, hspace = 0.05)

    for z in range(Z_Dimension):
        for i in range(10):
            img_z_cpu = img_z.cpu()
            img_z_cpu = img_z_cpu.data.numpy()[0]
            sample_i = img_z_cpu
            sample_i[z] = -3.0 + i * 0.6
            sample_i = Variable(torch.FloatTensor(sample_i))
            sample_i = sample_i.unsqueeze(0).unsqueeze(0)
            sample_i = sample_i.cuda()
            img_i = vae.Decoder(sample_i).cpu()
            img_i = img_i.data.numpy()
            #
            ax = plt.subplot(gs[i + z * 10])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(img_i.reshape(64, 64), cmap='Greys_r')
    plt.show()

## Metric

In [None]:
latents_sizes = np.array([ 1,  3,  6, 40, 32, 32])
latents_bases = np.concatenate((latents_sizes[::-1].cumprod()[::-1][1:],
                                np.array([1,])))
def latent_to_index(latents):
  return np.dot(latents, latents_bases).astype(int)

def sample_latent(size=1):
  samples = np.zeros((size, latents_sizes.size))
  for lat_i, lat_size in enumerate(latents_sizes):
    samples[:, lat_i] = np.random.randint(lat_size, size=size)
  return samples

def calculate_std():
    for batch_idx, data in enumerate(train_loader):
        data = Variable(data).unsqueeze(1)
        data = data.cuda()
        data_z_mu, data_z_logvar = vae.Encoder(data)
        data_z = vae.Reparam(data_z_mu, data_z_logvar)
        data_z = data_z.cpu()
        data_z = data_z.data.numpy()
        if batch_idx == 0:
            full_data_z = data_z
        else:
            full_data_z = np.concatenate((full_data_z, data_z), 0)
    std = np.std(full_data_z, axis=0)
    return std

L = 100
num_votes  = 500
def generate_samples(std):

    fk_list = np.random.randint(1, 6, num_votes)
    classifier_samples = np.zeros((10, 5))
    for i in range(num_votes):
        fk = fk_list[i]
        latents_sampled = sample_latent(size=L)
        if fk == 1:
            pf = np.random.randint(0, 3)
        elif fk == 2:
            pf = np.random.uniform(0.5, 1)
        elif fk == 3:
            pf = np.random.uniform(0, 2*np.pi)
        elif fk == 4:
            pf = np.random.uniform(0, 1)
        else:
            pf = np.random.uniform(0, 1)

        latents_sampled[:, fk] = pf
        # Select images
        indices_sampled = latent_to_index(latents_sampled)
        imgs_sampled = imgs[indices_sampled]

        imgs_variable = Variable(torch.FloatTensor(imgs_sampled))
        imgs_variable = imgs_variable.unsqueeze(1)
        imgs_variable = imgs_variable.cuda()
        img_z_mu, img_z_logvar = vae.Encoder(imgs_variable)
        
        img_z = vae.Reparam(img_z_mu, img_z_logvar)
        img_z = img_z.cpu()
        img_z = img_z.data.numpy()
        #s = np.std(img_z, axis = 0)
        img_z = np.divide(img_z, std)
        d_min = np.argmin(np.var(img_z, axis =0))
        classifier_samples[d_min, fk-1] += 1
    prediction = np.argmax(classifier_samples, axis=1) + 1
    fk_dict = {'1' : 'shape', '2' : 'scale', '3' : 'orintation', '4' : 'x-pos', '5' : 'y-pos'}
    correct = 0.0
    for j in range(10):
        print('Z%d ==> %s' % (j+1, fk_dict[str(prediction[j])]))
        correct += classifier_samples[j, prediction[j] - 1]
    accuracy = correct / num_votes
    print('accuracy : %0.8f' % accuracy)
    
    return classifier_samples, prediction, accuracy


In [None]:
std = calculate_std()
classifier_samples, prediction, accuracy = generate_samples(std)