In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
import torch.nn.functional as F
import math

import TopicVAE

from sklearn.datasets import fetch_20newsgroups
import numpy as np
import torch
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer

import argparse
from types import SimpleNamespace

import gensim.downloader as api
from gensim.models import Word2Vec, FastText, KeyedVectors
from os.path import isfile

import tools

import random
random.seed(1234)

import pandas as pd

Import the data (20NewsGroups) and make the doc-term matrix, which is the input to all of our models

In [2]:
newsgroups_train = fetch_20newsgroups(subset='train')

vectorizer = CountVectorizer(stop_words = 'english', min_df=.01, max_df=0.9, 
                             token_pattern = u'(?ui)\\b[a-z]{3,}\\b')
count_vecs = vectorizer.fit_transform(newsgroups_train.data)
doc_term_matrix = count_vecs.toarray()
doc_term_matrix.shape # number of documents, number of words (in vocab)

# note: vectorizer.get_feature_names() != vectorizer.vocabulary_

doc_term_matrix_tensor = torch.from_numpy(doc_term_matrix).float()

args_dict = {"en1_units" : 100, "en2_units" : 100, "num_topic" : 50, 
             "batch_size" : 200, "optimizer" : 80, "learning_rate" : 0.002, 
             "momentum" : 0.99, "num_epoch" : 80, "init_mult" : 1, 
             "variance" : 0.995, "start" : True, "nogpu" : True, 
             "embedding_dim" : 300, "freeze" : False}
args = SimpleNamespace(**args_dict)
args.num_input = doc_term_matrix_tensor.shape[1]

In [3]:
newsgroups_test = fetch_20newsgroups(subset='test')

count_vecs_test = vectorizer.transform(newsgroups_test.data)
doc_term_matrix_test = count_vecs_test.toarray()

# note: vectorizer.get_feature_names() != vectorizer.vocabulary_

doc_term_tensor_test = torch.from_numpy(doc_term_matrix_test).float()

# Experiments

## Getting Pretrained Vectors (20NewsGroups)

In [4]:
### make input to language models (word2vec, fasttext, etc.) ###

# we would do some more preprocessing later
newsgroups_train_preproc = []
for document in newsgroups_train.data:
    newsgroups_train_preproc.append(document.split())
    
# dict_word_freq = dict(zip(vectorizer.get_feature_names(), list(doc_term_matrix.sum(0))))

### Word2Vec: Skip-Gram

In [5]:
### make language model using word2vec ###

w2v = Word2Vec(sg=1, negative=5, size=300, window=10, min_count=1, max_vocab_size=None, seed=1, workers=1)
lm_w2v_20newsgroups = tools.create_language_model("lm_w2v_20newsgroups", w2v, doc_term_matrix,
                                            vectorizer.get_feature_names(), 
                                            sentences = newsgroups_train_preproc)

### get embedding matrix for word2vec language model trained on 20newsgroups ###
embedding_matrix_w2v_20newsgroups = tools.create_embedding_matrix(lm_w2v_20newsgroups, 
                                                                  vectorizer.get_feature_names())


### FastText: Skip-Gram

In [6]:
fasttext = FastText(sg=1, negative=5,size=300, window=10, min_count=1, max_vocab_size=None, seed=1, workers=1)
lm_fasttext_20newsgroups = tools.create_language_model("lm_fasttext_20newsgroups", fasttext, doc_term_matrix,
                                                       vectorizer.get_feature_names(), 
                                                       sentences = newsgroups_train_preproc)

### get embedding matrix for word2vec language model trained on 20newsgroups ###
embedding_matrix_fasttext_20newsgroups = tools.create_embedding_matrix(lm_fasttext_20newsgroups, 
                                                                       vectorizer.get_feature_names())


### Word2Vec: CBOW

In [7]:
### make language model using word2vec ###

w2v_cbow = Word2Vec(sg=0, negative=5, size=300, window=10, min_count=1, max_vocab_size=None, seed=1, workers=1)
lm_w2v_cbow_20newsgroups = tools.create_language_model("lm_w2v_cbow_20newsgroups", w2v_cbow, doc_term_matrix,
                                                       vectorizer.get_feature_names(), sentences = newsgroups_train_preproc)

### get embedding matrix for word2vec language model trained on 20newsgroups ###
embedding_matrix_w2v_cbow_20newsgroups = tools.create_embedding_matrix(lm_w2v_cbow_20newsgroups, 
                                                                  vectorizer.get_feature_names())


