# User Latent Dirichlet Allocation

In [1]:
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import pickle
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('.'))
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f394bef4b50>

In [2]:
%load_ext autoreload
%autoreload 2

from lib.models import InductiveLDA, NeuralLDA, NeuralLLNA
from lib import utils as utils

Import Data

In [3]:
class Corpus:
    def __init__(self, datadir):
        filenames = ['train.txt.npy', 'test.txt.npy']
        self.datapaths = [os.path.join(datadir, x) for x in filenames]
        with open(os.path.join(datadir, 'vocab.pkl'), 'rb') as f:
            self.vocab = pickle.load(f)
        self.train, self.test = [
            Data(dp, len(self.vocab)) for dp in self.datapaths]


class Data:
    def __init__(self, datapath, vocab_size):
        data = np.load(datapath, allow_pickle=True, encoding='bytes')
        self.data = np.array([np.bincount(x.astype('int'), minlength=vocab_size) for x in data if np.sum(x)>0])
        self.documents = data
        
    @property
    def size(self):
        return len(self.data)
    
    def get_batch(self, batch_size, start_id=None):
        if start_id is None:
            batch_idx = np.random.choice(np.arange(self.size), batch_size)
        else:
            batch_idx = np.arange(start_id, start_id + batch_size)
        batch_data = self.data[batch_idx]
        data_tensor = torch.from_numpy(batch_data).float()
        return data_tensor

In [4]:
corpus = Corpus("../data/20news")

In [5]:
RANDOM_SEED = 2112
np.random.seed(RANDOM_SEED)

In [6]:
X_document = corpus.train.data

In [7]:
num_docs, vocab_size = X_document.shape
print(f"Number of documents: {num_docs}, vocab size: {vocab_size}")

Number of documents: 11258, vocab size: 1995


In [8]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, X): self.X = X
    def __len__(self): return len(self.X)
    def __getitem__(self, index): return self.X[index]

In [9]:
dataset = Dataset(torch.tensor(X_document).to(torch.float32))

In [25]:
num_topics = 50
prodlda = True
conv = False
num_layers = 0 # num hidden layers in the NN block. There are at least (1+num_heads) hidden layers in the encoder and decoder. num_layers comes on top of those. Total number of hidden layers is (1+num_layers+num_heads).
num_neurons = 100
dropout = True
dropout_rate = 0.25
batch_normalization = True
prior_param = {'alpha': 1.0}
decoder_temperature = 1.0
encoder_temperature = 1.0

In [26]:
# NeuralModel, prior_param = NeuralLLNA, {"mu": 0.0, "sigma": 1.0}
# NeuralModel = NeuralLDA
NeuralModel = InductiveLDA

In [27]:
model = NeuralModel(input_dim=vocab_size,
                    num_topics=num_topics,
                    prior_param=prior_param.copy(),
                    conv=conv,
                    prodlda=prodlda,
                    decoder_temperature=decoder_temperature,
                    num_hidden_layers=num_layers,
                    num_neurons=num_neurons,
                    dropout=dropout,
                    dropout_rate=dropout_rate,
                    batch_normalization=batch_normalization,
                    )

In [28]:
lr = 5e-3
batch_size = 200
num_epochs = 100
beta = 1.0
learn_prior = True

In [29]:
# Train the model using default partitioning choice 
model.fit(lr=lr,
            dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True),
            epochs = num_epochs,
            beta = beta,
            mc_samples = 1,
            learn_prior = learn_prior,
            tensorboard = False,
            )

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

                                                      
Epoch:   2%|▏         | 2/100 [00:03<01:52,  1.15s/it]             

Iteration: 100 -- ELBO=-5.48e+02 / RLL=-5.42e+02 / KL=6.66e+00


                                                      
Epoch:   4%|▍         | 4/100 [00:05<01:55,  1.20s/it]             

Iteration: 200 -- ELBO=-9.29e+02 / RLL=-9.21e+02 / KL=7.70e+00


                                                      
Epoch:   6%|▌         | 6/100 [00:07<01:54,  1.21s/it]             

Iteration: 300 -- ELBO=-4.83e+02 / RLL=-4.79e+02 / KL=3.94e+00


                                                      
Epoch:   8%|▊         | 8/100 [00:09<01:51,  1.22s/it]    

Iteration: 400 -- ELBO=-7.85e+02 / RLL=-7.81e+02 / KL=4.59e+00


                                                      
Epoch:   9%|▉         | 9/100 [00:12<01:55,  1.27s/it]             

