In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import Parameter

In [2]:
from pytorch_visualize import *

In [3]:
# import sys
# !{sys.executable} -m pip install https://download.pytorch.org/whl/cu75/torch-0.1.12.post1-cp36-cp36m-linux_x86_64.whl

In [4]:
import numpy as np
import math
from pprint import pprint, pformat
import pickle
import argparse
import os
import math
import matplotlib.pyplot as plt

In [5]:
class ProdLDA(nn.Module):

    def __init__(self, net_arch):
        super(ProdLDA, self).__init__()
        ac = net_arch
        self.net_arch = net_arch
        # encoder
        self.en1_fc     = nn.Linear(ac.num_input, ac.en1_units)             # 1995 -> 100
        self.en2_fc     = nn.Linear(ac.en1_units, ac.en2_units)             # 100  -> 100
        self.en2_drop   = nn.Dropout(0.2)
        self.mean_fc    = nn.Linear(ac.en2_units, ac.num_topic)             # 100  -> 50
        self.mean_bn    = nn.BatchNorm1d(ac.num_topic)                      # bn for mean
        self.logvar_fc  = nn.Linear(ac.en2_units, ac.num_topic)             # 100  -> 50
        self.logvar_bn  = nn.BatchNorm1d(ac.num_topic)                      # bn for logvar
        # z
        self.p_drop     = nn.Dropout(0.2)
        # decoder
        self.decoder    = nn.Linear(ac.num_topic, ac.num_input)             # 50   -> 1995
        self.decoder_bn = nn.BatchNorm1d(ac.num_input)                      # bn for decoder
        # prior mean and variance as constant buffers
        prior_mean   = torch.Tensor(1, ac.num_topic).fill_(0)
        prior_var    = torch.Tensor(1, ac.num_topic).fill_(ac.variance)
        prior_logvar = prior_var.log()
        self.register_buffer('prior_mean',    prior_mean)
        self.register_buffer('prior_var',     prior_var)
        self.register_buffer('prior_logvar',  prior_logvar)
        # initialize decoder weight
        if ac.init_mult != 0:
            #std = 1. / math.sqrt( ac.init_mult * (ac.num_topic + ac.num_input))
            self.decoder.weight.data.uniform_(0, ac.init_mult)
        # remove BN's scale parameters
        self.logvar_bn .register_parameter('weight', None)
        self.mean_bn   .register_parameter('weight', None)
        self.decoder_bn.register_parameter('weight', None)
        self.decoder_bn.register_parameter('weight', None)

    def forward(self, input, compute_loss=False, avg_loss=True):
        # compute posterior
        en1 = F.softplus(self.en1_fc(input))                            # en1_fc   output
        en2 = F.softplus(self.en2_fc(en1))                              # encoder2 output
        en2 = self.en2_drop(en2)
        posterior_mean   = self.mean_bn  (self.mean_fc  (en2))          # posterior mean
        posterior_logvar = self.logvar_bn(self.logvar_fc(en2))          # posterior log variance
        posterior_var    = posterior_logvar.exp()
        # take sample
        eps = Variable(input.data.new().resize_as_(posterior_mean.data).normal_()) # noise
        z = posterior_mean + posterior_var.sqrt() * eps                 # reparameterization
        p = F.softmax(z)                                                # mixture probability
        p = self.p_drop(p)
        # do reconstruction
        recon = F.softmax(self.decoder_bn(self.decoder(p)))             # reconstructed distribution over vocabulary

        if compute_loss:
            return recon, self.loss(input, recon, posterior_mean, posterior_logvar, posterior_var, avg_loss)
        else:
            return recon

    def loss(self, input, recon, posterior_mean, posterior_logvar, posterior_var, avg=True):
        # NL
        NL  = -(input * (recon+1e-10).log()).sum(1)
        # KLD, see Section 3.3 of Akash Srivastava and Charles Sutton, 2017, 
        # https://arxiv.org/pdf/1703.01488.pdf
        prior_mean   = Variable(self.prior_mean).expand_as(posterior_mean)
        prior_var    = Variable(self.prior_var).expand_as(posterior_mean)
        prior_logvar = Variable(self.prior_logvar).expand_as(posterior_mean)
        var_division    = posterior_var  / prior_var
        diff            = posterior_mean - prior_mean
        diff_term       = diff * diff / prior_var
        logvar_division = prior_logvar - posterior_logvar
        # put KLD together
        KLD = 0.5 * ( (var_division + diff_term + logvar_division).sum(1) - self.net_arch.num_topic )
        # loss
        loss = (NL + KLD)
        # in traiming mode, return averaged loss. In testing mode, return individual loss
        if avg:
            return loss.mean()
        else:
            return loss