### FastText: CBOW

In [8]:
fasttext_cbow = FastText(sg=0, negative=5,size=300, window=10, min_count=1, max_vocab_size=None, seed=1, workers=1)
lm_fasttext_cbow_20newsgroups = tools.create_language_model("lm_fasttext_cbow_20newsgroups", fasttext_cbow,
                                                            doc_term_matrix, vectorizer.get_feature_names(), 
                                                            sentences = newsgroups_train_preproc)

### get embedding matrix for word2vec language model trained on 20newsgroups ###
embedding_matrix_fasttext_cbow_20newsgroups = tools.create_embedding_matrix(lm_fasttext_cbow_20newsgroups, 
                                                                       vectorizer.get_feature_names())


## Getting Pretrained Vectors (trained on outside)

### FastText: from Wiki

In [9]:
# pretrained_language_model = api.load("fasttext-wiki-news-subwords-300")
# pretrained_language_model.save("fasttext-wiki-news-subwords-300")

lm_fasttext_wiki = KeyedVectors.load("fasttext-wiki-news-subwords-300")

embedding_matrix_fasttext_wiki = np.random.randn(len(vectorizer.get_feature_names()), 300)
iterator = 0
for word in vectorizer.get_feature_names():
    if word in lm_fasttext_wiki.wv.vocab:
        embedding_matrix_fasttext_wiki[iterator] = lm_fasttext_wiki.wv.word_vec(word)
    else:
        continue
        # embedding_matrix2[iterator] = pretrained_language_model.wv.most_similar(word)
        # or something like that
    iterator += 1




### Word2Vec: from ???

## "Miao" with Pretrained Vectors (on 20NewsGroups)

This isn't actually the method described in Miao et. al., since the encoder is different (it's not MLP)

### Word2Vec: Skip-Gram

In [10]:
if isfile("model_GSMLDA_w2v_20news"):
    model_GSMLDA_w2v_20news = torch.load("model_GSMLDA_w2v_20news")
else:
    model_GSMLDA_w2v_20news = TopicVAE.GSMLDA(args, embedding_matrix_w2v_20newsgroups)
    optimizer_GSMLDA_w2v_20news = torch.optim.Adam(model_GSMLDA_w2v_20news.parameters(), args.learning_rate, 
                                            betas=(args.momentum, 0.999))
    model_GSMLDA_w2v_20news = TopicVAE.train(model_GSMLDA_w2v_20news, args, optimizer_GSMLDA_w2v_20news, 
                                             doc_term_matrix_tensor)
    torch.save(model_GSMLDA_w2v_20news, "model_GSMLDA_w2v_20news")
    
    
n = 10

  p = F.softmax(z)                                                # mixture probability
  loss_epoch += loss.data[0]    # add loss to loss_epoch


Epoch 0, loss=707.623779296875
Epoch 5, loss=570.6974487304688
Epoch 10, loss=561.2691650390625
Epoch 15, loss=555.5097045898438
Epoch 20, loss=551.9954833984375
Epoch 25, loss=548.9158935546875
Epoch 30, loss=547.31884765625
Epoch 35, loss=545.2451782226562
Epoch 40, loss=543.6106567382812
Epoch 45, loss=541.9308471679688
Epoch 50, loss=544.9839477539062
Epoch 55, loss=542.3426513671875
Epoch 60, loss=542.0008544921875
Epoch 65, loss=543.2259521484375
Epoch 70, loss=539.7322998046875
Epoch 75, loss=540.2738037109375


### FastText: Skip-Gram

In [11]:
if isfile("model_GSMLDA_fasttext_20news"):
    model_GSMLDA_fasttext_20news = torch.load("model_GSMLDA_fasttext_20news")
else:
    model_GSMLDA_fasttext_20news = TopicVAE.GSMLDA(args, embedding_matrix_fasttext_20newsgroups)
    optimizer_GSMLDA_fasttext_20news = torch.optim.Adam(model_GSMLDA_fasttext_20news.parameters(), args.learning_rate, 
                                            betas=(args.momentum, 0.999))
    model_GSMLDA_fasttext_20news = TopicVAE.train(model_GSMLDA_fasttext_20news, args, optimizer_GSMLDA_fasttext_20news, 
                                             doc_term_matrix_tensor)
    torch.save(model_GSMLDA_fasttext_20news, "model_GSMLDA_fasttext_20news")
    

