In [1]:
import pandas as pd
import numpy as np

from contextualized_topic_models.models.ctm import CTM
from contextualized_topic_models.utils.data_preparation import TextHandler, bert_embeddings_from_file
from contextualized_topic_models.datasets.dataset import CTMDataset

In [2]:
load_path="data/glassdoor/glassdoor_topics.parquet"
save_path="data/glassdoor/glassdoor_sentences.txt"
test_path="data/glassdoor/glassdoor_sentences_abridged.txt"

In [3]:
# df = pd.read_parquet(load_path)

In [4]:
# np.savetxt(save_path, df['text'].values, fmt='%s')

In [5]:
# np.savetxt(test_path, df['text'].head(100000).values, fmt='%s')

In [6]:
handler = TextHandler(test_path)
handler.prepare() # create vocabulary and training data

In [7]:
# generate BERT data
training_bert = bert_embeddings_from_file(test_path, "bert-base-nli-mean-tokens")

In [7]:
bert_encodings = "data/glassdoor/bert_encodings.npy"

In [9]:
# np.save(bert_encodings, training_bert)
training_bert = np.load(bert_encodings)

In [8]:
training_dataset = CTMDataset(handler.bow, training_bert, handler.idx2token)

In [9]:
# training CTM
ctm = CTM(
    input_size=len(handler.vocab), 
    bert_input_size=768, 
    inference_type="combined", 
    n_components=10,
    num_data_loader_workers=1,
    num_epochs=10,
    use_gpu=True
)

In [10]:
ctm.fit(training_dataset, save_dir="models/CTM/01") 

Settings: 
               N Components: 50
               Topic Prior Mean: 0.0
               Topic Prior Variance: 0.98
               Model Type: prodLDA
               Hidden Sizes: (100, 100)
               Activation: softplus
               Dropout: 0.2
               Learn Priors: True
               Learning Rate: 0.002
               Momentum: 0.99
               Reduce On Plateau: False
               Save Dir: models/CTM/01
Epoch: [1/10]	Samples: [100637/1006370]	Train Loss: 90.95055373976753	Time: 0:01:20.846262
Epoch: [2/10]	Samples: [201274/1006370]	Train Loss: 88.00021321700301	Time: 0:01:21.818298
Epoch: [3/10]	Samples: [301911/1006370]	Train Loss: 87.3909494954639	Time: 0:01:21.253848
Epoch: [4/10]	Samples: [402548/1006370]	Train Loss: 87.04632548769806	Time: 0:01:20.516218
Epoch: [5/10]	Samples: [503185/1006370]	Train Loss: 86.8078568085747	Time: 0:01:20.250964
Epoch: [6/10]	Samples: [603822/1006370]	Train Loss: 86.62441147594899	Time: 0:01:20.467594
Epoch: [7/10]	Sa

In [13]:
# getting the topics
ctm.get_topic_lists(5)[0:10]

[['years', 'years.', '3', 'months', '2'],
 ['joke.', 'pieces', 'Pay.', 'lie.', 'parking.'],
 ['need', 'make', 'actually', 'employees', 'hard'],
 ["I'm", 'worst', 'wish', 'for.', 'there.'],
 ['help', 'culture', 'willing', 'truly', 'clients'],
 ['doing.', 'Do', 'doing!', 'Keep', "Don't"],
 ['place', 'It', 'This', 'experience', 'lot'],
 ['growing', 'learn', 'Lots', 'opportunity', 'Company'],
 ['Get', 'up!', 'Be', 'Keep', 'doing!'],
 ['benefits,', 'pay,', 'life', 'good', 'free']]