In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# from pytorch_visualize import *

In [3]:
# import sys
# !{sys.executable} -m pip install https://download.pytorch.org/whl/cu75/torch-0.1.12.post1-cp36-cp36m-linux_x86_64.whl

In [4]:
import numpy as np
import math
import pickle
import argparse
import os
import math
import matplotlib.pyplot as plt

In [5]:
from collections import OrderedDict

In [6]:
def to_onehot(data, min_length):
    return np.bincount(data, minlength=min_length)

In [7]:
def print_perp(model):
    cost=[]
    model.eval()                        # switch to testing mode
    input_ = tensor_te
    recon, loss = model(input_, compute_loss=True, avg_loss=False)
    loss = loss.data
    counts = tensor_te.sum(1)
    avg = (loss / counts).mean()
    print('The approximated perplexity is: ', math.exp(avg))

def visualize():
    global recon
    input_ = tensor_te[:10]
    register_vis_hooks(model)
    recon = model(input_, compute_loss=False)
    remove_vis_hooks()
    save_visualization('pytorch_model', 'png')

In [8]:
dataset_tr = 'data/20news_clean/train.txt.npy'
data_tr = np.load(dataset_tr, encoding="latin1")
dataset_te = 'data/20news_clean/test.txt.npy'
data_te = np.load(dataset_te, encoding="latin1")
vocab = 'data/20news_clean/vocab.pkl'
vocab_file = open(vocab,'rb')
vocab = pickle.load(vocab_file)
vocab_size=len(vocab)

In [9]:
#--------------convert to one-hot representation------------------
print ('Converting data to one-hot representation')
data_tr = np.array([to_onehot(doc.astype('int'),vocab_size) for doc in data_tr if np.sum(doc)!=0])
data_te = np.array([to_onehot(doc.astype('int'),vocab_size) for doc in data_te if np.sum(doc)!=0])
#--------------print the data dimentions--------------------------
print ('Data Loaded')
print ('Dim Training Data', data_tr.shape)
print ('Dim Test Data', data_te.shape)

Converting data to one-hot representation
Data Loaded
Dim Training Data (11258, 1995)
Dim Test Data (7487, 1995)


In [10]:
#--------------make tensor datasets-------------------------------
tensor_tr = torch.from_numpy(data_tr).float()
tensor_te = torch.from_numpy(data_te).float()