Epoch 0, loss=691.7111206054688
Epoch 5, loss=570.8758544921875
Epoch 10, loss=561.243896484375
Epoch 15, loss=556.7259521484375
Epoch 20, loss=552.0405883789062
Epoch 25, loss=551.0697631835938
Epoch 30, loss=548.2799682617188
Epoch 35, loss=549.1189575195312
Epoch 40, loss=546.4395751953125
Epoch 45, loss=545.6307373046875
Epoch 50, loss=542.1295776367188
Epoch 55, loss=543.5731811523438
Epoch 60, loss=543.1550903320312
Epoch 65, loss=542.8093872070312
Epoch 70, loss=542.3997802734375
Epoch 75, loss=542.4989013671875


### Word2Vec: CBOW

In [12]:
if isfile("model_GSMLDA_w2v_cbow_20news"):
    model_GSMLDA_w2v_cbow_20news = torch.load("model_GSMLDA_w2v_cbow_20news")
else:
    model_GSMLDA_w2v_cbow_20news = TopicVAE.GSMLDA(args, embedding_matrix_w2v_cbow_20newsgroups)
    optimizer_GSMLDA_w2v_cbow_20news = torch.optim.Adam(model_GSMLDA_w2v_cbow_20news.parameters(), args.learning_rate, 
                                            betas=(args.momentum, 0.999))
    model_GSMLDA_w2v_cbow_20news = TopicVAE.train(model_GSMLDA_w2v_cbow_20news, args, optimizer_GSMLDA_w2v_cbow_20news, 
                                             doc_term_matrix_tensor)
    torch.save(model_GSMLDA_w2v_cbow_20news, "model_GSMLDA_w2v_cbow_20news")
    

Epoch 0, loss=761.8871459960938
Epoch 5, loss=584.682373046875
Epoch 10, loss=568.9644775390625
Epoch 15, loss=562.3456420898438
Epoch 20, loss=555.24853515625
Epoch 25, loss=552.6206665039062
Epoch 30, loss=551.4342041015625
Epoch 35, loss=547.6995849609375
Epoch 40, loss=546.9485473632812
Epoch 45, loss=545.9097290039062
Epoch 50, loss=543.5358276367188
Epoch 55, loss=543.0621948242188
Epoch 60, loss=540.9735717773438
Epoch 65, loss=540.71826171875
Epoch 70, loss=540.1134033203125
Epoch 75, loss=540.0799560546875


### FastText: CBOW

In [13]:
if isfile("model_GSMLDA_fasttext_cbow_20news"):
    model_GSMLDA_fasttext_cbow_20news = torch.load("model_GSMLDA_fasttext_cbow_20news")
else:
    model_GSMLDA_fasttext_cbow_20news = TopicVAE.GSMLDA(args, embedding_matrix_fasttext_cbow_20newsgroups)
    optimizer_GSMLDA_fasttext_cbow_20news = torch.optim.Adam(model_GSMLDA_fasttext_cbow_20news.parameters(), args.learning_rate, 
                                            betas=(args.momentum, 0.999))
    model_GSMLDA_fasttext_cbow_20news = TopicVAE.train(model_GSMLDA_fasttext_cbow_20news, args, optimizer_GSMLDA_fasttext_cbow_20news, 
                                             doc_term_matrix_tensor)
    torch.save(model_GSMLDA_fasttext_cbow_20news, "model_GSMLDA_fasttext_cbow_20news")
    

Epoch 0, loss=763.6353759765625
Epoch 5, loss=585.0731811523438
Epoch 10, loss=568.3031616210938
Epoch 15, loss=561.2235717773438
Epoch 20, loss=554.6532592773438
Epoch 25, loss=554.6126098632812
Epoch 30, loss=550.8827514648438
Epoch 35, loss=548.2445068359375
Epoch 40, loss=545.28955078125
Epoch 45, loss=543.41162109375
Epoch 50, loss=544.2401123046875
Epoch 55, loss=542.5443115234375
Epoch 60, loss=543.7285766601562
Epoch 65, loss=541.4929809570312
Epoch 70, loss=540.6140747070312
Epoch 75, loss=539.5203247070312


# Look at Coherences, "Miao" Pretrained Vectors 20NewsGroups (old)

