In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
import sys
from torchvision import datasets, transforms
import pickle
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.datasets import fetch_20newsgroups
import os
import random
import time
sys.path.append('/home/hao/Research/probtorch/')
from probtorch.util import expand_inputs
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 0.4.0a0+265e1a9 cuda: True


In [2]:
# model parameters
NUM_HIDDEN = 500
NUM_LATENTS = 50
VOCABULARY_SIZE = 2000
NUM_DOCS = 11268

# training parameters
NUM_SAMPLES = 1
BATCH_SIZE = 500
NUM_EPOCHS = 1000
LEARNING_RATE = 5e-5
ALPHA = 1.0
BETA = (1.0, 1.0, 0.0, 1.0, 0.0)
BIAS = (NUM_DOCS - 1) / (BATCH_SIZE - 1) 
CUDA = torch.cuda.is_available()

RESTORE = False

## dataset

In [3]:
def shuffler(train_dataset):
    index = np.arange(train_dataset.shape[0])
    np.random.shuffle(index)
    return train_dataset[index, :]

def NUM_ITERS(train_dataset):
    remainder = train_dataset.shape[0] % BATCH_SIZE
    if remainder == 0:
        num_iters = int(train_dataset.shape[0] / BATCH_SIZE)
    else:
        num_iters = int(((train_dataset.shape[0] - remainder) / BATCH_SIZE) + 1)
    return num_iters 

In [4]:
def data_set(data_url):
    """process data input."""
    data = []
    word_count = []
    fin = open(data_url)
    while True:
        line = fin.readline()
        if not line:
            break
        id_freqs = line.split()
        doc = {}
        count = 0
        for id_freq in id_freqs[1:]:
            items = id_freq.split(':')
            # python starts from 0
            doc[int(items[0])-1] = int(items[1])
            count += int(items[1])
        if count > 0:
            data.append(doc)
            word_count.append(count)
    fin.close()
    return data, word_count

def load_voca(filename):
    fp = open(filename)
    voca = dict()
    value = 0
    for line in fp.readlines():
        line = line.strip().split(' ')
        value += 1
        voca[line[0]] = value
    return voca

def convert_np(train_dataset, voca):
    train_dataset_array = np.zeros((len(train_dataset), len(voca)))
    for i in range(len(train_dataset)):
        for k, v in train_dataset[i].items():
            train_dataset_array[i, k] = v
    return train_dataset_array

In [5]:
train_dataset, train_count = data_set('/home/hao/Research/nips2018/nvdm/20News_train.txt')
voca = load_voca('/home/hao/Research/nips2018/nvdm/20News_voca.txt')
train_dataset = convert_np(train_dataset, voca)
len(train_dataset), len(voca)

(11268, 2000)

## encoder

