In [None]:
import pickle
import torch
from transformers import BertModel
from tqdm import tqdm 
from octis.preprocessing.preprocessing import Preprocessing
from sentence_transformers import SentenceTransformer
from Code.TNTM.TNTM_SentenceTransformer import TNTM_SentenceTransformer
import numpy as np
import os
import pandas as pd
import functools
import operator

In [31]:
# this code allows to skip the SSLerror when useing sentenceTransformer
import requests
from huggingface_hub import configure_http_backend

def backend_factory() -> requests.Session:
    session = requests.Session()
    session.verify = False
    return session

configure_http_backend(backend_factory=backend_factory)

In [None]:
torch.manual_seed(41)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'current device: {device}')

In [33]:
cleaned_data_path = "Data/Preprocessed_Data"
file_path = os.path.join(cleaned_data_path, "email2words.pickle")

with open(file_path, "rb") as f:
    corpus = pickle.load(f)

In [None]:
# create list of unique words
vocab = []
for each_lst in corpus:
    vocab.extend(each_lst)

vocab = list(set(vocab))
print(f"Total number of unique words: {len(vocab)}")

In [None]:
# load the pretrained sentenceTransformer for mulit language
modelPath = "Data/Bert_Embedder"
embeddings_model = SentenceTransformer(modelPath)

In [None]:
# # embeddings of unique words
# unique_words_embeddings = [embeddings_model.encode(word) for word in tqdm(vocab)]
# unique_words_embeddings = torch.Tensor(unique_words_embeddings)

# cleaned_data_path = "Data/Preprocessed_Data"
# file_path = os.path.join(cleaned_data_path, "vocab_embeddings.pickle")

# with open(file_path, "wb") as file:
#     pickle.dump(unique_words_embeddings, file)
# unique_words_embeddings.shape

In [None]:
cleaned_data_path = "Data/Preprocessed_Data"
file_path = os.path.join(cleaned_data_path, "vocab_embeddings.pickle")

with open(file_path, "rb") as file:
    vocab_embeddings = pickle.load(file)
vocab_embeddings.shape

In [None]:
# embeddings of unique words
cleaned_data_path = "Data/Preprocessed_Data"
file_path = os.path.join(cleaned_data_path, "email_embeddings.pickle")

with open(file_path, "rb") as file:
    sentence_embedding = pickle.load(file)
sentence_embedding.shape

In [39]:

tntm = TNTM_SentenceTransformer(
    n_topics = 20,
    save_path = f"Data/Auxillary_Data/{20}_topics",
    n_dims = 11,
    n_hidden_units = 200,
    n_encoder_layers = 2,
    enc_lr = 1e-3,
    dec_lr = 1e-3,
    n_epochs = 1000,
    batch_size = 128,
    #batch_size = 256,
    dropout_rate_encoder = 0.3,
    prior_variance =  0.995, 
    prior_mean = None,
    n_topwords = 200,
    device = device, 
    validation_set_size = 0.2, 
    early_stopping = True,
    n_epochs_early_stopping = 15,
)

In [None]:
result = tntm.fit(
              corpus              = corpus,
              vocab               = vocab, 
              word_embeddings     = vocab_embeddings,
              document_embeddings = sentence_embedding)

In [41]:
weights = result[1]
# normalize weights for each corresponding unique word
normalize_weights = weights/weights.sum(axis=1, keepdims=True)

In [None]:
# Select top-k words for each cluster
top_k = 5
top_words_per_cluster = []
for cluster_idx in range(normalize_weights.shape[0]):  # Iterate over clusters
    # Get weights for all words in the cluster
    word_weights = normalize_weights[cluster_idx]
    
    # Get indices of the top-k words
    top_k_indices = word_weights.argsort()[-top_k:][::-1]
    
    # Map indices to words using resulttt[0]
    top_words = [result[0][cluster_idx][i] for i in top_k_indices]
    top_words_per_cluster.append(top_words)

# Print the top-k words for each cluster
for cluster_idx, words in enumerate(top_words_per_cluster):
    print(f"Cluster {cluster_idx + 1}: {words}")