In [None]:
# model_GSMLDA_w2v_20news_beta = model_GSMLDA_w2v_20news.get_beta()
# tools.print_top_words(model_GSMLDA_w2v_20news_beta, vectorizer.get_feature_names(), n_top_words = 20)

In [None]:
# pretrained_20news_models = [model_GSMLDA_w2v_20news.get_beta(), model_GSMLDA_fasttext_20news.get_beta(), 
#                             model_GSMLDA_w2v_cbow_20news.get_beta(), model_GSMLDA_fasttext_cbow_20news.get_beta()]
# pretrained_20news_coherences = [tools.topic_coherence(beta, 20, doc_term_matrix) for beta in pretrained_20news_models]

# pretrained_20news_coherences_means = [coherences.mean() for coherences in pretrained_20news_coherences]
# pretrained_20news_coherences_means_df = pd.DataFrame(pretrained_20news_coherences_means).to_latex()

# print(pretrained_20news_coherences_means)

## Miao with Pretrained Vectors (outside text)

### FastText: from Wiki

In [14]:
if isfile("model_GSMLDA_fasttext_wiki"):
    model_GSMLDA_fasttext_wiki = torch.load("model_GSMLDA_fasttext_wiki")
else:
    model_GSMLDA_fasttext_wiki = TopicVAE.GSMLDA(args, embedding_matrix_fasttext_wiki)
    optimizer_GSMLDA_fasttext_wiki = torch.optim.Adam(model_GSMLDA_fasttext_wiki.parameters(), args.learning_rate, 
                                            betas=(args.momentum, 0.999))
    model_GSMLDA_fasttext_wiki = TopicVAE.train(model_GSMLDA_fasttext_wiki, args, optimizer_GSMLDA_fasttext_wiki, 
                                             doc_term_matrix_tensor)
    torch.save(model_GSMLDA_fasttext_wiki, "model_GSMLDA_fasttext_wiki")


Epoch 0, loss=1819.1591796875
Epoch 5, loss=1000.1744995117188
Epoch 10, loss=928.9122924804688
Epoch 15, loss=708.18994140625
Epoch 20, loss=647.6550903320312
Epoch 25, loss=616.5306396484375
Epoch 30, loss=599.0319213867188
Epoch 35, loss=590.3344116210938
Epoch 40, loss=579.5968627929688
Epoch 45, loss=567.510498046875
Epoch 50, loss=561.5753784179688
Epoch 55, loss=556.7196655273438
Epoch 60, loss=555.0986938476562
Epoch 65, loss=555.7420043945312
Epoch 70, loss=551.6532592773438
Epoch 75, loss=551.369873046875


In [None]:
GSMLDA_fasttext_wiki_coherences = tools.topic_coherence(model_GSMLDA_fasttext_wiki.get_beta(), 20, doc_term_matrix)

In [None]:
GSMLDA_fasttext_wiki_coherences_mean = GSMLDA_fasttext_wiki_coherences.mean()

In [None]:
print(GSMLDA_fasttext_wiki_coherences_mean)

### Word2Vec: from ???

## Miao without Pretrained Vectors

In [15]:
if isfile("model_GSMLDA"):
    model_GSMLDA = torch.load("model_GSMLDA")
else:
    model_GSMLDA = TopicVAE.GSMLDA(args)
    optimizer_GSMLDA = torch.optim.Adam(model_GSMLDA.parameters(), args.learning_rate, betas=(args.momentum, 0.999))
    model_GSMLDA = TopicVAE.train(model_GSMLDA, args, optimizer_GSMLDA, doc_term_matrix_tensor)
    torch.save(model_GSMLDA, "model_GSMLDA")


Epoch 0, loss=727.3043212890625
Epoch 5, loss=571.98828125
Epoch 10, loss=567.2332153320312
Epoch 15, loss=561.876708984375
Epoch 20, loss=559.117431640625
Epoch 25, loss=558.0926513671875
Epoch 30, loss=555.7078247070312
Epoch 35, loss=555.4019775390625
Epoch 40, loss=554.8189697265625
Epoch 45, loss=552.0185546875
Epoch 50, loss=551.605712890625
Epoch 55, loss=550.105712890625
Epoch 60, loss=548.0165405273438
Epoch 65, loss=547.9912719726562
Epoch 70, loss=545.1119384765625
Epoch 75, loss=546.1055297851562


