In [20]:
import argparse
import torch
import pickle 
import numpy as np 
import os 
import math 
import random 
import sys
import matplotlib.pyplot as plt 
import scipy.io

from torch import nn, optim
from torch.nn import functional as F

sys.path.append('..')
import data
from etm import ETM
from utils import nearest_neighbors, get_topic_coherence, get_topic_diversity

np.random.seed(2019)
torch.manual_seed(2019)
if torch.cuda.is_available():
    torch.cuda.manual_seed(2019)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
## get data
# 1. vocabulary
vocab, train, valid, test = data.get_data(os.path.join('../data/20ng'))
vocab_size = len(vocab)

# 2. tokens, counts for train, dev and test set
train_tokens = train['tokens']
train_counts = train['counts']
num_docs_train = len(train_tokens)
valid_tokens = valid['tokens']
valid_counts = valid['counts']
num_docs_valid = len(valid_tokens)
test_tokens = test['tokens']
test_counts = test['counts']
num_docs_test = len(test_tokens)
test_1_tokens = test['tokens_1']
test_1_counts = test['counts_1']
num_docs_test_1 = len(test_1_tokens)
test_2_tokens = test['tokens_2']
test_2_counts = test['counts_2']
num_docs_test_2 = len(test_2_tokens)

In [21]:
len(train_tokens)

11214

In [22]:
len(valid_tokens)

100

In [23]:
len(test_tokens)

7531

In [24]:
len(vocab)

3072

In [25]:
ckpt = '../results/etm_20ng_K_50_Htheta_800_Optim_adam_Clip_0.0_ThetaAct_relu_Lr_0.005_Bsz_1000_RhoSize_300_trainEmbeddings_1'

In [26]:
dataset = '20ng'
num_topics = 50
t_hidden_size = 800 # encoding dimension
optimizer = 'adam'
clip = 0
theta_act = 'relu' # activation
lr = 0.005 # learning rate
wdecay=1.2e-6
enc_drop = 0.0 # drop out rate on encoder
batch_size = 1000
rho_size = 300 # dimension of rho, the word embedding?
emb_size = 300
train_embeddings = 1

embeddings = None
model = ETM(num_topics,          # the all-important number of topics
            vocab_size,          # vocab size is needed for input shape sizes, possibly redundant
            t_hidden_size,       # t_hidden_size is the size of document encoding
            rho_size,            # embedding size of word embedding 
            emb_size,            # embedding size of word embedding # redundant!
            theta_act,           # activation function (string)
            embeddings,          # prefit embeddings
            train_embeddings,    # binary, for whether to train embeddings
            enc_drop             # encoder dropout 
           ).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wdecay)

In [27]:
with open(ckpt, 'rb') as f:
    model = torch.load(f)
model = model.to(device)
model.eval();

In [28]:
acc_loss = 0
acc_kl_theta_loss = 0
cnt = 0

num_docs_train = len(train_tokens)
batch_size = 1000

indices = torch.randperm(num_docs_train)
indices = torch.split(indices, batch_size)
idx, ind = 0, indices[0]

data_batch = data.get_batch(train_tokens, train_counts, ind, vocab_size, device)
sums = data_batch.sum(1).unsqueeze(1)
normalized_data_batch = data_batch / sums

In [29]:
ind[0]

tensor(405)