In [11]:
class ProdLDA(nn.Module):
    def __init__(self, net_arch):
        super(ProdLDA, self).__init__()
        ac = net_arch
        self.net_arch = net_arch
        # encoder
        self.en1_fc     = nn.Linear(ac.num_input, ac.en1_units)             # 1995 -> 100
        self.en2_fc     = nn.Linear(ac.en1_units, ac.en2_units)             # 100  -> 100
        self.en2_drop   = nn.Dropout(ac.drop_rate)
        self.mean_fc    = nn.Linear(ac.en2_units, ac.num_topic)             # 100  -> 50
        self.mean_bn    = nn.BatchNorm1d(ac.num_topic)                      # bn for mean
        self.logvar_fc  = nn.Linear(ac.en2_units, ac.num_topic)             # 100  -> 50
        self.logvar_bn  = nn.BatchNorm1d(ac.num_topic)                      # bn for logvar
        self.en = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(ac.num_input, ac.en1_units)),
            ('act1', nn.Softplus()),
            ('linear2', nn.Linear(ac.en1_units, ac.en2_units)),
            ('act2', nn.Softplus()),
            ('dropout', nn.Dropout(ac.drop_rate))
        ]))
        self.mean = nn.Sequential(OrderedDict([
            ('linear', nn.Linear(ac.en2_units, ac.num_topic)),
            ('batchnorm', nn.BatchNorm1d(ac.num_topic))
        ]))
        self.logvar = nn.Sequential(OrderedDict([
            ('linear', nn.Linear(ac.en2_units, ac.num_topic)),
            ('batchnorm', nn.BatchNorm1d(ac.num_topic))
        ]))
        
        self.de = nn.Sequential(OrderedDict([
            ('act1', nn.Softmax(dim=-1)),
            ('dropout', nn.Dropout(ac.drop_rate)),
            ('linear', nn.Linear(ac.num_topic, ac.num_input)),
            ('batchnorm', nn.BatchNorm1d(ac.num_input)),
            ('act2', nn.Softmax(dim=-1))
        ]))
        # prior mean and variance as constant buffers
        self.prior_mean   = torch.Tensor(1, ac.num_topic).fill_(0)
        self.prior_var    = torch.Tensor(1, ac.num_topic).fill_(ac.variance)
        self.prior_mean = nn.Parameter(self.prior_mean, requires_grad=False)
        self.prior_var = nn.Parameter(self.prior_var, requires_grad=False)
        self.prior_logvar = nn.Parameter(self.prior_var.log(), requires_grad=False)
        # initialize decoder weight
        if ac.init_mult != 0:
            #std = 1. / math.sqrt( ac.init_mult * (ac.num_topic + ac.num_input))
            self.de.linear.weight.data.uniform_(0, ac.init_mult)
        # remove BN's scale parameters
        for component in [self.mean, self.logvar, self.de]:
            component.batchnorm.weight.requires_grad = False
            component.batchnorm.weight.fill_(1.0)

    def encode(self, input_):
        encoded = self.en(input_)
        posterior_mean = self.mean(encoded)
        posterior_logvar = self.logvar(encoded)
        return encoded, posterior_mean, posterior_logvar
    
    def decode(self, input_, posterior_mean, posterior_var):
        # take sample
        eps = input_.data.new().resize_as_(posterior_mean.data).normal_() # noise 
        z = posterior_mean + posterior_var.sqrt() * eps                   # reparameterization
        # do reconstruction
        recon = self.de(z)          # reconstructed distribution over vocabulary
        return recon
    
    def forward(self, input_, compute_loss=False, avg_loss=True):
        # compute posterior
        en2, posterior_mean, posterior_logvar = self.encode(input_) 
        posterior_var    = posterior_logvar.exp()
        
        recon = self.decode(input_, posterior_mean, posterior_var)
        if compute_loss:
            return recon, self.loss(input_, recon, posterior_mean, posterior_logvar, posterior_var, avg_loss)
        else:
            return recon

    def loss(self, input_, recon, posterior_mean, posterior_logvar, posterior_var, avg=True):
        # NL
        NL  = -(input_ * (recon + 1e-10).log()).sum(1)
        # KLD, see Section 3.3 of Akash Srivastava and Charles Sutton, 2017, 
        # https://arxiv.org/pdf/1703.01488.pdf
        prior_mean   = self.prior_mean.expand_as(posterior_mean)
        prior_var    = self.prior_var.expand_as(posterior_mean)
        prior_logvar = self.prior_logvar.expand_as(posterior_mean)
        var_division    = posterior_var  / prior_var
        diff            = posterior_mean - prior_mean
        diff_term       = diff * diff / prior_var
        logvar_division = prior_logvar - posterior_logvar
        # put KLD together
        KLD = 0.5 * ( (var_division + diff_term + logvar_division).sum(1) - self.net_arch.num_topic )
        # loss
        loss = (NL + KLD)
        # in traiming mode, return averaged loss. In testing mode, return individual loss
        if avg:
            return loss.mean()
        else:
            return loss

In [13]:
class args():
    en1_units=100
    en2_units=100
    num_topic=50
    num_input=1995
    variance=0.995
    init_mult=1.0
    learning_rate=0.002
    batch_size=200
    momentum=0.99
    num_epoch=100
    nogpu=True
    drop_rate=0.2

In [14]:
net_arch = args # en1_units, en2_units, num_topic, num_input
net_arch.num_input = data_tr.shape[1]
model = ProdLDA(net_arch)

In [15]:
# device = torch.device("cpu")

In [16]:
# model.to(device)