In [6]:
def to_onehot(data, min_length):
    return np.bincount(data, minlength=min_length)

In [7]:
def print_perp(model):
    cost=[]
    model.eval()                        # switch to testing mode
    input = Variable(tensor_te)
    recon, loss = model(input, compute_loss=True, avg_loss=False)
    loss = loss.data
    counts = tensor_te.sum(1)
    avg = (loss / counts).mean()
    print('The approximated perplexity is: ', math.exp(avg))

def visualize():
    global recon
    input = Variable(tensor_te[:10])
    register_vis_hooks(model)
    recon = model(input, compute_loss=False)
    remove_vis_hooks()
    save_visualization('pytorch_model', 'png')

In [8]:
dataset_tr = 'data/20news_clean/train.txt.npy'
data_tr = np.load(dataset_tr, encoding="latin1")
dataset_te = 'data/20news_clean/test.txt.npy'
data_te = np.load(dataset_te, encoding="latin1")
vocab = 'data/20news_clean/vocab.pkl'
vocab_file = open(vocab,'rb')
vocab = pickle.load(vocab_file)
vocab_size=len(vocab)

In [9]:
#--------------convert to one-hot representation------------------
print ('Converting data to one-hot representation')
data_tr = np.array([to_onehot(doc.astype('int'),vocab_size) for doc in data_tr if np.sum(doc)!=0])
data_te = np.array([to_onehot(doc.astype('int'),vocab_size) for doc in data_te if np.sum(doc)!=0])
#--------------print the data dimentions--------------------------
print ('Data Loaded')
print ('Dim Training Data', data_tr.shape)
print ('Dim Test Data', data_te.shape)

Converting data to one-hot representation
Data Loaded
Dim Training Data (11258, 1995)
Dim Test Data (7487, 1995)


In [10]:
#--------------make tensor datasets-------------------------------
tensor_tr = torch.from_numpy(data_tr).float()
tensor_te = torch.from_numpy(data_te).float()

In [11]:
class args():
    en1_units=100
    en2_units=100
    num_topic=50
    num_input=1995
    variance=0.995
    init_mult=1.0
    learning_rate=0.002
    batch_size=200
    momentum=0.99
    num_epoch=80
    nogpu=True

In [12]:
net_arch = args # en1_units, en2_units, num_topic, num_input
net_arch.num_input = data_tr.shape[1]
model = ProdLDA(net_arch)

In [13]:
# device = torch.device("cpu")

In [14]:
# model.to(device)

In [15]:
optimizer = torch.optim.Adam(model.parameters(), args.learning_rate, betas=(args.momentum, 0.999))
# optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=args.momentum)

In [16]:
for epoch in range(args.num_epoch):
    all_indices = torch.randperm(tensor_tr.size(0)).split(args.batch_size)
    loss_epoch = 0.0
    model.train()                   # switch to training mode
    for batch_indices in all_indices:
#         if not args.nogpu: batch_indices = batch_indices.cuda()
        input = Variable(tensor_tr[batch_indices])
        recon, loss = model(input, compute_loss=True)
        # optimize
        optimizer.zero_grad()       # clear previous gradients
        loss.backward()             # backprop
        optimizer.step()            # update parameters
        # report
        loss_epoch += loss.data[0]    # add loss to loss_epoch
    if epoch % 5 == 0:
        print('Epoch {}, loss={}'.format(epoch, loss_epoch / len(all_indices)))

