In [77]:
import string
import spacy
spacy_nlp = spacy.load("en_core_web_sm")
from spacy.lang.en.stop_words import STOP_WORDS

import utility_functions as utils
import importlib
import pandas as pd
from itertools import chain
import numpy as np
from tqdm import tqdm
import json
from scipy.sparse import dok_matrix
import matplotlib.pyplot as plt

from contextualized_topic_models.models.ctm import ZeroShotTM
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation

# Octis is the library which can use different implemented topic modelling techniques
from octis.preprocessing.preprocessing import Preprocessing
from octis.evaluation_metrics.coherence_metrics import Coherence
from octis.evaluation_metrics.diversity_metrics import TopicDiversity
from octis.models.LDA import LDA
from octis.models.CTM import CTM
from octis.models.ETM import ETM
from octis.models.NeuralLDA import NeuralLDA
from octis.models.NMF import NMF

importlib.reload(utils)
data = './preprocessed_df.pkl'

In [None]:
df = pd.read_pickle(data)
df[['Artist', 'Song', 'Tokens', 'Lyrics', 'Coast']].head()

In [None]:
# Create the corpus required by OCTIS to build up the dataset
with open('corpus.tsv', 'w', encoding='utf-8') as file:
    for lyrics in df['Lyrics']:
        if pd.notna(lyrics):
            file.write(lyrics + '\n')

In [None]:
# Flatten all tokens to create a single list of words
vocab = set(chain.from_iterable(df['Tokens'].tolist()))

# Save as vocabulary.json
with open("./vocabulary.json", 'w') as f:
    json.dump(list(vocab), f)

In [None]:
# Initialize the document-term matrix
num_docs = len(df)
num_terms = len(vocab)
doc_term_matrix = dok_matrix((num_docs, num_terms), dtype=np.int32)

# Build the matrix
token_to_index = {word: idx for idx, word in enumerate(vocab)}

for doc_idx, tokens in enumerate(df['Tokens']):
    for token in tokens:
        if token in token_to_index:
            word_idx = token_to_index[token]
            doc_term_matrix[doc_idx, word_idx] += 1

# Convert the matrix to a sparse format JSON
sparse_matrix = []
for (doc_idx, word_idx), freq in doc_term_matrix.items():
    sparse_matrix.append([doc_idx, word_idx, int(freq)])

# Save as doc_term_matrix.json
with open("./doc_term_matrix.json", 'w') as f:
    json.dump(sparse_matrix, f)

In [None]:
# Initialize preprocessing
preprocessor = Preprocessing(
    vocabulary=None,
    max_features=None,
    remove_punctuation=True,
    punctuation=string.punctuation,
    lemmatize=True,
    min_chars=2,
    min_words_docs=0,
    save_original_indexes=True,
    min_df=0.05,
    max_df=0.8,
    split=True
)

dataset = preprocessor.preprocess_dataset(documents_path="./corpus.tsv")

In [67]:
from octis.evaluation_metrics.coherence_metrics import Coherence
from octis.evaluation_metrics.diversity_metrics import TopicDiversity

# Initialize the coherence and diversity metrics
coherence_cv = Coherence(topk=10, measure='c_v')
coherence_umass = Coherence(topk=10, measure='u_mass')
topic_diversity = TopicDiversity(topk=10)

In [84]:
from itertools import product

# Define hyperparameter grids for each model
param_grids = {
    'LDA': {
        'num_topics': [2, 3, 4, 5],
        'iterations': [500, 1000],
        'random_state': [42]
    },
    'CTM': {
        'num_topics': [2, 3, 4, 5],
        'num_epochs': [5, 10],
    },
    'ETM': {
        'num_topics': [2, 3, 4, 5],
        'num_epochs': [50, 100],
    },
    'NeuralLDA': {
        'num_topics': [2, 3, 4, 5],
        'num_epochs': [50, 100],
        'lr': [2e-3, 1e-3],
    },
    'NMF': {
        'num_topics': [2, 3, 4, 5],
        'random_state': [42]
    }
}

def evaluate_coherence(model_output):
    # Initialize the coherence metric
    coherence_cv = Coherence(topk=10, measure='c_v')

    # Calculate and return coherence score
    coherence_score = coherence_cv.score(model_output)
    return coherence_score

def parameter_search(model_name, dataset, param_grid):
    param_combinations = list(product(*param_grid.values()))
    best_score = -float('inf')
    best_params = None
    best_model_output = None

    # Use tqdm to track progress in parameter search
    for params in tqdm(param_combinations, desc=f"Searching {model_name} Params"):
        # Create parameter dict
        param_dict = dict(zip(param_grid.keys(), params))

        # Initialize and train the model based on model name and params
        if model_name == 'LDA':
            model = LDA(**param_dict)
        elif model_name == 'CTM':
            model = CTM(**param_dict)
        elif model_name == 'ETM':
            model = ETM(**param_dict)
        elif model_name == 'NeuralLDA':
            model = NeuralLDA(**param_dict)
        elif model_name == 'NMF':
            model = NMF(**param_dict)

        # Train the model
        model_output = model.train_model(dataset)

        # Evaluate the model (e.g., using coherence score)
        score = evaluate_coherence(model_output)

        if score > best_score:
            best_score = score
            best_params = param_dict
            best_model_output = model_output

    return best_model_output, best_params, best_score

In [85]:
best_models = {}

# Perform hyperparameter search for each model with progress bars
for model_name, param_grid in tqdm(param_grids.items(), desc="Overall Model Parameter Search"):
    best_output, best_params, best_score = parameter_search(model_name, dataset, param_grid)
    best_models[model_name] = {'output': best_output, 'params': best_params, 'score': best_score}
    print(f"Best {model_name} Params: {best_params} with Score: {best_score}")