In [32]:
GSMLDA_coherences = tools.topic_coherence(model_GSMLDA.get_beta(), 20, doc_term_matrix)

In [33]:
GSMLDA_coherences_mean = GSMLDA_coherences.mean()

In [34]:
print(GSMLDA_coherences_mean)

-1148.3212446470566


In [35]:
if isfile("model_GSMLDA2"):
    model_GSMLDA2 = torch.load("model_GSMLDA2")
else:
    model_GSMLDA2 = TopicVAE.GSMLDA(args)
    optimizer_GSMLDA2 = torch.optim.Adam(model_GSMLDA2.parameters(), args.learning_rate, betas=(args.momentum, 0.999))
    model_GSMLDA2 = TopicVAE.train(model_GSMLDA2, args, optimizer_GSMLDA2, doc_term_matrix_tensor)
    torch.save(model_GSMLDA2, "model_GSMLDA2")

Epoch 0, loss=710.1594848632812
Epoch 5, loss=571.350830078125
Epoch 10, loss=564.021240234375
Epoch 15, loss=560.5138549804688
Epoch 20, loss=557.8563842773438
Epoch 25, loss=556.1616821289062
Epoch 30, loss=554.48095703125
Epoch 35, loss=554.1760864257812
Epoch 40, loss=552.4703369140625
Epoch 45, loss=552.4207153320312
Epoch 50, loss=550.7462768554688
Epoch 55, loss=549.5635375976562
Epoch 60, loss=548.7001953125
Epoch 65, loss=548.9349365234375
Epoch 70, loss=547.9478759765625
Epoch 75, loss=547.1461181640625


In [36]:
GSMLDA2_coherences = tools.topic_coherence(model_GSMLDA2.get_beta(), 20, doc_term_matrix)
GSMLDA2_coherences_mean = GSMLDA2_coherences.mean()
print(GSMLDA2_coherences_mean)


-586.8383929756021


# NVLDA

In [16]:
if isfile("model_LDA"):
    model_LDA = torch.load("model_LDA")
else:
    model_LDA = TopicVAE.LDA(args)
    optimizer_LDA = torch.optim.Adam(model_LDA.parameters(), args.learning_rate, betas=(args.momentum, 0.999))
    model_GSMLDA = TopicVAE.train(model_LDA, args, optimizer_LDA, doc_term_matrix_tensor)
    torch.save(model_LDA, "model_LDA")


  p = F.softmax(z)                                                # mixture probability


Epoch 0, loss=614.6697998046875
Epoch 5, loss=585.0731811523438
Epoch 10, loss=567.2426147460938
Epoch 15, loss=562.9029541015625
Epoch 20, loss=558.6740112304688
Epoch 25, loss=557.590087890625
Epoch 30, loss=557.8560791015625
Epoch 35, loss=556.2362670898438
Epoch 40, loss=553.95068359375
Epoch 45, loss=554.6138916015625
Epoch 50, loss=554.067626953125
Epoch 55, loss=553.5233764648438
Epoch 60, loss=551.6415405273438
Epoch 65, loss=552.6060791015625
Epoch 70, loss=551.7682495117188
Epoch 75, loss=549.5260009765625


# Compare Coherences

In [18]:
# model_GSMLDA_w2v_20news_beta = model_GSMLDA_w2v_20news.get_beta()
# tools.print_top_words(model_GSMLDA_w2v_20news_beta, vectorizer.get_feature_names(), n_top_words = 20)

models = [model_GSMLDA_w2v_20news, model_GSMLDA_fasttext_20news, 
          model_GSMLDA_w2v_cbow_20news, model_GSMLDA_fasttext_cbow_20news, 
          model_GSMLDA_fasttext_wiki, model_GSMLDA, model_LDA]

models_betas = [model.get_beta() for model in models]
coherences = [tools.topic_coherence(beta, 20, doc_term_matrix) for beta in models_betas]
coherences_means = [coherences.mean() for coherences in coherences]
coherences_means_df = pd.DataFrame(coherences_means).to_latex()

print(coherences_means)


[-603.2732064836742, -609.1866787140507, -638.079224907247, -618.5259246022016, -559.4402619940417, -1148.3212446470566, -1148.3212446470566]


In [29]:
tools.topic_coherence(model_LDA.get_beta(), 20, doc_term_matrix)