Topic 0: ['writes', 'article', 'org', 'jim', 'virginia', 'stanford', 'computer', 'distribution', 'usa']  
Topic 1: ['mail', 'hp', 'line', 'mark', 'version', 'info', 'netcom', 'wrote', 'phone']  
Topic 2: ['mit', 'ibm', 'mil', 'group', 'apr', 'navy', 'newsgroup', 'marc', 'list']  
Topic 3: ['ftp', 'image', 'information', 'faq', 'pub', 'mail', 'data', 'software', 'graphics']  
Topic 4: ['israel', 'jews', 'turkish', 'israeli', 'people', 'armenian', 'armenians', 'turkey', 'turks']  
Topic 5: ['andrew', 'uiuc', 'cmu', 'colorado', 'university', 'michael', 'cso', 'writes', 'illinois']  
Topic 6: ['drive', 'card', 'scsi', 'video', 'mac', 'system', 'bit', 'disk', 'hard']  
Topic 7: ['good', 'writes', 'time', 'people', 'article', 'make', 'thing', 'things', 'back']  
Topic 8: ['posting', 'host', 'university', 'nntp', 'article', 'writes', 'ca', 'cs', 'distribution']  
Topic 9: ['writes', 'posting', 'article', 'university', 'host', 'nntp', 'ca', 'cs', 'distribution']  
Topic 10: ['good', 'people', 'writes', 'time', 'make', 'thing', 'things', 'article', 'give']  
Topic 11: ['good', 'time', 'writes', 'make', 'people', 'thing', 'back', 'bad', 'things']  
Topic 12: ['writes', 'article', 'org', 'jim', 'usa', 'distribution', 'computer', 'stanford', 'opinions']  
Topic 13: ['posting', 'host', 'nntp', 'ca', 'university', 'cc', 'distribution', 'article', 'cs']  
Topic 14: ['windows', 'dos', 'files', 'access', 'ms', 'program', 'file', 'pc', 'run']  
Topic 15: ['posting', 'host', 'nntp', 'university', 'ca', 'article', 'writes', 'cs', 'distribution']  
Topic 16: ['uk', 'de', 'ac', 'au', 'uni', 'university', 'germany', 'tu', 'australia']  
Topic 17: ['posting', 'host', 'university', 'nntp', 'ca', 'article', 'writes', 'cc', 'cs']  
Topic 18: ['god', 'jesus', 'people', 'christian', 'christians', 'bible', 'church', 'christ', 'faith']  
Topic 19: ['car', 'bike', 'dod', 'cars', 'engine', 'ride', 'road', 'front', 'speed']  
Topic 20: ['writes', 'article', 'good', 'time', 'people', 'thing', 'make', 'back', 'heard']  
Topic 21: ['max', 'ah', 'mr', 'air', 'ma', 'sp', 'cs', 'mi', 'tm']  
Topic 22: ['space', 'nasa', 'gov', 'launch', 'henry', 'toronto', 'earth', 'jpl', 'satellite']  
Topic 23: ['people', 'good', 'time', 'make', 'things', 'thing', 'work', 'real', 'lot']  
Topic 24: ['writes', 'article', 'distribution', 'usa', 'university', 'org', 'computer', 'reply', 'david']  
Topic 25: ['sale', 'price', 'offer', 'shipping', 'printer', 'sell', 'mail', 'condition', 'cd']  
Topic 26: ['posting', 'host', 'nntp', 'ca', 'university', 'cc', 'article', 'distribution', 'cs']  
Topic 27: ['people', 'told', 'home', 'time', 'started', 'left', 'building', 'back', 'day']  
Topic 28: ['people', 'good', 'time', 'make', 'things', 'thing', 'years', 'work', 'put']  
Topic 29: ['south', 'war', 'san', 'information', 'world', 'american', 'los', 'nuclear', 'southern']  
Topic 30: ['good', 'time', 'make', 'back', 'work', 'problem', 'give', 'ago', 'long']  
Topic 31: ['science', 'pitt', 'gordon', 'banks', 'cs', 'geb', 'food', 'article', 'disease']  
Topic 32: ['writes', 'good', 'time', 'article', 'people', 'make', 'thing', 'back', 'give']  
Topic 33: ['posting', 'university', 'host', 'nntp', 'writes', 'ca', 'article', 'cs', 'cc']  
Topic 34: ['health', 'national', 'research', 'university', 'care', 'insurance', 'medical', 'children', 'rate']  
Topic 35: ['good', 'time', 'make', 'problem', 'work', 'thing', 'back', 'level', 'find']  
Topic 36: ['question', 'problem', 'read', 'post', 'point', 'time', 'answer', 'find', 'good']  
Topic 37: ['power', 'water', 'wire', 'ground', 'current', 'circuit', 'high', 'cable', 'signal']  
Topic 38: ['team', 'game', 'games', 'year', 'play', 'hockey', 'season', 'players', 'win']  
Topic 39: ['netcom', 'hp', 'services', 'mail', 'newsreader', 'tin', 'mark', 'phone', 'version']  
Topic 40: ['ohio', 'cwru', 'keith', 'cleveland', 'caltech', 'state', 'freenet', 'acs', 'sgi']  
Topic 41: ['mail', 'software', 'information', 'line', 'find', 'version', 'net', 'number', 'info']  
Topic 42: ['key', 'encryption', 'chip', 'clipper', 'government', 'keys', 'security', 'public', 'law']  
Topic 43: ['writes', 'article', 'university', 'distribution', 'reply', 'usa', 'computer', 'org', 'cs']  
Topic 44: ['people', 'law', 'article', 'rights', 'writes', 'sex', 'men', 'evidence', 'cramer']  
Topic 45: ['book', 'books', 'question', 'theory', 'points', 'point', 'reference', 'find', 'time']  
Topic 46: ['file', 'window', 'program', 'set', 'display', 'color', 'server', 'application', 'code']  
Topic 47: ['good', 'time', 'make', 'thing', 'back', 'people', 'problem', 'work', 'things']  
Topic 48: ['good', 'people', 'time', 'make', 'thing', 'things', 'back', 'give', 'work']  
Topic 49: ['gun', 'people', 'mr', 'guns', 'government', 'president', 'clinton', 'weapons', 'control']  