Overall Model Parameter Search:   0%|          | 0/5 [00:00<?, ?it/s]
Searching LDA Params:   0%|          | 0/8 [00:00<?, ?it/s][A
Searching LDA Params:  12%|█▎        | 1/8 [00:07<00:53,  7.59s/it][A
Searching LDA Params:  25%|██▌       | 2/8 [00:16<00:48,  8.10s/it][A
Searching LDA Params:  38%|███▊      | 3/8 [00:22<00:36,  7.32s/it][A
Searching LDA Params:  50%|█████     | 4/8 [00:29<00:29,  7.41s/it][A
Searching LDA Params:  62%|██████▎   | 5/8 [00:35<00:20,  6.74s/it][A
Searching LDA Params:  75%|███████▌  | 6/8 [00:41<00:13,  6.60s/it][A
Searching LDA Params:  88%|████████▊ | 7/8 [00:47<00:06,  6.23s/it][A
Searching LDA Params: 100%|██████████| 8/8 [00:53<00:00,  6.72s/it][A
Overall Model Parameter Search:  20%|██        | 1/5 [00:53<03:34, 53.74s/it]

Best LDA Params: {'num_topics': 2, 'iterations': 500, 'random_state': 42} with Score: 0.6017872659552621



Searching CTM Params:   0%|          | 0/8 [00:00<?, ?it/s][A
Searching CTM Params:  12%|█▎        | 1/8 [00:01<00:11,  1.70s/it][A
Searching CTM Params:  25%|██▌       | 2/8 [00:03<00:10,  1.68s/it][A
Searching CTM Params:  38%|███▊      | 3/8 [00:05<00:08,  1.72s/it][A
Searching CTM Params:  50%|█████     | 4/8 [00:07<00:07,  1.85s/it][A
Searching CTM Params:  62%|██████▎   | 5/8 [00:09<00:05,  1.85s/it][A
Searching CTM Params:  75%|███████▌  | 6/8 [00:10<00:03,  1.89s/it][A
Searching CTM Params:  88%|████████▊ | 7/8 [00:12<00:01,  1.93s/it][A
Searching CTM Params: 100%|██████████| 8/8 [00:14<00:00,  1.87s/it][A
Overall Model Parameter Search:  40%|████      | 2/5 [01:08<01:32, 30.91s/it]

Best CTM Params: {'num_topics': 3, 'num_epochs': 5} with Score: 0.6173094858246226



Searching ETM Params:   0%|          | 0/8 [00:00<?, ?it/s][A

model: ETM(
  (t_drop): Dropout(p=0.5, inplace=False)
  (theta_act): ReLU()
  (rho): Linear(in_features=300, out_features=720, bias=False)
  (alphas): Linear(in_features=300, out_features=2, bias=False)
  (q_theta): Sequential(
    (0): Linear(in_features=720, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=2, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=2, bias=True)
)
****************************************************************************************************
Epoch----->1 .. LR: 0.005 .. KL_theta: 0.2 .. Rec_loss: 1111.18 .. NELBO: 1111.38
****************************************************************************************************
****************************************************************************************************
VALIDATION .. LR: 0.005 .. KL_theta: 0.07 .. Rec_loss: 277.77 .. NELBO: 277.84
******


Searching ETM Params:  12%|█▎        | 1/8 [00:04<00:29,  4.22s/it][A

model: ETM(
  (t_drop): Dropout(p=0.5, inplace=False)
  (theta_act): ReLU()
  (rho): Linear(in_features=300, out_features=720, bias=False)
  (alphas): Linear(in_features=300, out_features=2, bias=False)
  (q_theta): Sequential(
    (0): Linear(in_features=720, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=2, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=2, bias=True)
)
****************************************************************************************************
Epoch----->1 .. LR: 0.005 .. KL_theta: 0.01 .. Rec_loss: 1111.96 .. NELBO: 1111.97
****************************************************************************************************
****************************************************************************************************
VALIDATION .. LR: 0.005 .. KL_theta: 0.0 .. Rec_loss: 278.08 .. NELBO: 278.08
******


Searching ETM Params:  25%|██▌       | 2/8 [00:10<00:31,  5.31s/it][A

model: ETM(
  (t_drop): Dropout(p=0.5, inplace=False)
  (theta_act): ReLU()
  (rho): Linear(in_features=300, out_features=720, bias=False)
  (alphas): Linear(in_features=300, out_features=3, bias=False)
  (q_theta): Sequential(
    (0): Linear(in_features=720, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=3, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=3, bias=True)
)
****************************************************************************************************
Epoch----->1 .. LR: 0.005 .. KL_theta: 0.09 .. Rec_loss: 1113.3 .. NELBO: 1113.39
****************************************************************************************************
****************************************************************************************************
VALIDATION .. LR: 0.005 .. KL_theta: 0.08 .. Rec_loss: 278.55 .. NELBO: 278.63
******


Searching ETM Params:  38%|███▊      | 3/8 [00:12<00:20,  4.09s/it][A

model: ETM(
  (t_drop): Dropout(p=0.5, inplace=False)
  (theta_act): ReLU()
  (rho): Linear(in_features=300, out_features=720, bias=False)
  (alphas): Linear(in_features=300, out_features=3, bias=False)
  (q_theta): Sequential(
    (0): Linear(in_features=720, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=3, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=3, bias=True)
)
****************************************************************************************************
Epoch----->1 .. LR: 0.005 .. KL_theta: 0.04 .. Rec_loss: 1113.2 .. NELBO: 1113.24
****************************************************************************************************
****************************************************************************************************
VALIDATION .. LR: 0.005 .. KL_theta: 0.06 .. Rec_loss: 278.55 .. NELBO: 278.61
******


Searching ETM Params:  50%|█████     | 4/8 [00:20<00:21,  5.38s/it][A

model: ETM(
  (t_drop): Dropout(p=0.5, inplace=False)
  (theta_act): ReLU()
  (rho): Linear(in_features=300, out_features=720, bias=False)
  (alphas): Linear(in_features=300, out_features=4, bias=False)
  (q_theta): Sequential(
    (0): Linear(in_features=720, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=4, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=4, bias=True)
)
****************************************************************************************************
Epoch----->1 .. LR: 0.005 .. KL_theta: 0.14 .. Rec_loss: 1113.76 .. NELBO: 1113.9
****************************************************************************************************
****************************************************************************************************
VALIDATION .. LR: 0.005 .. KL_theta: 0.1 .. Rec_loss: 278.77 .. NELBO: 278.87
*******


Searching ETM Params:  62%|██████▎   | 5/8 [00:22<00:12,  4.31s/it][A

model: ETM(
  (t_drop): Dropout(p=0.5, inplace=False)
  (theta_act): ReLU()
  (rho): Linear(in_features=300, out_features=720, bias=False)
  (alphas): Linear(in_features=300, out_features=4, bias=False)
  (q_theta): Sequential(
    (0): Linear(in_features=720, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=4, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=4, bias=True)
)
****************************************************************************************************
Epoch----->1 .. LR: 0.005 .. KL_theta: 0.01 .. Rec_loss: 1114.02 .. NELBO: 1114.03
****************************************************************************************************
****************************************************************************************************
VALIDATION .. LR: 0.005 .. KL_theta: 0.01 .. Rec_loss: 278.91 .. NELBO: 278.92
*****


Searching ETM Params:  75%|███████▌  | 6/8 [00:25<00:07,  3.65s/it][A

model: ETM(
  (t_drop): Dropout(p=0.5, inplace=False)
  (theta_act): ReLU()
  (rho): Linear(in_features=300, out_features=720, bias=False)
  (alphas): Linear(in_features=300, out_features=5, bias=False)
  (q_theta): Sequential(
    (0): Linear(in_features=720, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=5, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=5, bias=True)
)
****************************************************************************************************
Epoch----->1 .. LR: 0.005 .. KL_theta: 0.02 .. Rec_loss: 1114.78 .. NELBO: 1114.8
****************************************************************************************************
****************************************************************************************************
VALIDATION .. LR: 0.005 .. KL_theta: 0.03 .. Rec_loss: 279.27 .. NELBO: 279.3
*******


Searching ETM Params:  88%|████████▊ | 7/8 [00:30<00:04,  4.32s/it][A

model: ETM(
  (t_drop): Dropout(p=0.5, inplace=False)
  (theta_act): ReLU()
  (rho): Linear(in_features=300, out_features=720, bias=False)
  (alphas): Linear(in_features=300, out_features=5, bias=False)
  (q_theta): Sequential(
    (0): Linear(in_features=720, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=5, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=5, bias=True)
)
****************************************************************************************************
Epoch----->1 .. LR: 0.005 .. KL_theta: 0.01 .. Rec_loss: 1114.75 .. NELBO: 1114.76
****************************************************************************************************
****************************************************************************************************
VALIDATION .. LR: 0.005 .. KL_theta: 0.01 .. Rec_loss: 279.33 .. NELBO: 279.34
*****


Searching ETM Params: 100%|██████████| 8/8 [00:35<00:00,  4.41s/it][A
Overall Model Parameter Search:  60%|██████    | 3/5 [01:43<01:05, 32.92s/it]

Best ETM Params: {'num_topics': 5, 'num_epochs': 50} with Score: 0.639376252501009



Searching NeuralLDA Params:   0%|          | 0/16 [00:00<?, ?it/s][A

Epoch: [1/50]	Samples: [956/47800]	Train Loss: 1306.3124754837866	Time: 0:00:00.031102
Epoch: [1/50]	Samples: [205/10250]	Validation Loss: 1152.355449695122	Time: 0:00:00.002602
Epoch: [2/50]	Samples: [1912/47800]	Train Loss: 1308.0964712996863	Time: 0:00:00.030453
Epoch: [2/50]	Samples: [205/10250]	Validation Loss: 1151.0301829268292	Time: 0:00:00.004017
Epoch: [3/50]	Samples: [2868/47800]	Train Loss: 1299.848538833682	Time: 0:00:00.035441
Epoch: [3/50]	Samples: [205/10250]	Validation Loss: 1148.5410299161585	Time: 0:00:00.002554
Epoch: [4/50]	Samples: [3824/47800]	Train Loss: 1242.8418295632846	Time: 0:00:00.027743
Epoch: [4/50]	Samples: [205/10250]	Validation Loss: 1147.4167063643292	Time: 0:00:00.002639
Epoch: [5/50]	Samples: [4780/47800]	Train Loss: 1218.9853638206066	Time: 0:00:00.029398
Epoch: [5/50]	Samples: [205/10250]	Validation Loss: 1143.5890196265243	Time: 0:00:00.002293
Epoch: [6/50]	Samples: [5736/47800]	Train Loss: 1259.0131978948746	Time: 0:00:00.035583
Epoch: [6/50]	S


Searching NeuralLDA Params:   6%|▋         | 1/16 [00:02<00:31,  2.13s/it][A

Epoch: [1/50]	Samples: [956/47800]	Train Loss: 1313.5469322044978	Time: 0:00:00.033491
Epoch: [1/50]	Samples: [205/10250]	Validation Loss: 1152.2649437881098	Time: 0:00:00.003875
Epoch: [2/50]	Samples: [1912/47800]	Train Loss: 1299.7031495162134	Time: 0:00:00.029481
Epoch: [2/50]	Samples: [205/10250]	Validation Loss: 1151.361513910061	Time: 0:00:00.002913
Epoch: [3/50]	Samples: [2868/47800]	Train Loss: 1298.8786692599372	Time: 0:00:00.023132
Epoch: [3/50]	Samples: [205/10250]	Validation Loss: 1148.8041730182927	Time: 0:00:00.002213
Epoch: [4/50]	Samples: [3824/47800]	Train Loss: 1279.2679458682007	Time: 0:00:00.027659
Epoch: [4/50]	Samples: [205/10250]	Validation Loss: 1145.9502334222561	Time: 0:00:00.003559
Epoch: [5/50]	Samples: [4780/47800]	Train Loss: 1263.526714500523	Time: 0:00:00.025383
Epoch: [5/50]	Samples: [205/10250]	Validation Loss: 1143.9386766387195	Time: 0:00:00.002245
Epoch: [6/50]	Samples: [5736/47800]	Train Loss: 1256.084875130753	Time: 0:00:00.027085
Epoch: [6/50]	Sa


Searching NeuralLDA Params:  12%|█▎        | 2/16 [00:04<00:29,  2.08s/it][A

Epoch: [1/100]	Samples: [956/95600]	Train Loss: 1316.3228049816946	Time: 0:00:00.036850
Epoch: [1/100]	Samples: [205/20500]	Validation Loss: 1157.9471370045733	Time: 0:00:00.003166
Epoch: [2/100]	Samples: [1912/95600]	Train Loss: 1295.2522391474895	Time: 0:00:00.044960
Epoch: [2/100]	Samples: [205/20500]	Validation Loss: 1155.2202076981707	Time: 0:00:00.003556
Epoch: [3/100]	Samples: [2868/95600]	Train Loss: 1263.5988902327406	Time: 0:00:00.027017
Epoch: [3/100]	Samples: [205/20500]	Validation Loss: 1149.2368997713415	Time: 0:00:00.002867
Epoch: [4/100]	Samples: [3824/95600]	Train Loss: 1268.0293785957113	Time: 0:00:00.026457
Epoch: [4/100]	Samples: [205/20500]	Validation Loss: 1144.3816930259147	Time: 0:00:00.002993
Epoch: [5/100]	Samples: [4780/95600]	Train Loss: 1255.5929654811716	Time: 0:00:00.026317
Epoch: [5/100]	Samples: [205/20500]	Validation Loss: 1143.270274390244	Time: 0:00:00.003889
Epoch: [6/100]	Samples: [5736/95600]	Train Loss: 1275.387290794979	Time: 0:00:00.026605
Epoc


Searching NeuralLDA Params:  19%|█▉        | 3/16 [00:06<00:29,  2.28s/it][A

Epoch: [1/100]	Samples: [956/95600]	Train Loss: 1330.0997483002093	Time: 0:00:00.028081
Epoch: [1/100]	Samples: [205/20500]	Validation Loss: 1151.9510480182928	Time: 0:00:00.004087
Epoch: [2/100]	Samples: [1912/95600]	Train Loss: 1287.77886375523	Time: 0:00:00.026642
Epoch: [2/100]	Samples: [205/20500]	Validation Loss: 1150.9565548780488	Time: 0:00:00.002353
Epoch: [3/100]	Samples: [2868/95600]	Train Loss: 1284.7082488885983	Time: 0:00:00.027672
Epoch: [3/100]	Samples: [205/20500]	Validation Loss: 1147.824776105183	Time: 0:00:00.002369
Epoch: [4/100]	Samples: [3824/95600]	Train Loss: 1301.4014938546024	Time: 0:00:00.026162
Epoch: [4/100]	Samples: [205/20500]	Validation Loss: 1145.2487709603658	Time: 0:00:00.002611
Epoch: [5/100]	Samples: [4780/95600]	Train Loss: 1267.4022865455022	Time: 0:00:00.029305
Epoch: [5/100]	Samples: [205/20500]	Validation Loss: 1143.6507621951218	Time: 0:00:00.002274
Epoch: [6/100]	Samples: [5736/95600]	Train Loss: 1314.437181289226	Time: 0:00:00.026240
Epoch:


Searching NeuralLDA Params:  25%|██▌       | 4/16 [00:08<00:25,  2.16s/it][A

Epoch: [1/50]	Samples: [956/47800]	Train Loss: 1213.783325706067	Time: 0:00:00.032014
Epoch: [1/50]	Samples: [205/10250]	Validation Loss: 1152.8027677210366	Time: 0:00:00.003394
Epoch: [2/50]	Samples: [1912/47800]	Train Loss: 1199.4566308185147	Time: 0:00:00.028023
Epoch: [2/50]	Samples: [205/10250]	Validation Loss: 1153.7237328506098	Time: 0:00:00.003637
Epoch: [3/50]	Samples: [2868/47800]	Train Loss: 1205.0922463389122	Time: 0:00:00.030235
Epoch: [3/50]	Samples: [205/10250]	Validation Loss: 1155.8576838795732	Time: 0:00:00.002915
Epoch: [4/50]	Samples: [3824/47800]	Train Loss: 1210.2283603556486	Time: 0:00:00.026746
Epoch: [4/50]	Samples: [205/10250]	Validation Loss: 1157.968221227134	Time: 0:00:00.003736
Epoch: [5/50]	Samples: [4780/47800]	Train Loss: 1194.4533946783472	Time: 0:00:00.028759
Epoch: [5/50]	Samples: [205/10250]	Validation Loss: 1156.2775247713414	Time: 0:00:00.002533
Epoch: [6/50]	Samples: [5736/47800]	Train Loss: 1211.6824905203976	Time: 0:00:00.029109
Epoch: [6/50]	S


Searching NeuralLDA Params:  31%|███▏      | 5/16 [00:10<00:22,  2.03s/it][A

Epoch: [1/50]	Samples: [956/47800]	Train Loss: 1215.5511408211298	Time: 0:00:00.032062
Epoch: [1/50]	Samples: [205/10250]	Validation Loss: 1152.567830602134	Time: 0:00:00.003454
Epoch: [2/50]	Samples: [1912/47800]	Train Loss: 1226.7707570606694	Time: 0:00:00.030504
Epoch: [2/50]	Samples: [205/10250]	Validation Loss: 1152.7583031631098	Time: 0:00:00.002344
Epoch: [3/50]	Samples: [2868/47800]	Train Loss: 1231.1626405596235	Time: 0:00:00.028771
Epoch: [3/50]	Samples: [205/10250]	Validation Loss: 1155.2747141768293	Time: 0:00:00.002729
Epoch: [4/50]	Samples: [3824/47800]	Train Loss: 1186.3028978164225	Time: 0:00:00.028011
Epoch: [4/50]	Samples: [205/10250]	Validation Loss: 1155.8814977134145	Time: 0:00:00.002852
Epoch: [5/50]	Samples: [4780/47800]	Train Loss: 1201.580650169979	Time: 0:00:00.028497
Epoch: [5/50]	Samples: [205/10250]	Validation Loss: 1156.8998285060975	Time: 0:00:00.003018
Epoch: [6/50]	Samples: [5736/47800]	Train Loss: 1196.8239490716528	Time: 0:00:00.025936
Epoch: [6/50]	S


Searching NeuralLDA Params:  38%|███▊      | 6/16 [00:12<00:19,  1.97s/it][A

Epoch: [1/100]	Samples: [956/95600]	Train Loss: 1209.451049293933	Time: 0:00:00.027245
Epoch: [1/100]	Samples: [205/20500]	Validation Loss: 1152.3001810213414	Time: 0:00:00.003059
Epoch: [2/100]	Samples: [1912/95600]	Train Loss: 1214.8429001046024	Time: 0:00:00.027376
Epoch: [2/100]	Samples: [205/20500]	Validation Loss: 1153.2921589176829	Time: 0:00:00.002412
Epoch: [3/100]	Samples: [2868/95600]	Train Loss: 1196.8736597803347	Time: 0:00:00.026345
Epoch: [3/100]	Samples: [205/20500]	Validation Loss: 1157.9451410060976	Time: 0:00:00.002358
Epoch: [4/100]	Samples: [3824/95600]	Train Loss: 1190.3820279811716	Time: 0:00:00.027126
Epoch: [4/100]	Samples: [205/20500]	Validation Loss: 1157.5220179115854	Time: 0:00:00.002711
Epoch: [5/100]	Samples: [4780/95600]	Train Loss: 1204.5756406903765	Time: 0:00:00.026586
Epoch: [5/100]	Samples: [205/20500]	Validation Loss: 1155.458350800305	Time: 0:00:00.002833
Epoch: [6/100]	Samples: [5736/95600]	Train Loss: 1212.5296973064853	Time: 0:00:00.025484
Epoc


Searching NeuralLDA Params:  44%|████▍     | 7/16 [00:14<00:18,  2.00s/it][A

Epoch: [1/100]	Samples: [956/95600]	Train Loss: 1217.6855877353557	Time: 0:00:00.038017
Epoch: [1/100]	Samples: [205/20500]	Validation Loss: 1152.2216653963415	Time: 0:00:00.002880
Epoch: [2/100]	Samples: [1912/95600]	Train Loss: 1219.3607887683054	Time: 0:00:00.028300
Epoch: [2/100]	Samples: [205/20500]	Validation Loss: 1152.3012480945122	Time: 0:00:00.002664
Epoch: [3/100]	Samples: [2868/95600]	Train Loss: 1203.4967965481171	Time: 0:00:00.025146
Epoch: [3/100]	Samples: [205/20500]	Validation Loss: 1153.0567549542684	Time: 0:00:00.003069
Epoch: [4/100]	Samples: [3824/95600]	Train Loss: 1206.5789994116108	Time: 0:00:00.029827
Epoch: [4/100]	Samples: [205/20500]	Validation Loss: 1153.9561261432927	Time: 0:00:00.003187
Epoch: [5/100]	Samples: [4780/95600]	Train Loss: 1212.5322143043934	Time: 0:00:00.026505
Epoch: [5/100]	Samples: [205/20500]	Validation Loss: 1151.312052210366	Time: 0:00:00.002858
Epoch: [6/100]	Samples: [5736/95600]	Train Loss: 1193.2700297463389	Time: 0:00:00.024632
Epo


Searching NeuralLDA Params:  50%|█████     | 8/16 [00:17<00:17,  2.24s/it][A

Epoch: [1/50]	Samples: [956/47800]	Train Loss: 1193.3189559361924	Time: 0:00:00.031533
Epoch: [1/50]	Samples: [205/10250]	Validation Loss: 1152.4163205030488	Time: 0:00:00.004556
Epoch: [2/50]	Samples: [1912/47800]	Train Loss: 1184.4515396182007	Time: 0:00:00.026602
Epoch: [2/50]	Samples: [205/10250]	Validation Loss: 1153.515086699695	Time: 0:00:00.005252
Epoch: [3/50]	Samples: [2868/47800]	Train Loss: 1167.511334662657	Time: 0:00:00.027295
Epoch: [3/50]	Samples: [205/10250]	Validation Loss: 1154.8846131859757	Time: 0:00:00.002721
Epoch: [4/50]	Samples: [3824/47800]	Train Loss: 1181.1899189330543	Time: 0:00:00.031810
Epoch: [4/50]	Samples: [205/10250]	Validation Loss: 1153.5417540015244	Time: 0:00:00.002627
Epoch: [5/50]	Samples: [4780/47800]	Train Loss: 1167.22164291318	Time: 0:00:00.030593
Epoch: [5/50]	Samples: [205/10250]	Validation Loss: 1146.6237280868902	Time: 0:00:00.004032
Epoch: [6/50]	Samples: [5736/47800]	Train Loss: 1159.5073221757323	Time: 0:00:00.028439
Epoch: [6/50]	Sam


Searching NeuralLDA Params:  56%|█████▋    | 9/16 [00:19<00:16,  2.35s/it][A

Epoch: [1/50]	Samples: [956/47800]	Train Loss: 1191.6721773666318	Time: 0:00:00.035478
Epoch: [1/50]	Samples: [205/10250]	Validation Loss: 1152.2186785442073	Time: 0:00:00.003952
Epoch: [2/50]	Samples: [1912/47800]	Train Loss: 1196.5241157819037	Time: 0:00:00.027747
Epoch: [2/50]	Samples: [205/10250]	Validation Loss: 1152.3397389481706	Time: 0:00:00.002592
Epoch: [3/50]	Samples: [2868/47800]	Train Loss: 1178.770601791318	Time: 0:00:00.025125
Epoch: [3/50]	Samples: [205/10250]	Validation Loss: 1153.5943025914635	Time: 0:00:00.002552
Epoch: [4/50]	Samples: [3824/47800]	Train Loss: 1173.7351799490064	Time: 0:00:00.026812
Epoch: [4/50]	Samples: [205/10250]	Validation Loss: 1154.8837938262195	Time: 0:00:00.003099
Epoch: [5/50]	Samples: [4780/47800]	Train Loss: 1183.686372254184	Time: 0:00:00.027574
Epoch: [5/50]	Samples: [205/10250]	Validation Loss: 1152.6774199695121	Time: 0:00:00.002906
Epoch: [6/50]	Samples: [5736/47800]	Train Loss: 1166.60648208682	Time: 0:00:00.025170
Epoch: [6/50]	Sam


Searching NeuralLDA Params:  62%|██████▎   | 10/16 [00:21<00:13,  2.30s/it][A

Epoch: [1/100]	Samples: [956/95600]	Train Loss: 1192.7949382191423	Time: 0:00:00.036555
Epoch: [1/100]	Samples: [205/20500]	Validation Loss: 1152.0147103658537	Time: 0:00:00.003385
Epoch: [2/100]	Samples: [1912/95600]	Train Loss: 1174.0873594403765	Time: 0:00:00.032159
Epoch: [2/100]	Samples: [205/20500]	Validation Loss: 1152.0216558689024	Time: 0:00:00.002783
Epoch: [3/100]	Samples: [2868/95600]	Train Loss: 1174.9910107217572	Time: 0:00:00.031550
Epoch: [3/100]	Samples: [205/20500]	Validation Loss: 1150.7003429878048	Time: 0:00:00.003707
Epoch: [4/100]	Samples: [3824/95600]	Train Loss: 1162.311372254184	Time: 0:00:00.034425
Epoch: [4/100]	Samples: [205/20500]	Validation Loss: 1146.658565167683	Time: 0:00:00.003223
Epoch: [5/100]	Samples: [4780/95600]	Train Loss: 1170.3130883891213	Time: 0:00:00.030329
Epoch: [5/100]	Samples: [205/20500]	Validation Loss: 1143.323466082317	Time: 0:00:00.004665
Epoch: [6/100]	Samples: [5736/95600]	Train Loss: 1159.567975287657	Time: 0:00:00.037272
Epoch:


Searching NeuralLDA Params:  69%|██████▉   | 11/16 [00:24<00:12,  2.49s/it][A

Epoch: [1/100]	Samples: [956/95600]	Train Loss: 1193.455650169979	Time: 0:00:00.047914
Epoch: [1/100]	Samples: [205/20500]	Validation Loss: 1152.5327124618902	Time: 0:00:00.002869
Epoch: [2/100]	Samples: [1912/95600]	Train Loss: 1189.3061911610878	Time: 0:00:00.029881
Epoch: [2/100]	Samples: [205/20500]	Validation Loss: 1152.9272675304878	Time: 0:00:00.003523
Epoch: [3/100]	Samples: [2868/95600]	Train Loss: 1190.3814722803347	Time: 0:00:00.030491
Epoch: [3/100]	Samples: [205/20500]	Validation Loss: 1153.184922827744	Time: 0:00:00.002513
Epoch: [4/100]	Samples: [3824/95600]	Train Loss: 1181.8481138859834	Time: 0:00:00.028982
Epoch: [4/100]	Samples: [205/20500]	Validation Loss: 1153.5223561356706	Time: 0:00:00.002582
Epoch: [5/100]	Samples: [4780/95600]	Train Loss: 1175.340726333682	Time: 0:00:00.027055
Epoch: [5/100]	Samples: [205/20500]	Validation Loss: 1151.7985137195121	Time: 0:00:00.002409
Epoch: [6/100]	Samples: [5736/95600]	Train Loss: 1174.6892120489017	Time: 0:00:00.030499
Epoch


Searching NeuralLDA Params:  75%|███████▌  | 12/16 [00:28<00:11,  2.91s/it][A

Epoch: [1/50]	Samples: [956/47800]	Train Loss: 1181.519719207636	Time: 0:00:00.031503
Epoch: [1/50]	Samples: [205/10250]	Validation Loss: 1153.0141149009146	Time: 0:00:00.003005
Epoch: [2/50]	Samples: [1912/47800]	Train Loss: 1183.769972541841	Time: 0:00:00.030524
Epoch: [2/50]	Samples: [205/10250]	Validation Loss: 1154.2534632240854	Time: 0:00:00.003786
Epoch: [3/50]	Samples: [2868/47800]	Train Loss: 1164.0793548967051	Time: 0:00:00.040652
Epoch: [3/50]	Samples: [205/10250]	Validation Loss: 1155.7674256859757	Time: 0:00:00.009230
Epoch: [4/50]	Samples: [3824/47800]	Train Loss: 1167.0504870554394	Time: 0:00:00.033498
Epoch: [4/50]	Samples: [205/10250]	Validation Loss: 1154.836794969512	Time: 0:00:00.002834
Epoch: [5/50]	Samples: [4780/47800]	Train Loss: 1155.7491337604602	Time: 0:00:00.027757
Epoch: [5/50]	Samples: [205/10250]	Validation Loss: 1151.8128715701218	Time: 0:00:00.004063
Epoch: [6/50]	Samples: [5736/47800]	Train Loss: 1158.3637960904812	Time: 0:00:00.035497
Epoch: [6/50]	Sa


Searching NeuralLDA Params:  81%|████████▏ | 13/16 [00:32<00:09,  3.08s/it][A

Epoch: [1/50]	Samples: [956/47800]	Train Loss: 1191.6138287787658	Time: 0:00:00.029720
Epoch: [1/50]	Samples: [205/10250]	Validation Loss: 1152.4995665015244	Time: 0:00:00.003618
Epoch: [2/50]	Samples: [1912/47800]	Train Loss: 1177.6419979079499	Time: 0:00:00.031485
Epoch: [2/50]	Samples: [205/10250]	Validation Loss: 1154.1273866234756	Time: 0:00:00.002628
Epoch: [3/50]	Samples: [2868/47800]	Train Loss: 1166.5060269024582	Time: 0:00:00.028919
Epoch: [3/50]	Samples: [205/10250]	Validation Loss: 1155.6586270960365	Time: 0:00:00.003811
Epoch: [4/50]	Samples: [3824/47800]	Train Loss: 1163.6949447567993	Time: 0:00:00.027289
Epoch: [4/50]	Samples: [205/10250]	Validation Loss: 1155.5199266387194	Time: 0:00:00.002279
Epoch: [5/50]	Samples: [4780/47800]	Train Loss: 1161.1521394482218	Time: 0:00:00.027997
Epoch: [5/50]	Samples: [205/10250]	Validation Loss: 1152.0209651295731	Time: 0:00:00.002659
Epoch: [6/50]	Samples: [5736/47800]	Train Loss: 1165.0869263206066	Time: 0:00:00.032422
Epoch: [6/50]


Searching NeuralLDA Params:  88%|████████▊ | 14/16 [00:35<00:06,  3.21s/it][A

Epoch: [1/100]	Samples: [956/95600]	Train Loss: 1187.6001977641213	Time: 0:00:00.033870
Epoch: [1/100]	Samples: [205/20500]	Validation Loss: 1152.492978277439	Time: 0:00:00.004436
Epoch: [2/100]	Samples: [1912/95600]	Train Loss: 1172.8392308446653	Time: 0:00:00.024221
Epoch: [2/100]	Samples: [205/20500]	Validation Loss: 1155.3724561737804	Time: 0:00:00.002724
Epoch: [3/100]	Samples: [2868/95600]	Train Loss: 1175.3372777196653	Time: 0:00:00.025833
Epoch: [3/100]	Samples: [205/20500]	Validation Loss: 1155.4211937881098	Time: 0:00:00.003093
Epoch: [4/100]	Samples: [3824/95600]	Train Loss: 1168.3710692337866	Time: 0:00:00.024825
Epoch: [4/100]	Samples: [205/20500]	Validation Loss: 1148.7686070884147	Time: 0:00:00.002825
Epoch: [5/100]	Samples: [4780/95600]	Train Loss: 1153.0177497384936	Time: 0:00:00.029301
Epoch: [5/100]	Samples: [205/20500]	Validation Loss: 1142.171660632622	Time: 0:00:00.002624
Epoch: [6/100]	Samples: [5736/95600]	Train Loss: 1158.1183479341005	Time: 0:00:00.026895
Epoc


Searching NeuralLDA Params:  94%|█████████▍| 15/16 [00:38<00:03,  3.13s/it][A

Epoch: [1/100]	Samples: [956/95600]	Train Loss: 1179.2978066161088	Time: 0:00:00.028482
Epoch: [1/100]	Samples: [205/20500]	Validation Loss: 1152.2553925304878	Time: 0:00:00.002975
Epoch: [2/100]	Samples: [1912/95600]	Train Loss: 1174.1072543475418	Time: 0:00:00.026631
Epoch: [2/100]	Samples: [205/20500]	Validation Loss: 1152.8133431783538	Time: 0:00:00.002697
Epoch: [3/100]	Samples: [2868/95600]	Train Loss: 1172.6026739016736	Time: 0:00:00.025888
Epoch: [3/100]	Samples: [205/20500]	Validation Loss: 1154.071169969512	Time: 0:00:00.002614
Epoch: [4/100]	Samples: [3824/95600]	Train Loss: 1164.665320998954	Time: 0:00:00.030206
Epoch: [4/100]	Samples: [205/20500]	Validation Loss: 1152.7749666539635	Time: 0:00:00.002397
Epoch: [5/100]	Samples: [4780/95600]	Train Loss: 1161.8813496992677	Time: 0:00:00.024623
Epoch: [5/100]	Samples: [205/20500]	Validation Loss: 1147.6921303353658	Time: 0:00:00.002566
Epoch: [6/100]	Samples: [5736/95600]	Train Loss: 1167.299776085251	Time: 0:00:00.026402
Epoch


Searching NeuralLDA Params: 100%|██████████| 16/16 [00:41<00:00,  2.58s/it][A
Overall Model Parameter Search:  80%|████████  | 4/5 [02:25<00:36, 36.21s/it]

Best NeuralLDA Params: {'num_topics': 2, 'num_epochs': 50, 'lr': 0.001} with Score: 0.776061086940576



Searching NMF Params:   0%|          | 0/4 [00:00<?, ?it/s][A
Searching NMF Params:  25%|██▌       | 1/4 [00:01<00:05,  1.75s/it][A
Searching NMF Params:  50%|█████     | 2/4 [00:03<00:03,  1.71s/it][A
Searching NMF Params:  75%|███████▌  | 3/4 [00:05<00:01,  1.80s/it][A
Searching NMF Params: 100%|██████████| 4/4 [00:07<00:00,  1.80s/it][A
Overall Model Parameter Search: 100%|██████████| 5/5 [02:32<00:00, 30.49s/it]

Best NMF Params: {'num_topics': 2, 'random_state': 42} with Score: 0.6035502327204885





In [86]:
# Summarize results
import pandas as pd

summary = []
for model_name, model_info in best_models.items():
    summary.append({
        'Model': model_name,
        'Best_Params': model_info['params'],
        'Best_Score': model_info['score']
    })

summary_df = pd.DataFrame(summary)
print(summary_df)

       Model                                        Best_Params  Best_Score
0        LDA  {'num_topics': 2, 'iterations': 500, 'random_s...    0.601787
1        CTM                 {'num_topics': 3, 'num_epochs': 5}    0.617309
2        ETM                {'num_topics': 5, 'num_epochs': 50}    0.639376
3  NeuralLDA   {'num_topics': 2, 'num_epochs': 50, 'lr': 0.001}    0.776061
4        NMF              {'num_topics': 2, 'random_state': 42}    0.603550


In [95]:
best_models['LDA']

{'output': {'topic-word-matrix': array([[0.00185909, 0.00224993, 0.00613317, ..., 0.00109303, 0.00014323,
          0.00078758],
         [0.00183788, 0.00327841, 0.00513837, ..., 0.00229301, 0.00082292,
          0.00063022]], dtype=float32),
  'topics': [['nigga',
    'make',
    'fuck',
    'man',
    'shit',
    'let',
    'one',
    'want',
    'time',
    'come'],
   ['nigga',
    'want',
    'say',
    'see',
    'one',
    'make',
    'back',
    'come',
    'shit',
    'take']],
  'topic-document-matrix': array([[0.99390554, 0.97734255, 0.49285874, ..., 0.6869083 , 0.07091396,
          0.02274054],
         [0.00609445, 0.02265745, 0.50714129, ..., 0.3130917 , 0.92908597,
          0.97725952]]),
  'test-topic-document-matrix': array([[0.        , 0.74635935, 0.99336243, 0.10644579, 0.7301811 ,
          0.08416618, 0.91168082, 0.35630509, 0.23246025, 0.82655567,
          0.98141748, 0.6418463 , 0.48036799, 0.03879594, 0.61559248,
          0.79487765, 0.21062225, 0.29875869

In [96]:
best_models['CTM']

{'output': {'topics': [['double',
    'somethin',
    'hear',
    'thing',
    'world',
    'whatever',
    'stress',
    'bass',
    'ho',
    'battle'],
   ['take',
    'well',
    'little',
    'as',
    'tell',
    'guy',
    'another',
    'brother',
    'nothing',
    'wit'],
   ['clean',
    'another',
    'huh',
    'bass',
    'great',
    'blue',
    'flow',
    'makin',
    'ghetto',
    'cube']],
  'topic-document-matrix': array([[0.08450763, 0.317814  , 0.33435184, ..., 0.25211883, 0.08910767,
          0.22139151],
         [0.33204313, 0.16490646, 0.07378573, ..., 0.12000985, 0.31651188,
          0.18019159],
         [0.58344924, 0.51727953, 0.59186239, ..., 0.62787131, 0.59438041,
          0.59841692]]),
  'topic-word-matrix': array([[ 0.0170679 ,  0.03369807, -0.04947468, ..., -0.12947851,
          -0.1389453 , -0.06324163],
         [-0.02002003, -0.06443809, -0.00666221, ..., -0.022369  ,
          -0.01695826,  0.00541135],
         [-0.08495021, -0.09038385, -0