Epoch 0, loss=752.4367258172286
Epoch 5, loss=675.5377936112253
Epoch 10, loss=660.8357276247259
Epoch 15, loss=658.2979083144874
Epoch 20, loss=644.949268541838
Epoch 25, loss=636.1256954795435
Epoch 30, loss=630.2199583890146
Epoch 35, loss=625.2974773206209
Epoch 40, loss=621.7331098589981
Epoch 45, loss=623.0270417865954
Epoch 50, loss=618.1221174273575
Epoch 55, loss=616.5587902403714
Epoch 60, loss=613.2184491073876
Epoch 65, loss=611.7819711785568
Epoch 70, loss=610.1798170658581
Epoch 75, loss=612.7063925224438


In [17]:
def print_perp(model):
    cost=[]
    model.eval()                        # switch to testing mode
    input = Variable(tensor_te)
    recon, loss = model(input, compute_loss=True, avg_loss=False)
    loss = loss.data
    counts = tensor_te.sum(1)
    avg = (loss / counts).mean()
    print('The approximated perplexity is: ', math.exp(avg))

def visualize():
    global recon
    input = Variable(tensor_te[:10])
    register_vis_hooks(model)
    recon = model(input, compute_loss=False)
    remove_vis_hooks()
    save_visualization('pytorch_model', 'png')

In [18]:
def identify_topic_in_line(line):
    topics = []
    for topic, keywords in associations.items():
        for word in keywords:
            if word in line:
                topics.append(topic)
                break
    return topics

In [19]:
def print_top_words(beta, feature_names, n_top_words=10):
    print ('---------------Printing the Topics------------------')
    for i in range(len(beta)):
        line = " ".join([feature_names[j] 
                            for j in beta[i].argsort()[:-n_top_words - 1:-1]])
        topics = identify_topic_in_line(line)
        print('|'.join(topics))
        print('     {}'.format(line))
    print ('---------------End of Topics------------------')

In [20]:
associations = {
    'jesus': ['prophet', 'jesus', 'matthew', 'christ', 'worship', 'church'],
    'comp ': ['floppy', 'windows', 'microsoft', 'monitor', 'workstation', 'macintosh', 
              'printer', 'programmer', 'colormap', 'scsi', 'jpeg', 'compression'],
    'car  ': ['wheel', 'tire'],
    'polit': ['amendment', 'libert', 'regulation', 'president'],
    'crime': ['violent', 'homicide', 'rape'],
    'midea': ['lebanese', 'israel', 'lebanon', 'palest'],
    'sport': ['coach', 'hitter', 'pitch'],
    'gears': ['helmet', 'bike'],
    'nasa ': ['orbit', 'spacecraft'],
}

In [21]:
emb = model.decoder.weight.data.cpu().numpy().T
print_top_words(emb, list(zip(*sorted(vocab.items(), key=lambda x:x[1])))[0])
print_perp(model)
visualize()

---------------Printing the Topics------------------
comp 
     scsus floppy jumper scsi controller microsoft backup cache disk ram
jesus
     turks armenians armenian village troops town massacre jesus azerbaijan turkish

     anonymous electronic digital internet service account site amiga abuse responsibility

     batf compound waco insurance country kid fbi clinton cop assault
polit
     militia sentence constitution amendment arm shall weapon organize regulation states

     amp speaker cd dos shipping remote adapter audio channel manual

     det tor min pit nj cal que mon buf calgary

     gm mw wings st wm mg md vs rangers june
midea
     arab israel arabs palestinian jews israeli jew francisco territory land
gears
     honda bike helmet gear mouse rear ford ride craig front
jesus
     religious atheism catholic doctrine atheist tradition pope church teaching god

     wire connector panel outlet wiring cable connect jumper ground pin

     armenia armenians art massacre montr

ExecutableNotFound: failed to execute ['dot', '-Tpng', '-O', 'pytorch_model'], make sure the Graphviz executables are on your systems' PATH