Iteration: 500 -- ELBO=-5.45e+02 / RLL=-5.41e+02 / KL=3.94e+00


                                                       
Epoch:  11%|█         | 11/100 [00:14<02:07,  1.43s/it]            

Iteration: 600 -- ELBO=-6.05e+02 / RLL=-6.01e+02 / KL=4.08e+00


                                                       
Epoch:  13%|█▎        | 13/100 [00:16<01:54,  1.32s/it]            

Iteration: 700 -- ELBO=-6.38e+02 / RLL=-6.33e+02 / KL=5.09e+00


                                                       
Epoch:  15%|█▌        | 15/100 [00:19<01:48,  1.27s/it]   

Iteration: 800 -- ELBO=-5.65e+02 / RLL=-5.61e+02 / KL=4.07e+00


                                                       
Epoch:  16%|█▌        | 16/100 [00:21<01:45,  1.25s/it]            

Iteration: 900 -- ELBO=-4.77e+02 / RLL=-4.73e+02 / KL=3.16e+00


                                                       
Epoch:  18%|█▊        | 18/100 [00:23<01:41,  1.23s/it]            

Iteration: 1000 -- ELBO=-6.41e+02 / RLL=-6.36e+02 / KL=4.77e+00


                                                       
Epoch:  20%|██        | 20/100 [00:25<01:38,  1.23s/it]            

Iteration: 1100 -- ELBO=-6.27e+02 / RLL=-6.21e+02 / KL=6.66e+00


                                                       
Epoch:  22%|██▏       | 22/100 [00:27<01:35,  1.22s/it]   

Iteration: 1200 -- ELBO=-7.25e+02 / RLL=-7.18e+02 / KL=7.35e+00


                                                       
Epoch:  23%|██▎       | 23/100 [00:29<01:34,  1.23s/it]            

Iteration: 1300 -- ELBO=-6.74e+02 / RLL=-6.66e+02 / KL=8.53e+00


                                                       
Epoch:  25%|██▌       | 25/100 [00:31<01:32,  1.23s/it]            

Iteration: 1400 -- ELBO=-6.51e+02 / RLL=-6.44e+02 / KL=7.59e+00


                                                       
Epoch:  27%|██▋       | 27/100 [00:34<01:29,  1.23s/it]            

Iteration: 1500 -- ELBO=-5.00e+02 / RLL=-4.93e+02 / KL=7.39e+00


                                                       
Epoch:  29%|██▉       | 29/100 [00:36<01:27,  1.23s/it]   

Iteration: 1600 -- ELBO=-6.03e+02 / RLL=-5.95e+02 / KL=7.86e+00


                                                       
Epoch:  30%|███       | 30/100 [00:38<01:26,  1.24s/it]            

Iteration: 1700 -- ELBO=-6.26e+02 / RLL=-6.16e+02 / KL=1.01e+01


                                                       
Epoch:  32%|███▏      | 32/100 [00:40<01:24,  1.25s/it]            

Iteration: 1800 -- ELBO=-7.64e+02 / RLL=-7.53e+02 / KL=1.05e+01


                                                       
Epoch:  34%|███▍      | 34/100 [00:42<01:21,  1.24s/it]            

Iteration: 1900 -- ELBO=-5.63e+02 / RLL=-5.54e+02 / KL=9.55e+00


                                                       
Epoch:  36%|███▌      | 36/100 [00:45<01:19,  1.24s/it]   

Iteration: 2000 -- ELBO=-6.07e+02 / RLL=-5.96e+02 / KL=1.08e+01


                                                       
Epoch:  38%|███▊      | 38/100 [00:47<01:17,  1.24s/it]            

Iteration: 2100 -- ELBO=-4.49e+02 / RLL=-4.40e+02 / KL=9.07e+00


                                                       
Epoch:  39%|███▉      | 39/100 [00:49<01:15,  1.24s/it]            

Iteration: 2200 -- ELBO=-7.49e+02 / RLL=-7.35e+02 / KL=1.39e+01


                                                       
Epoch:  41%|████      | 41/100 [00:51<01:13,  1.25s/it]            

Iteration: 2300 -- ELBO=-5.82e+02 / RLL=-5.71e+02 / KL=1.11e+01


                                                       
Epoch:  43%|████▎     | 43/100 [00:53<01:10,  1.24s/it]           

Iteration: 2400 -- ELBO=-5.94e+02 / RLL=-5.82e+02 / KL=1.18e+01


                                                       
Epoch:  44%|████▍     | 44/100 [00:56<01:09,  1.25s/it]            