Predict() method is as simple as :  
```
def predict_proba(self, normalized_bow):
    return F.softmax( model.encode(normalized_bow) ).argmax()
```
Input: 2D pytorch array
Output: probabilities
  
<font color=purple> Importantly, the choice of final activation layer has a lot of say in what kind of problem we want to solve.  
    So for instance, softmax is great for single label multiclass problem, but sigmoid can be used for mult-label problems. 

In [44]:
model.alphas(model.rho.weight).shape

torch.Size([3072, 50])

In [40]:
model.rho.weight.shape

torch.Size([3072, 300])

In [42]:
model.rho

Linear(in_features=300, out_features=3072, bias=False)

In [31]:
model.encode(normalized_data_batch[0])[0]

tensor([-0.1220, -0.2369,  0.1000, -0.1858, -0.0396,  1.5132, -0.2654, -0.1114,
        -0.1370, -0.1132, -0.0564, -0.1072, -0.1702, -0.1346, -0.1902, -0.1288,
        -0.1792, -0.1313,  4.9024, -0.2944, -0.0665, -0.2247, -0.2041, -0.0585,
        -0.1245, -0.1449, -0.1964, -0.0613, -0.0243, -0.1894, -0.1078,  0.5689,
        -0.0595, -0.2002, -0.1898, -0.0687, -0.0373, -0.2176, -0.3027, -0.2731,
        -0.2226, -0.2369, -0.1647, -0.0982, -0.0611, -0.1829, -0.1590, -0.0718,
        -0.0703, -0.1583], device='cuda:0', grad_fn=<AddBackward0>)

In [36]:
F.sigmoid( model.encode(normalized_data_batch[0])[0] )

tensor([0.4695, 0.4411, 0.5250, 0.4537, 0.4901, 0.8195, 0.4340, 0.4722, 0.4658,
        0.4717, 0.4859, 0.4732, 0.4576, 0.4664, 0.4526, 0.4678, 0.4553, 0.4672,
        0.9926, 0.4269, 0.4834, 0.4441, 0.4492, 0.4854, 0.4689, 0.4639, 0.4511,
        0.4847, 0.4939, 0.4528, 0.4731, 0.6385, 0.4851, 0.4501, 0.4527, 0.4828,
        0.4907, 0.4458, 0.4249, 0.4321, 0.4446, 0.4411, 0.4589, 0.4755, 0.4847,
        0.4544, 0.4603, 0.4820, 0.4824, 0.4605], device='cuda:0',
       grad_fn=<SigmoidBackward>)

In [37]:
F.softmax( model.encode(normalized_data_batch[0])[0] )

  """Entry point for launching an IPython kernel.


tensor([0.0049, 0.0043, 0.0061, 0.0046, 0.0053, 0.0250, 0.0042, 0.0049, 0.0048,
        0.0049, 0.0052, 0.0049, 0.0046, 0.0048, 0.0045, 0.0048, 0.0046, 0.0048,
        0.7403, 0.0041, 0.0051, 0.0044, 0.0045, 0.0052, 0.0049, 0.0048, 0.0045,
        0.0052, 0.0054, 0.0046, 0.0049, 0.0097, 0.0052, 0.0045, 0.0045, 0.0051,
        0.0053, 0.0044, 0.0041, 0.0042, 0.0044, 0.0043, 0.0047, 0.0050, 0.0052,
        0.0046, 0.0047, 0.0051, 0.0051, 0.0047], device='cuda:0',
       grad_fn=<SoftmaxBackward>)

In [34]:
print( "document words" )
[vocab[idx] for idx in train_tokens[ind[0]][0] ] 

document words


['modern',
 'law',
 'cs',
 'link',
 'longer',
 'sin',
 'days',
 'uga',
 'athens',
 'kind',
 'paul',
 'day',
 'food',
 'christians',
 'eat',
 'cc',
 'sunday',
 'recognized',
 'totally',
 'question',
 'keeping',
 'personal',
 'jesus',
 'athena',
 'authority',
 'telling',
 'holy',
 'judge',
 'requirements',
 'irrelevant',
 'people',
 'writes',
 'fourth',
 'acts',
 'command',
 'living',
 'jewish',
 'position',
 'article',
 'man',
 'regard',
 'georgia',
 'mailer',
 'romans',
 'relevant',
 'list',
 'jr',
 'problem',
 'university']

In [7]:
# # writing stops to a list in a py file
# with open('../scripts/stops.txt', 'r') as f:
#     stops = f.read()
# with open('stops.py', 'w') as f:
#     f.write('stops = ["' + stops.replace('\n', '",\n"')[:-1] + ']')
# from stops import stops