In [1]:
# importing modules


import pickle
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid, save_image

In [2]:
# hyperparameter settings


batch_size = 100
num_threads = 8

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [3]:
# dataloaders settings


def MNIST_loader():
    train_data = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_threads)
    test_data = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_threads)
    return train_loader, test_loader
  


In [4]:
# several utils : mostly are for the visualizations


def plot_ELBO_curve(generated, lat_dim, train_list, test_list, ylim):
    fig = plt.figure()       
    plt.plot(list(range(1, 1600*60000, 60000)), train_list, 'r-', label='train ELBO')
    plt.plot(list(range(1, 1600*60000, 60000)), test_list, 'r--', label='test ELBO')
    plt.legend()
    plt.xscale('log')
    plt.xlim(10**5, 10**8)
    plt.ylim(ylim)
    plt.savefig('plots/'+generated+'/loss_curves_{}.png'.format(lat_dim))
    plt.close()

def plot_test_images(sample_hat, generated, lat_dim, height, width, n=10):
    sample_hat = sample_hat.view(n*n, 1, height, width)
    grid = save_image(sample_hat, fp="plots/"+generated+"/test_images_{}.png".format(lat_dim), nrow=n, padding=0)
    
def plot_latent_space(model, epoch, generated, height, width, intNumX1, intNumX2):
    z1s = np.linspace(0.0, 1.0, intNumX1, dtype=np.float32)
    z2s = np.linspace(0.0, 1.0, intNumX2, dtype=np.float32)
    dist = torch.distributions.normal.Normal(loc=0, scale=1)    
    canvas = np.zeros((height * intNumX1, width * intNumX2))

    for j, z2 in enumerate(z2s):
        for i, z1 in enumerate(z1s):
            z = torch.tensor([[z1, z2]]).to(device)
            z_hat = dist.icdf(z)
            x_hat = model(z_hat).detach().to(torch.device("cpu")).numpy()
            canvas[height*(intNumX2-j-1):height*(intNumX2-j), width*i:width*(i+1)] = x_hat.reshape(height, width)

    plt.figure(figsize=(8, 8))
    plt.imshow(canvas, cmap='gray')
    plt.tight_layout()
    plt.savefig("plots/"+generated+"/latent_spaces_%d.png" % (epoch), bbox_inches='tight')
    plt.close()


In [15]:
# The model class of Stick Breaking Variational autoencoder


class SBVAE(nn.Module):

    def __init__(self, generated, inp_dim, hid_dim, mode='learn'):
        super(SBVAE, self).__init__()
        self.generated = generated
        self.mode = mode
        
        self.prior_alpha = torch.Tensor([1]).to(device)
        self.prior_beta = torch.Tensor([5]).to(device)
        
        self.x_to_h = nn.Linear(inp_dim, hid_dim)
        self.h_to_a = nn.Linear(hid_dim, lat_dim-1)
        self.h_to_b = nn.Linear(hid_dim, lat_dim-1)
        
        self.z_to_h = nn.Linear(lat_dim, hid_dim)
        self.h_to_x = nn.Linear(hid_dim, inp_dim)
        self.h_to_x_mu = nn.Linear(hid_dim, inp_dim)
        self.h_to_x_log_var = nn.Linear(hid_dim, inp_dim)
        
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

    def beta_fn(self, m, n):
        output = torch.exp(torch.lgamma(m)+torch.lgamma(n)-torch.lgamma(m+n))
        return output
    
    def digamma_expansion(self, b):
        output = torch.log(b) - 1/(2*b) - 1/(12*b**2)
        return output
    
    def compute_D_KL(self, a, b):
        D_KL = 0
        for idx in range(1, 51):
            D_KL += (self.prior_beta-1)*b * 1/(idx+a*b) * self.beta_fn(idx/(a), b)
        D_KL += torch.log(a*b) + torch.log(self.beta_fn(self.prior_alpha, self.prior_beta)) - (b-1)/(b)
        D_KL += (a-self.prior_alpha)/(a) * (-0.57721 - self.digamma_expansion(b) - 1/(b))
        return D_KL
    
    
    def encoderNet(self, x):
        h = self.tanh(self.x_to_h(x))
        a = self.softplus(self.h_to_a(h))
        b = self.softplus(self.h_to_b(h))
        
        dist = torch.distributions.uniform.Uniform(0.01, 0.99)
        uniform = dist.sample(a.size()).to(device)
        v = (1-(uniform**(1/(b))))**(1/(a))
        batch, _ = v.size()
        ones = torch.ones(size=(batch, 1)).to(device)
        v = torch.cat((v, ones), 1)

        self.kl = self.compute_D_KL(a,b).mean(dim=0).sum()
        return v

    def decoderNet(self, v):
        pi = torch.ones(size=v.size()).to(device)
        for idx in range(lat_dim):
            product = 1
            for sub_idx in range(idx):
                product *= 1-v[:,sub_idx]
            pi[:,idx] = v[:,idx] * product    
        
        h_ = self.tanh(self.z_to_h(pi))
        if self.generated == 'Bernoulli':
            x_p = self.sigmoid(self.h_to_x(h_))
            return x_p

        elif self.generated == 'Gaussian':
            x_mu = self.sigmoid(self.h_to_x_mu(h_))
            x_log_var = self.h_to_x_log_var(h_)
            x_std = torch.exp(0.5 * x_log_var)
            return x_mu, x_std
            
    def forward(self, x):
        if self.mode == 'learn':
            v = self.encoderNet(x)
            if self.generated == 'Bernoulli':
                x_p = self.decoderNet(v)
                self.recon = -(x * torch.log(x_p+1e-5) + (1 - x) * torch.log(1 - x_p+1e-5)).mean(dim=0).sum()
            elif self.generated == 'Gaussian':
                x_mu, x_std = self.decoderNet(v)
                self.dist = torch.distributions.normal.Normal(loc=x_mu, scale=x_std)
                self.recon = -self.dist.log_prob(x).mean(dim=0).sum()
            return self.recon, self.kl

        elif self.mode == 'generate':
            v = x
            if self.generated == 'Bernoulli':
                x_p = self.decoderNet(v)
                x_hat = x_p
            elif self.generated == 'Gaussian':
                x_mu, x_std = self.decoderNet(v)
                x_hat = x_mu
            return x_hat
        