Iteration: 2500 -- ELBO=-6.24e+02 / RLL=-6.13e+02 / KL=1.11e+01


                                                       
Epoch:  46%|████▌     | 46/100 [00:58<01:09,  1.28s/it]            

Iteration: 2600 -- ELBO=-6.96e+02 / RLL=-6.83e+02 / KL=1.24e+01


                                                       
Epoch:  48%|████▊     | 48/100 [01:00<01:06,  1.27s/it]            

Iteration: 2700 -- ELBO=-5.63e+02 / RLL=-5.52e+02 / KL=1.10e+01


                                                       
Epoch:  50%|█████     | 50/100 [01:02<01:02,  1.26s/it]           


Iteration: 2800 -- ELBO=-4.71e+02 / RLL=-4.61e+02 / KL=1.05e+01


                                                       , 46.91it/s][A
Epoch:  52%|█████▏    | 52/100 [01:05<01:00,  1.25s/it]            

Iteration: 2900 -- ELBO=-5.03e+02 / RLL=-4.92e+02 / KL=1.04e+01


                                                       
Epoch:  53%|█████▎    | 53/100 [01:07<00:59,  1.26s/it]            

Iteration: 3000 -- ELBO=-5.62e+02 / RLL=-5.50e+02 / KL=1.18e+01


                                                       
Epoch:  55%|█████▌    | 55/100 [01:09<00:56,  1.26s/it]            

Iteration: 3100 -- ELBO=-6.15e+02 / RLL=-6.03e+02 / KL=1.21e+01


                                                       
Epoch:  57%|█████▋    | 57/100 [01:11<00:53,  1.25s/it]           

Iteration: 3200 -- ELBO=-5.78e+02 / RLL=-5.66e+02 / KL=1.17e+01


                                                       
Epoch:  59%|█████▉    | 59/100 [01:13<00:51,  1.26s/it]            

Iteration: 3300 -- ELBO=-5.91e+02 / RLL=-5.80e+02 / KL=1.14e+01


                                                       
Epoch:  60%|██████    | 60/100 [01:15<00:50,  1.26s/it]            

Iteration: 3400 -- ELBO=-5.93e+02 / RLL=-5.82e+02 / KL=1.07e+01


                                                       
Epoch:  62%|██████▏   | 62/100 [01:18<00:47,  1.26s/it]            

Iteration: 3500 -- ELBO=-5.45e+02 / RLL=-5.34e+02 / KL=1.06e+01


                                                       
Epoch:  64%|██████▍   | 64/100 [01:20<00:45,  1.25s/it]           

Iteration: 3600 -- ELBO=-4.77e+02 / RLL=-4.66e+02 / KL=1.06e+01


                                                       
Epoch:  66%|██████▌   | 66/100 [01:22<00:42,  1.26s/it]            

Iteration: 3700 -- ELBO=-5.12e+02 / RLL=-5.01e+02 / KL=1.09e+01


                                                       
Epoch:  67%|██████▋   | 67/100 [01:24<00:41,  1.26s/it]            

Iteration: 3800 -- ELBO=-7.38e+02 / RLL=-7.27e+02 / KL=1.18e+01


                                                       
Epoch:  69%|██████▉   | 69/100 [01:26<00:38,  1.25s/it]            

Iteration: 3900 -- ELBO=-4.88e+02 / RLL=-4.77e+02 / KL=1.05e+01


                                                       
Epoch:  71%|███████   | 71/100 [01:29<00:36,  1.25s/it]           

Iteration: 4000 -- ELBO=-5.42e+02 / RLL=-5.31e+02 / KL=1.08e+01


                                                       
Epoch:  73%|███████▎  | 73/100 [01:31<00:33,  1.25s/it]            

Iteration: 4100 -- ELBO=-4.99e+02 / RLL=-4.89e+02 / KL=1.06e+01


                                                       
Epoch:  74%|███████▍  | 74/100 [01:33<00:32,  1.26s/it]            

Iteration: 4200 -- ELBO=-4.75e+02 / RLL=-4.64e+02 / KL=1.04e+01


                                                       
Epoch:  76%|███████▌  | 76/100 [01:35<00:30,  1.27s/it]            

Iteration: 4300 -- ELBO=-6.38e+02 / RLL=-6.27e+02 / KL=1.11e+01


                                                       
Epoch:  78%|███████▊  | 78/100 [01:38<00:28,  1.27s/it]           

Iteration: 4400 -- ELBO=-6.54e+02 / RLL=-6.42e+02 / KL=1.15e+01


                                                       
Epoch:  80%|████████  | 80/100 [01:40<00:25,  1.26s/it]            

Iteration: 4500 -- ELBO=-5.05e+02 / RLL=-4.94e+02 / KL=1.08e+01


                                                       
Epoch:  81%|████████  | 81/100 [01:42<00:23,  1.26s/it]            

Iteration: 4600 -- ELBO=-6.18e+02 / RLL=-6.08e+02 / KL=1.08e+01


                                                       
Epoch:  83%|████████▎ | 83/100 [01:44<00:21,  1.27s/it]            

Iteration: 4700 -- ELBO=-7.47e+02 / RLL=-7.36e+02 / KL=1.13e+01


                                                       
Epoch:  85%|████████▌ | 85/100 [01:47<00:19,  1.27s/it]            

Iteration: 4800 -- ELBO=-5.74e+02 / RLL=-5.63e+02 / KL=1.09e+01


Epoch:  87%|████████▋ | 87/100 [01:49<00:16,  1.27s/it]

Iteration: 4900 -- ELBO=-7.00e+02 / RLL=-6.89e+02 / KL=1.08e+01


                                                       
Epoch:  88%|████████▊ | 88/100 [01:51<00:15,  1.27s/it]            

Iteration: 5000 -- ELBO=-9.10e+02 / RLL=-8.98e+02 / KL=1.20e+01


                                                       
Epoch:  90%|█████████ | 90/100 [01:53<00:12,  1.26s/it]            


Iteration: 5100 -- ELBO=-5.13e+02 / RLL=-5.03e+02 / KL=1.04e+01


                                                       , 45.05it/s][A
Epoch:  92%|█████████▏| 92/100 [01:55<00:10,  1.26s/it]            

Iteration: 5200 -- ELBO=-5.51e+02 / RLL=-5.41e+02 / KL=1.00e+01


                                                       
Epoch:  94%|█████████▍| 94/100 [01:58<00:07,  1.27s/it]            

Iteration: 5300 -- ELBO=-6.46e+02 / RLL=-6.36e+02 / KL=1.07e+01


                                                       
Epoch:  95%|█████████▌| 95/100 [02:00<00:06,  1.26s/it]            

Iteration: 5400 -- ELBO=-6.00e+02 / RLL=-5.89e+02 / KL=1.08e+01


                                                       
Epoch:  97%|█████████▋| 97/100 [02:02<00:03,  1.26s/it]            

Iteration: 5500 -- ELBO=-7.28e+02 / RLL=-7.17e+02 / KL=1.12e+01


                                                       
Epoch:  99%|█████████▉| 99/100 [02:04<00:01,  1.26s/it]            

Iteration: 5600 -- ELBO=-6.87e+02 / RLL=-6.76e+02 / KL=1.10e+01


                                                        
Epoch: 100%|██████████| 100/100 [02:06<00:00,  1.27s/it]           
Iteration in Epoch: 100%|██████████| 57/57 [00:01<00:00, 45.15it/s]

Iteration: 5700 -- ELBO=-4.69e+02 / RLL=-4.60e+02 / KL=9.87e+00





In [30]:
model.get_prior_params()

tensor([7.8248e-01, 8.5697e-04, 2.8882e-03, 8.3738e-01, 7.6415e-04, 9.1559e-01,
        7.2740e-01, 5.1719e-01, 5.8330e-04, 1.0756e-03, 1.3546e-03, 4.2201e-01,
        8.2058e-04, 9.3865e-01, 6.5549e-01, 5.0389e-03, 8.6805e-03, 5.8849e-04,
        9.9697e-04, 7.9068e-04, 7.6066e-04, 1.1435e-03, 6.6645e-04, 1.0982e-03,
        2.4720e-01, 1.2482e-03, 6.9792e-03, 8.4468e-04, 1.1672e-03, 7.0699e-04,
        9.9519e-04, 5.5826e-04, 6.2155e-01, 1.4910e-02, 8.2116e-04, 2.3724e-02,
        6.4800e-01, 1.0430e-02, 5.6161e-04, 9.4454e-03, 4.5106e-02, 7.0446e-01,
        1.4588e-03, 8.6612e-04, 8.2152e-04, 7.6196e-03, 7.3882e-01, 4.3006e-03,
        7.4542e-04, 6.6248e-01], grad_fn=<SoftplusBackward0>)

In [31]:
model.decoder.get_beta().detach()

tensor([[3.2186e-08, 2.1772e-03, 1.2023e-03,  ..., 9.7794e-07, 3.6909e-07,
         1.1403e-08],
        [2.8716e-06, 2.5772e-02, 1.9077e-03,  ..., 1.3827e-05, 3.5287e-06,
         3.5271e-06],
        [1.8509e-05, 1.5041e-03, 6.6409e-04,  ..., 7.3777e-06, 6.2625e-06,
         2.5380e-06],
        ...,
        [1.8513e-06, 7.4842e-03, 4.8255e-03,  ..., 1.2720e-05, 3.0170e-06,
         2.9274e-06],
        [1.6940e-06, 1.1498e-02, 2.6513e-03,  ..., 3.4772e-05, 1.3194e-05,
         3.1730e-06],
        [1.0509e-08, 2.7682e-03, 3.2106e-03,  ..., 3.1044e-08, 2.5065e-07,
         1.9819e-08]])

In [32]:
model.eval()

InductiveLDA(
  (encoder): DirichletNN(
    (parameterizer): ParameterizerNN(
      (block_dict): ModuleDict(
        (input): NNBlock(
          (input_layer): Sequential(
            (0): Linear(in_features=1995, out_features=100, bias=True)
            (1): Softplus(beta=1, threshold=20)
          )
          (middle_layers): ModuleList()
          (output_layer): Sequential(
            (0): Linear(in_features=100, out_features=100, bias=True)
            (1): Softplus(beta=1, threshold=20)
          )
        )
        (alpha): NNBlock(
          (input_layer): Sequential(
            (0): Linear(in_features=100, out_features=100, bias=True)
            (1): Softplus(beta=1, threshold=20)
            (2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (3): Dropout(p=0.25, inplace=False)
          )
          (middle_layers): ModuleList()
          (output_layer): Sequential(
            (0): Linear(in_features=100, out_features=50, bia

In [33]:
from gensim.models import CoherenceModel
from gensim.corpora.dictionary import Dictionary
import gensim.corpora as corpora

In [34]:
## turn the encoding in corpus.train.documents into a list of list of words using the vocab
id2word = {v: k for k, v in corpus.vocab.items()}
texts = [[id2word[i] for i in doc] for doc in corpus.train.documents]

In [35]:
Beta = model.decoder.get_beta().detach().numpy()
# Beta = model.decoder.beta_unnorm.detach().numpy()

In [36]:
result = {}
result["topic-word-matrix"] = Beta

top_k = 10

if top_k > 0:
    topics_output = []
    for topic in result["topic-word-matrix"]:
        top_k_words = list(reversed([id2word[i] for i in np.argsort(topic)[-top_k:]]))
        topics_output.append(top_k_words)
    result["topics"] = topics_output

In [37]:
result["topics"]

[['ide',
  'scsi',
  'bus',
  'isa',
  'scsus',
  'controller',
  'simm',
  'drive',
  'team',
  'mhz'],
 ['article',
  'mean',
  'write',
  'good',
  'think',
  'anyone',
  'player',
  'way',
  'claim',
  'please'],
 ['dos',
  'world',
  'cd',
  'war',
  'ii',
  'family',
  'great',
  'money',
  'library',
  'business'],
 ['bike',
  'ride',
  'game',
  'fire',
  'dog',
  'helmet',
  'cop',
  'dod',
  'armenians',
  'hit'],
 ['anyone',
  'fax',
  'article',
  'think',
  'please',
  'know',
  'gun',
  'thing',
  'way',
  'car'],
 ['drive',
  'car',
  'connector',
  'monitor',
  'voltage',
  'cable',
  'pin',
  'printer',
  'floppy',
  'sale'],
 ['god',
  'eternal',
  'ride',
  'drug',
  'bike',
  'christianity',
  'life',
  'faith',
  'hell',
  'christian'],
 ['armenian',
  'turkish',
  'armenians',
  'israel',
  'armenia',
  'turks',
  'israeli',
  'genocide',
  'village',
  'lebanese'],
 ['write',
  'god',
  'israeli',
  'human',
  'life',
  'faith',
  'want',
  'tell',
  'article',
 

In [38]:
# Initialize metric
npmi = CoherenceModel(
    topics=result["topics"],
    texts=texts,
    corpus=corpus.train.data,
    dictionary=Dictionary(texts),
    coherence="c_npmi",
    topn=top_k)

In [39]:
npmi.get_coherence()

0.06962826290808065