In [1]:
from contextualized_topic_models.models.ctm import CombinedTM
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
from contextualized_topic_models.utils.preprocessing import WhiteSpacePreprocessing
import nltk

In [2]:
nltk.download('stopwords')

data_path = "../data/"

documents = [line.strip() for line in open(data_path + "dbpedia_sample_abstract_20k_unprep.txt", "r").readlines()]
sp = WhiteSpacePreprocessing(documents, stopwords_language='english') 
preprocessed_documents, unpreprocessed_corpus, vocab = sp.preprocess()

print(preprocessed_documents[:3])
print(unpreprocessed_corpus[:3])
print(len(vocab))

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/godpeny/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


['mid peninsula highway proposed across peninsula canadian province ontario although highway connecting hamilton fort south international study published ministry', 'died march american photographer specialized photography operated studio silver spring maryland later lived florida magazine photographer year', 'henry howard august august british peer son henry howard father died march behind became']
['The Mid-Peninsula Highway is a proposed freeway across the Niagara Peninsula in the Canadian province of Ontario. Although plans for a highway connecting Hamilton to Fort Erie south of the Niagara Escarpment have surfaced for decades,it was not until The Niagara Frontier International Gateway Study was published by the Ministry', "Monte Zucker (died March 15, 2007) was an American photographer. He specialized in wedding photography, entering it as a profession in 1947. In the 1970s he operated a studio in Silver Spring, Maryland. Later he lived in Florida. He was Brides Magazine's Wedding

# Modeling

In [3]:
tp = TopicModelDataPreparation("paraphrase-distilroberta-base-v1")
training_dataset = tp.fit(text_for_contextual=unpreprocessed_corpus, text_for_bow=preprocessed_documents)
ctm = CombinedTM(bow_size=len(tp.vocab), contextual_size=768, n_components=50, num_epochs=5)

ctm.fit(training_dataset)

Batches:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch: [10/10]	 Seen Samples: [200000/200000]	Train Loss: 136.5857074951172	Time: 0:01:03.019182: : 10it [10:32, 63.27s/it]


In [4]:
ctm.get_topic_lists(5)

[['season', 'league', 'team', 'tournament', 'football'],
 ['town', 'county', 'located', 'census', 'city'],
 ['son', 'de', 'french', 'wife', 'daughter'],
 ['name', 'king', 'church', 'century', 'roman'],
 ['school', 'high', 'house', 'built', 'students'],
 ['born', 'played', 'former', 'player', 'made'],
 ['railway', 'line', 'company', 'service', 'services'],
 ['century', 'greek', 'ancient', 'king', 'period'],
 ['published', 'game', 'developed', 'magazine', 'video'],
 ['album', 'american', 'music', 'band', 'released'],
 ['university', 'professor', 'born', 'served', 'american'],
 ['series', 'film', 'produced', 'directed', 'written'],
 ['album', 'released', 'band', 'studio', 'music'],
 ['family', 'species', 'found', 'genus', 'mm'],
 ['american', 'born', 'played', 'university', 'former'],
 ['member', 'politician', 'party', 'elected', 'served'],
 ['built', 'building', 'house', 'story', 'historic'],
 ['series', 'published', 'book', 'american', 'television'],
 ['world', 'summer', 'competed', 'ol

# Prediction

In [10]:
import numpy as np

In [6]:
topics_predictions = ctm.get_thetas(training_dataset, n_samples=3) 

Sampling: [3/3]: : 3it [02:59, 60.00s/it]


In [20]:
print(preprocessed_documents[0])

mid peninsula highway proposed across peninsula canadian province ontario although highway connecting hamilton fort south international study published ministry


In [17]:
print(topics_predictions)
topic_idx = np.argmax(topics_predictions[0])

print(topic_idx)

[[0.01077348 0.01356979 0.00509916 ... 0.03638788 0.0092768  0.01338069]
 [0.00246786 0.00571247 0.00603287 ... 0.00244921 0.0123746  0.06028802]
 [0.00385155 0.00399819 0.26056468 ... 0.00345671 0.00545027 0.03472301]
 ...
 [0.00311637 0.02559119 0.00579893 ... 0.00491325 0.04074635 0.01217081]
 [0.00881516 0.00299022 0.03167975 ... 0.00859608 0.10077539 0.00538099]
 [0.00325098 0.00294681 0.00615367 ... 0.00189291 0.00317795 0.02111261]]
40


In [21]:
ctm.get_topic_lists(10)[topic_idx]

['station',
 'located',
 'city',
 'railway',
 'line',
 'road',
 'airport',
 'owned',
 'street',
 'mill']