In [None]:
import pickle
import gensim
import pandas as pd
import numpy as np
from sklearn.linear_model import LogisticRegression
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

from utilities import get_nodeid2text

In [None]:
with open("data/corpus.pkl", "rb") as f:
    corpus = pickle.load(f)
with open("data/id2word.pkl", "rb") as f:
    id2word = pickle.load(f)
nodeid2text = pd.read_pickle("data/nodeid2text_gensim.pkl")
_, (train_idx, valid_idx, test_idx) = get_nodeid2text()
num_topics_list = [10, 20, 40, 80]

In [None]:
for n_topics in tqdm(num_topics_list):
    # create a latent dirichlet allocation model from our training data
    # we arbitarily set the number of topics to be 10
    lda_model = gensim.models.LdaMulticore(corpus=corpus, id2word=id2word, num_topics=n_topics)
    # now we apply this model to our entire dataset
    gammas, _ = lda_model.inference(
        [id2word.doc2bow(text) for text in nodeid2text["words_clean"]]
    )
    np.save(f"gammas/{n_topics}_topics.npy", gammas)

In [None]:
all_gammas = [
    np.load(f"gammas/{n_topics}_topics.npy") for n_topics in tqdm(num_topics_list)
]
acc = np.zeros((len(num_topics_list), 3), dtype=float)
for i, (n_topics, gammas) in enumerate(zip(tqdm(num_topics_list), all_gammas)):
    # train a classifier on the output of the LDA model
    logistic_clf = LogisticRegression(random_state=0).fit(
        gammas[train_idx], nodeid2text.iloc[train_idx]["label"]
    )
    # test the model
    train_acc = logistic_clf.score(
        gammas[train_idx], nodeid2text.iloc[train_idx]["label"]
    )
    valid_acc = logistic_clf.score(
        gammas[valid_idx], nodeid2text.iloc[valid_idx]["label"]
    )
    test_acc = logistic_clf.score(gammas[test_idx], nodeid2text.iloc[test_idx]["label"])
    acc[i] = np.array([train_acc, valid_acc, test_acc], dtype=float)
    print(f"Num Topics: {n_topics}")
    print(f"Training Accuracy: {train_acc}")
    print(f"Validation Accuracy: {valid_acc}")
    print(f"Test Accuracy: {test_acc}")


In [None]:
plt.title("Validation Acc. v. Num Topics")
plt.xlabel("Number of Topics")
plt.ylabel("Validation Accuracy")
plt.bar([str(n_topics) for n_topics in num_topics_list], acc[:, 1])

In [None]:
acc