In [20]:
# Function to train SBVAE with MNIST dataset


def MNIST_main(dist='Bernoulli'):
    train_loader, test_loader = MNIST_loader()
    model = SBVAE(generated=dist, inp_dim=784, hid_dim=500).to(device)
    model.load_state_dict(torch.load('models/'+dist+'/SBVAE_%d.pt' % (lat_dim)))
    
    
    #for name, param in model.named_parameters():
    #    if param.requires_grad:
    #        print(name)
    
    train_list, test_list = [], []
    
    for epoch in range(n_epoch):
        #if (epoch+1) % 200 == 0:
        #    lr = lr * 0.1
        #    print('current learning rate is ', lr)
        optimizer = optim.Adagrad(model.parameters(), lr=lr)
        
        train_loss, recon_err, kl_div, test_loss = 0.0, 0.0, 0.0, 0.0
            
        cnt = 0
        for x, y in train_loader:
            cnt += 1
            model.train()
            inputs = x.to(device)
            new_batch_size = x.size()[0]
            inputs = inputs.view(new_batch_size, -1)
            
            recon, kl = model(inputs)
            loss = recon + kl         
 
            #optimizer.zero_grad()
            #loss.backward()
            #optimizer.step()
            
            train_loss += loss * new_batch_size / batch_size
            recon_err += recon * new_batch_size / batch_size
            kl_div += kl * new_batch_size / batch_size
        
        train_loss = train_loss / cnt
        recon_err = recon_err / cnt
        kl_div = kl_div / cnt
        train_list.append(-train_loss.item())
        
        with torch.no_grad():
            cnt = 0
            for x, y in test_loader:
                cnt += 1
                model.eval()
                inputs = x.to(device)
                new_batch_size = x.size()[0]
                inputs = inputs.view(new_batch_size, -1)
            
                recon, kl = model(inputs)
                loss = recon + kl
           
                test_loss += loss * new_batch_size / batch_size

            test_loss = test_loss / cnt
            test_list.append(-test_loss.item())
        
        if (epoch+1) % 1 == 0:
            print('[Epoch %d] train_loss: %.3f, recon: %.3f, kl_div: %.3f, test_loss : %3f'
                  % (epoch+1, train_loss, recon_err, kl_div, test_loss))
             
            #torch.save(model.state_dict(), 'models/'+dist+'/SBVAE_%d.pt' % (lat_dim))
    
    #with open('loss/'+dist+'/train_ELBO_%d.txt' % (lat_dim), 'wb') as f:
    #    pickle.dump(train_list, f)
    #with open('loss/'+dist+'/test_ELBO_%d.txt' % (lat_dim), 'wb') as f:
    #    pickle.dump(test_list, f)
    
    #plot_ELBO_curve(model.generated, lat_dim, train_list, test_list, (-150, -90))
    
    print('start')
    model.mode = 'generate'
    dist = torch.distributions.uniform.Uniform(0.01, 0.99)
    uniform = dist.sample((100,49)).to(device)
    v = (1-(uniform**(1/(model.prior_beta))))**(1/(model.prior_alpha))
    ones = torch.ones(size=(100, 1)).to(device)
    v = torch.cat((v, ones), 1)
    sample_hat = model(v)
    plot_test_images(sample_hat, generated='Bernoulli', lat_dim=10, height=28, width=28, n=10)
    

In [21]:
#Training and visualization can be done with customized setting of learning/hyper parameters


if __name__ == '__main__': 
    lr = 1e-3
    lat_dim = 50
    print(device)
    n_epoch = 1


    MNIST_main()


cuda:1
[Epoch 1] train_loss: 119.085, recon: 101.699, kl_div: 17.386, test_loss : 118.412506
start


RuntimeError: size mismatch, m1: [100 x 51], m2: [50 x 500] at /opt/conda/conda-bld/pytorch_1587428266983/work/aten/src/THC/generic/THCTensorMathBlas.cu:283