In [6]:
class Encoder(nn.Module):
    def __init__(self, vocabulary_size=VOCABULARY_SIZE,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_hidden = nn.Sequential(
            nn.Linear(vocabulary_size, num_hidden),
            nn.ReLU())
        self.latent_mean = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        self.latent_log_std = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
    @expand_inputs
    def forward(self, documents, num_samples=None):
        q = probtorch.Trace()
        hidden = self.enc_hidden(documents)
        latents_mean = self.latent_mean(hidden)
        latents_std = torch.exp(self.latent_log_std(hidden))
        q.normal(latents_mean, latents_std, name='z')
        return q

## decoder

In [7]:
class Decoder(nn.Module):
    def __init__(self, vocabulary_size=VOCABULARY_SIZE,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS,
                       num_samples=NUM_SAMPLES,
                       batch_size=BATCH_SIZE):
        super(self.__class__, self).__init__()
        self.dec_hidden = nn.Sequential(
            nn.Linear(num_latents, num_hidden),
            nn.ReLU())
        self.dec_document = nn.Sequential(
                           nn.Linear(num_hidden, vocabulary_size),
                           nn.LogSoftmax(dim=-1))
        self.prior_mean = torch.zeros((num_samples, batch_size, num_latents)).cuda()
        self.prior_cov = torch.ones((num_samples, batch_size, num_latents)).cuda()

    def forward(self, documents, q=None, num_samples=None):
        p = probtorch.Trace()
        latents = p.normal(self.prior_mean, self.prior_cov, value=q['z'], name='z')   
        hidden = self.dec_hidden(latents)
        documents_recon = self.dec_document(hidden)
        p.loss(cross_entropy, documents_recon, documents, name='documents')
        return p


In [8]:
def cross_entropy(x_recon, x):
    return - (x_recon * x).sum(-1)

## initialization

In [9]:
def initialize():
    enc = Encoder()
    dec = Decoder()
    if CUDA:
        enc.cuda()
        dec.cuda()
    optimizer =  torch.optim.Adam(list(enc.parameters())+list(dec.parameters()),lr=LEARNING_RATE)    
    return enc, dec, optimizer

In [10]:
def elbo(q, p):
    if NUM_SAMPLES is None:
        ave_elbo = probtorch.objectives.montecarlo.elbo(q, p, sample_dim=None, batch_dim=0, beta=1.0,)
        elbos = probtorch.objectives.montecarlo.elbo(q, p, sample_dim=None, batch_dim=0, beta=1.0, reduce=False)
        return ave_elbo, elbos
    else:
        ave_elbo = probtorch.objectives.montecarlo.elbo(q, p, sample_dim=0, batch_dim=1, beta=1.0)
        elbos = probtorch.objectives.montecarlo.elbo(q, p, sample_dim=0, batch_dim=1, beta=1.0, reduce=False)
        return ave_elbo, elbos

In [11]:
# def elbo(q, p):
#     if NUM_SAMPLES is None:
#         ave_elbo = probtorch.objectives.marginal.elbo(q, p, sample_dim=None, batch_dim=0, alpha=ALPHA, bias=BIAS)
#         return ave_elbo
#     else:
#         ave_elbo = probtorch.objectives.marginal.elbo(q, p, sample_dim=0, batch_dim=1, alpha=ALPHA, bias=BIAS)
#         return ave_elbo

In [12]:
# def elbo(q, p):
#     if NUM_SAMPLES is None:
#         ave_elbo = probtorch.objectives.marginal.elbo(q, p, sample_dim=None, batch_dim=0, beta=BETA, bias=BIAS)
#         elbos = probtorch.objectives.marginal.elbo(q, p, sample_dim=None, batch_dim=0, beta=BETA, bias=BIAS, reduce=False)
#         return ave_elbo, elbos
#     else:
#         ave_elbo = probtorch.objectives.marginal.elbo(q, p, sample_dim=0, batch_dim=1, beta=BETA, bias=BIAS)
#         elbos = probtorch.objectives.marginal.elbo(q, p, sample_dim=0, batch_dim=1, beta=BETA, bias=BIAS, reduce=False)
#         return ave_elbo, elbos

In [13]:
def train(enc, dec, optimizer):
    if RESTORE == False:
        num_iters = NUM_ITERS(train_dataset)
        elbos_list = []
        perplexities_list = []
        for epoch in range(NUM_EPOCHS):
            time_start = time.time()
            train_dataset_shuffled = shuffler(train_dataset)
            epoch_elbo = 0.0
            perplexity = 0.0
            N = 0.0
            for i in range(num_iters):
                documents = train_dataset_shuffled[i * BATCH_SIZE : (i + 1) * BATCH_SIZE, :]
                if documents.shape[0] == BATCH_SIZE:
                    N += 1
                    documents = torch.FloatTensor(documents)
                    N_d = documents.sum(-1)
                    if CUDA:
                        documents = documents.cuda()
                    optimizer.zero_grad()
                    q = enc(documents, num_samples=NUM_SAMPLES)
                    p = dec(documents, q, num_samples=NUM_SAMPLES)
                    ave_elbo, elbos = elbo(q, p)
                    loss = -ave_elbo
                    loss.backward()
                    optimizer.step()
                    if CUDA:
                        loss = loss.cpu()
                        elbos = elbos.cpu()
                        epoch_elbo += -loss.data.numpy()
                    perplexity += torch.mul(elbos, (1 / N_d)).sum()
                    
            average_epoch_elbo = epoch_elbo / N
            elbos_list.append(average_epoch_elbo)
            perplexity = torch.exp((-1 / (N * BATCH_SIZE)) * perplexity)
            perplexities_list.append(perplexity)
            
            time_end = time.time()
            print('Epoch : %d, ELBO : %f (%ds), Perplexity : %f' % (epoch+1, average_epoch_elbo, time_end - time_start, perplexity))
        return elbos_list, perplexities_list 

In [14]:
enc, dec, optimizer = initialize()
elbos_list, perplexities_list = train(enc, dec, optimizer)

Epoch : 1, ELBO : -766.879367 (0s), Perplexity : 2000.549561
Epoch : 2, ELBO : -752.261000 (0s), Perplexity : 1865.950073
Epoch : 3, ELBO : -745.280118 (0s), Perplexity : 1756.040161
Epoch : 4, ELBO : -744.908231 (0s), Perplexity : 1662.083618
Epoch : 5, ELBO : -732.717521 (0s), Perplexity : 1587.514771
Epoch : 6, ELBO : -735.136366 (0s), Perplexity : 1526.008789
Epoch : 7, ELBO : -732.432506 (0s), Perplexity : 1476.679321
Epoch : 8, ELBO : -725.018668 (0s), Perplexity : 1438.075317
Epoch : 9, ELBO : -723.221308 (0s), Perplexity : 1403.983521
Epoch : 10, ELBO : -720.649925 (0s), Perplexity : 1375.562256
Epoch : 11, ELBO : -719.668815 (0s), Perplexity : 1353.605347
Epoch : 12, ELBO : -719.599307 (0s), Perplexity : 1334.398804
Epoch : 13, ELBO : -718.848378 (0s), Perplexity : 1318.392944
Epoch : 14, ELBO : -716.145716 (0s), Perplexity : 1306.251953
Epoch : 15, ELBO : -713.943481 (0s), Perplexity : 1296.810547
Epoch : 16, ELBO : -717.371471 (0s), Perplexity : 1288.367798
Epoch : 17, ELBO 

Epoch : 133, ELBO : -675.601970 (0s), Perplexity : 1030.721191
Epoch : 134, ELBO : -674.527854 (0s), Perplexity : 1029.824585
Epoch : 135, ELBO : -673.971755 (0s), Perplexity : 1026.882935
Epoch : 136, ELBO : -671.495322 (0s), Perplexity : 1026.622070
Epoch : 137, ELBO : -673.276037 (0s), Perplexity : 1025.784302
Epoch : 138, ELBO : -668.625327 (0s), Perplexity : 1024.693604
Epoch : 139, ELBO : -674.414809 (0s), Perplexity : 1024.629150
Epoch : 140, ELBO : -674.281228 (0s), Perplexity : 1020.719177
Epoch : 141, ELBO : -671.730444 (0s), Perplexity : 1022.135010
Epoch : 142, ELBO : -671.774958 (0s), Perplexity : 1021.280518
Epoch : 143, ELBO : -672.727292 (0s), Perplexity : 1018.609497
Epoch : 144, ELBO : -672.696153 (0s), Perplexity : 1018.506042
Epoch : 145, ELBO : -665.537703 (0s), Perplexity : 1018.467651
Epoch : 146, ELBO : -670.827739 (0s), Perplexity : 1015.789917
Epoch : 147, ELBO : -671.769024 (0s), Perplexity : 1016.104370
Epoch : 148, ELBO : -672.705988 (0s), Perplexity : 1014

Epoch : 265, ELBO : -659.197363 (0s), Perplexity : 910.820496
Epoch : 266, ELBO : -660.754639 (0s), Perplexity : 910.224792
Epoch : 267, ELBO : -659.474776 (0s), Perplexity : 908.460449
Epoch : 268, ELBO : -658.080408 (0s), Perplexity : 909.017273
Epoch : 269, ELBO : -658.042392 (0s), Perplexity : 907.929932
Epoch : 270, ELBO : -660.851895 (0s), Perplexity : 907.387207
Epoch : 271, ELBO : -657.922635 (0s), Perplexity : 906.819275
Epoch : 272, ELBO : -659.613414 (0s), Perplexity : 905.703491
Epoch : 273, ELBO : -656.953078 (0s), Perplexity : 904.706421
Epoch : 274, ELBO : -654.440488 (0s), Perplexity : 905.556274
Epoch : 275, ELBO : -658.795571 (0s), Perplexity : 904.417908
Epoch : 276, ELBO : -658.413336 (0s), Perplexity : 904.170349
Epoch : 277, ELBO : -657.614277 (0s), Perplexity : 902.489197
Epoch : 278, ELBO : -657.097909 (0s), Perplexity : 902.597656
Epoch : 279, ELBO : -656.876631 (0s), Perplexity : 900.393311
Epoch : 280, ELBO : -658.248460 (0s), Perplexity : 900.728210
Epoch : 

Epoch : 398, ELBO : -649.729623 (0s), Perplexity : 831.619385
Epoch : 399, ELBO : -649.856484 (0s), Perplexity : 828.216064
Epoch : 400, ELBO : -649.075994 (0s), Perplexity : 828.466064
Epoch : 401, ELBO : -647.784388 (0s), Perplexity : 830.171265
Epoch : 402, ELBO : -649.598899 (0s), Perplexity : 829.447571
Epoch : 403, ELBO : -648.721605 (0s), Perplexity : 830.329224
Epoch : 404, ELBO : -649.670885 (0s), Perplexity : 829.028809
Epoch : 405, ELBO : -650.112310 (0s), Perplexity : 826.995056
Epoch : 406, ELBO : -648.255316 (0s), Perplexity : 827.094849
Epoch : 407, ELBO : -647.948839 (0s), Perplexity : 827.452209
Epoch : 408, ELBO : -650.520647 (0s), Perplexity : 825.618347
Epoch : 409, ELBO : -649.976463 (0s), Perplexity : 826.032654
Epoch : 410, ELBO : -648.943995 (0s), Perplexity : 824.840027
Epoch : 411, ELBO : -649.024708 (0s), Perplexity : 825.311768
Epoch : 412, ELBO : -649.555223 (0s), Perplexity : 825.388123
Epoch : 413, ELBO : -649.682818 (0s), Perplexity : 825.069763
Epoch : 

Epoch : 531, ELBO : -640.604201 (0s), Perplexity : 775.604797
Epoch : 532, ELBO : -641.527668 (0s), Perplexity : 773.678894
Epoch : 533, ELBO : -639.208324 (0s), Perplexity : 771.372864
Epoch : 534, ELBO : -637.773229 (0s), Perplexity : 772.487427
Epoch : 535, ELBO : -641.581848 (0s), Perplexity : 769.569763
Epoch : 536, ELBO : -643.506181 (0s), Perplexity : 769.380066
Epoch : 537, ELBO : -638.712638 (0s), Perplexity : 772.019043
Epoch : 538, ELBO : -642.823545 (0s), Perplexity : 769.635437
Epoch : 539, ELBO : -643.415852 (0s), Perplexity : 769.539673
Epoch : 540, ELBO : -642.960435 (0s), Perplexity : 768.365967
Epoch : 541, ELBO : -639.654111 (0s), Perplexity : 769.021729
Epoch : 542, ELBO : -641.682040 (0s), Perplexity : 767.689514
Epoch : 543, ELBO : -644.364163 (0s), Perplexity : 767.246033
Epoch : 544, ELBO : -642.568873 (0s), Perplexity : 766.827942
Epoch : 545, ELBO : -642.465915 (0s), Perplexity : 767.787292
Epoch : 546, ELBO : -639.821189 (0s), Perplexity : 767.244934
Epoch : 

Epoch : 664, ELBO : -637.170577 (0s), Perplexity : 728.517090
Epoch : 665, ELBO : -633.856601 (0s), Perplexity : 727.915649
Epoch : 666, ELBO : -636.217676 (0s), Perplexity : 727.245361
Epoch : 667, ELBO : -634.535750 (0s), Perplexity : 726.412903
Epoch : 668, ELBO : -637.299930 (0s), Perplexity : 725.816650
Epoch : 669, ELBO : -634.685569 (0s), Perplexity : 725.355469
Epoch : 670, ELBO : -633.469527 (0s), Perplexity : 725.191162
Epoch : 671, ELBO : -635.264113 (0s), Perplexity : 725.229553
Epoch : 672, ELBO : -633.635315 (0s), Perplexity : 725.236816
Epoch : 673, ELBO : -634.304984 (0s), Perplexity : 724.246704
Epoch : 674, ELBO : -634.267320 (0s), Perplexity : 725.529114
Epoch : 675, ELBO : -637.457819 (0s), Perplexity : 723.827942
Epoch : 676, ELBO : -637.777485 (0s), Perplexity : 725.114441
Epoch : 677, ELBO : -635.854842 (0s), Perplexity : 726.190857
Epoch : 678, ELBO : -637.454226 (0s), Perplexity : 723.580200
Epoch : 679, ELBO : -638.572457 (0s), Perplexity : 723.212463
Epoch : 

Epoch : 797, ELBO : -632.006353 (0s), Perplexity : 692.877380
Epoch : 798, ELBO : -633.238636 (0s), Perplexity : 692.449341
Epoch : 799, ELBO : -632.151148 (0s), Perplexity : 690.959167
Epoch : 800, ELBO : -626.832943 (0s), Perplexity : 691.244507
Epoch : 801, ELBO : -630.537040 (0s), Perplexity : 690.390686
Epoch : 802, ELBO : -626.121135 (0s), Perplexity : 690.940674
Epoch : 803, ELBO : -630.909990 (0s), Perplexity : 690.224487
Epoch : 804, ELBO : -628.610008 (0s), Perplexity : 690.168518
Epoch : 805, ELBO : -630.963769 (0s), Perplexity : 690.829346
Epoch : 806, ELBO : -630.941004 (0s), Perplexity : 689.643494
Epoch : 807, ELBO : -627.995334 (0s), Perplexity : 691.285706
Epoch : 808, ELBO : -634.611256 (0s), Perplexity : 690.200134
Epoch : 809, ELBO : -632.918096 (0s), Perplexity : 688.941772
Epoch : 810, ELBO : -631.222065 (0s), Perplexity : 689.389648
Epoch : 811, ELBO : -631.002078 (0s), Perplexity : 689.395569
Epoch : 812, ELBO : -627.978441 (0s), Perplexity : 687.804077
Epoch : 

Epoch : 931, ELBO : -624.501620 (0s), Perplexity : 663.066162
Epoch : 932, ELBO : -626.174355 (0s), Perplexity : 663.649780
Epoch : 933, ELBO : -628.078380 (0s), Perplexity : 663.258423
Epoch : 934, ELBO : -627.043632 (0s), Perplexity : 662.665710
Epoch : 935, ELBO : -623.292056 (0s), Perplexity : 661.874939
Epoch : 936, ELBO : -628.623624 (0s), Perplexity : 662.715027
Epoch : 937, ELBO : -627.867101 (0s), Perplexity : 661.505493
Epoch : 938, ELBO : -625.293593 (0s), Perplexity : 661.737671
Epoch : 939, ELBO : -627.762859 (0s), Perplexity : 662.154663
Epoch : 940, ELBO : -628.571711 (0s), Perplexity : 661.203064
Epoch : 941, ELBO : -622.647333 (0s), Perplexity : 661.379944
Epoch : 942, ELBO : -627.683186 (0s), Perplexity : 660.207825
Epoch : 943, ELBO : -626.690308 (0s), Perplexity : 661.813721
Epoch : 944, ELBO : -628.886400 (0s), Perplexity : 660.523987
Epoch : 945, ELBO : -624.303286 (0s), Perplexity : 660.171631
Epoch : 946, ELBO : -625.857774 (0s), Perplexity : 659.470947
Epoch : 

In [15]:
dec.para

-78876