In [17]:
# vocab_size = 10
# N = 3
# K = 2
# input_.data.new().resize_as_(posterior_mean.data).normal_()
# input_ = torch.randint(0, 10, (3, 10))
# posterior_mean = torch.randint(0, 10, (3, 2))
# posterior_mean
# input_
# input_.data.new().resize_as_(posterior_mean.data).normal_()

In [18]:
optimizer = torch.optim.Adam(model.parameters(), args.learning_rate, betas=(args.momentum, 0.999))

In [19]:
for epoch in range(args.num_epoch):
    all_indices = torch.randperm(tensor_tr.size(0)).split(args.batch_size)
    loss_epoch = 0.0
    model.train()                    # switch to training mode
    for batch_indices in all_indices:
        input_ = tensor_tr[batch_indices]
        recon, loss = model(input_, compute_loss=True)
        # optimize
        optimizer.zero_grad()        # clear previous gradients
        loss.backward()              # backprop
        optimizer.step()             # update parameters
        # report
        loss_epoch += loss.item()    # add loss to loss_epoch
    if epoch % 5 == 0:
        print('Epoch {}, loss={}'.format(epoch, loss_epoch / len(all_indices)))

Epoch 0, loss=744.3255818684896
Epoch 5, loss=677.8445091916803
Epoch 10, loss=661.1274039285224
Epoch 15, loss=650.3963317871094
Epoch 20, loss=648.0864969889323
Epoch 25, loss=644.8758148728756
Epoch 30, loss=635.3875571803043
Epoch 35, loss=630.1503595720258
Epoch 40, loss=626.6581356650904
Epoch 45, loss=623.0975084806744
Epoch 50, loss=620.9949945817914
Epoch 55, loss=617.2237002724096
Epoch 60, loss=616.3330522503769
Epoch 65, loss=615.0380334686814
Epoch 70, loss=613.6359167266311
Epoch 75, loss=613.9560643246299
Epoch 80, loss=611.7194620768229
Epoch 85, loss=611.8647321734512
Epoch 90, loss=612.314353005928
Epoch 95, loss=609.612364651864


In [20]:
def print_perp(model):
    cost=[]
    model.eval()                        # switch to testing mode
    input_ = tensor_te
    recon, loss = model(input_, compute_loss=True, avg_loss=False)
    loss = loss.data
    counts = tensor_te.sum(1)
    avg = (loss / counts).mean()
    print('The approximated perplexity is: ', math.exp(avg))

def visualize():
    global recon
    input_ = tensor_te[:10]
    register_vis_hooks(model)
    recon = model(input_, compute_loss=False)
    remove_vis_hooks()
    save_visualization('pytorch_model', 'png')

In [28]:
def print_top_words(beta, feature_names, n_top_words=10):
    print ('---------------Printing the Topics------------------')
    for i in range(len(beta)):
        line = " ".join([feature_names[j] 
                         for j in beta[i].argsort()[:-n_top_words - 1:-1]])
        print('{}'.format(line))
    print ('---------------End of Topics------------------')

In [29]:
emb = model.de.linear.weight.data.cpu().numpy().T
print_top_words(emb, list(zip(*sorted(vocab.items(), key=lambda x:x[1])))[0])
print_perp(model)
# visualize()

---------------Printing the Topics------------------
pittsburgh louis jose pp minnesota boston philadelphia calgary la montreal
doctrine bible christian scripture explanation christianity biblical hang revelation tradition
encrypt pgp rsa xlib implementation xt distribution cipher cryptography toolkit
crypto outlet escrow ground motor encryption trip voltage panel hide
baseball hitter fan ball pitch sport braves stanley roger craig
armenian turkish genocide turks greece militia jews organize constitution army
cryptography security social responsibility privacy electronic anonymous rsa threat email
shipping launch sale ice remote speaker iii cd contract sport
hello appreciate hus fax email advance thanks anybody tel institute
shipping speaker oo honda rf sale mw mount ac bike
entry remark contest oname null output winner int rule printf
gateway swap meg ram window isa windows microsoft bus button
lebanon civilian village israel armenia israeli troops soldier lebanese armenians
morality 