In [None]:
import pandas as pd 
df = pd.read_csv('../data/stm_datasets/poliblogs2008.csv')
df['doc'] = df['documents']
df

In [None]:
import sys
sys.path.append('../gtm/')

from utils import text_processor
p = text_processor('en_core_web_sm', pos_tags_to_keep = ['VERB', 'NOUN'])
df['doc_clean'] = p.process_docs(df['doc'])

In [None]:
from corpus import GTMCorpus

# Create a GTMCorpus object
train_dataset = GTMCorpus(
    df, 
    prevalence = "~ rating", # + C(speech_year) 
    content = "~ rating" # + C(speech_year)
)

train_dataset.M_prevalence_covariates.shape

In [None]:
from gtm import GTM

# Train the model
tm = GTM(
    train_dataset, 
    n_topics=20,
    doc_topic_prior='dirichlet', # logistic_normal
    alpha=0.1,
    prevalence_covariates_regularization=0.1,
    update_prior=True,
    encoder_hidden_layers=[], # structure of the encoder neural net
    decoder_hidden_layers=[300], # structure of the decoder neural net
    num_epochs=2,
    print_every=10000,
    log_every=1,
    w_prior=None,
    batch_size=250
)

In [None]:
tm.plot_wordcloud(topic_id = 18)

In [None]:
# Assess the quality of the learned word embeddings 
# Top 8 closest words to a specific word

import torch
import torch.nn.functional as F

specific_word = 'tax'

word_id = [i for i,w in enumerate(train_dataset.vocab) if w == specific_word][0]

words = tm.AutoEncoder.decoder['dec_1'].weight.T

logit = torch.matmul(words.T[word_id], words)

beta = F.softmax(logit)

tm.AutoEncoder.eval()
topic_words = []
vals, indices = torch.topk(beta, 8)
vals = vals.cpu().tolist()
indices = indices.cpu().tolist()
[tm.id2token[idx] for idx in indices]

In [None]:
dfc = tm.estimate_effect(train_dataset, n_samples=10, topic_ids=None)
dfc

In [None]:
import statsmodels.api as sm
Y = tm.get_doc_topic_distribution(train_dataset)
X = train_dataset.M_prevalence_covariates
model = sm.OLS(Y[:,1],X)
results = model.fit()
covs = train_dataset.prevalence_colnames
pd.DataFrame([covs, results.params])

In [None]:
tm.get_top_docs(train_dataset, topic_id = 10)

In [None]:
tm.prior.prevalence_covariates_size

In [None]:
tm.prior.sample(N=X.shape[0], M_prevalence_covariates = X, epoch = 10)[0]