array([-1105.52860479, -1164.20327375, -1197.69839877, -1123.01620449,
       -1043.89114835, -1160.561923  , -1084.250682  , -1181.9325668 ,
       -1186.24327668, -1138.67951998, -1103.92344374, -1048.49830253,
       -1181.17941133, -1180.88287797, -1204.30671623, -1118.81880687,
       -1030.34422392, -1045.8553841 , -1232.24115401, -1167.89015564,
       -1173.79258331, -1172.38319386, -1128.58679358, -1235.17897937,
       -1249.96496039, -1154.00782508, -1209.48595554, -1094.69523856,
       -1254.911861  , -1070.44923484, -1219.2687084 , -1144.3954316 ,
       -1163.73173323, -1177.94262558, -1136.26436815, -1003.17274192,
       -1242.48854519, -1143.1228774 , -1092.03274886, -1096.468171  ,
       -1152.84049136, -1140.69190316, -1075.64430344, -1245.91205374,
       -1165.3811309 , -1237.28810388, -1184.1547856 , -1134.11874779,
       -1139.04698852, -1078.69307215])

In [37]:
tools.print_top_words(model_GSMLDA2.get_beta(), vectorizer.get_feature_names(), n_top_words = 20)

---------------Printing the Topics------------------

     edu com posting host nntp writes university distribution article new reply mail just apr usa league john like michael don

     file information space program available number section source email internet image software ftp pub code line using anonymous set new
midea
     people god think don said know say believe did right just question like life way israeli come says israel things

     use key chip information encryption public data used bit new space keys technology number clipper message using available access need
polit
     people said know don think just right did going say time didn like says went told state president things came

     edu com article posting university like just new world good problem does don distribution know computer writes reply science news

     new government public number used time use national possible does years questions case order long april want information right non

     don think like

# Compare Perplexities

In [25]:
tools.perplexity(model_GSMLDA, doc_term_tensor_test)

perplexities = [tools.perplexity(model, doc_term_tensor_test).item() for model in models]

In [26]:
print(perplexities)

[871.3388671875, 877.4763793945312, 875.93212890625, 879.0941772460938, 929.1653442382812, 912.7308959960938, 914.7857055664062]


## Compare Coherences OLD.

In [None]:
import matplotlib.pyplot as plt

In [None]:
# plt.style.use("seaborn-deep")

x = GSMLDA_without_embedding_coherence
y = GSMLDA2_20newsgroups_coherence

plt.hist([x, y], label = ["without embedding", "with 20newsgroups embedding"])
plt.legend(loc = 'upper right')
plt.show()

# t test

In [None]:
new_word_vecs = dict(zip(vectorizer.get_feature_names(), [model_GSMLDA2.word_embedding.weight[i] for i in range(model_GSMLDA2.word_embedding.weight.shape[0])]))


In [None]:
model_GSMLDA2.word_embedding.weight.detach().numpy()

In [None]:
len(vectorizer.get_feature_names())

In [None]:
len([model_GSMLDA2.word_embedding.weight[i] for i in range(model_GSMLDA2.word_embedding.weight.shape[0])])


In [None]:
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
cos_sim_matrix = cosine_similarity(model_GSMLDA2.word_embedding.weight.detach().numpy(), 
                                   model_GSMLDA2.word_embedding.weight.detach().numpy())

In [None]:
def n_closest_words(word, cos_sim_matrix, n):
    word_index = vectorizer.get_feature_names().index(word)
    close_words_indices = np.argsort(cos_sim_matrix[word_index])[-n:]
    print(close_words_indices)
    return [vectorizer.get_feature_names()[j] for j in close_words_indices]
    

In [None]:
n_closest_words("nasa", cos_sim_matrix, 20)

In [None]:
model_GSMLDA_cos_sim_matrix = cosine_similarity(model_GSMLDA_without_embedding.word_embedding.weight.detach().numpy(), 
                                   model_GSMLDA_without_embedding.word_embedding.weight.detach().numpy())
n_closest_words("amendment", model_GSMLDA_cos_sim_matrix, 20)


In [None]:
lm_20newsgroups.most_similar("nasa")

In [None]:
tools.perplexity(model_GSMLDA, doc_term_tensor_test)

In [None]:
doc_term_tensor_test.shape

In [None]:
doc_term_matrix_tensor.shape

In [None]:
doc_term_matrix_test.shape