## Zeroshot Topic Models (ZTMs) 

*   List item
*   List item



Using ZTMs to produce topic models
 

## Preprocessing

to ensure the citations are stripped of html and preprocessed according to a pipeline 

In [1]:
!pip install contextualized_topic_models pyLDAvis scispacy
!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.0/en_core_sci_sm-0.5.0.tar.gz


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.0/en_core_sci_sm-0.5.0.tar.gz
  Using cached https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.0/en_core_sci_sm-0.5.0.tar.gz (15.9 MB)


In [2]:
from bs4 import BeautifulSoup
import string
import re

def remove_html(x):
    soup = BeautifulSoup(x, 'html.parser')
    tags = soup.find_all('cite')
    text = soup.get_text()
    text = text.replace('  ', '')
    text = text.replace(' et ', '')
    text = text.replace(' al ', '')
    text = text.replace(' et. ', '')
    text = text.replace(' al. ', '')
    return text

In [3]:
import scispacy
import spacy
import string

nlp = spacy.load("en_core_sci_sm")
nlp.add_pipe("merge_entities")

def merge_entities(x):
    doc = nlp(x)
    return " ".join([re.sub(r'\W+', '', t.text.strip().translate(str.maketrans('', '', string.punctuation)).replace(" ", "_").lower()) for t in doc])


In [4]:
import pandas as pd

citations = pd.read_csv('./example_doc_citations.csv')
citations['text'] = citations['text'].apply(
    lambda x : remove_html(x)
)
citations['ner_merged_text'] = citations['text'].apply(
    lambda x: merge_entities(x)
)

In [5]:
from gensim.corpora.dictionary import Dictionary

removal = ['ADV','PRON','CCONJ','PUNCT','PART','DET','ADP','SPACE', 'NUM', 'SYM']
remove_text = ['al', 'et', 'al.']
authors = [
    'mongeon',
    'hus'
]


unpreprocessed_corpus = citations['text']

preprocessed_documents = []
for doc in nlp.pipe(citations['ner_merged_text']):
    proj_tok = [token.lemma_.lower() for token in doc if "_" in token.text or (len(token.text) > 2 and token.text not in remove_text and token.pos_ not in removal and not token.is_stop and token.is_alpha)]
    proj_tok = [
        tok for tok in proj_tok
        if all([author not in tok for author in authors])
    ]
    preprocessed_documents.append(' '.join(proj_tok))
    
    
texts = [doc.split(' ') for doc in preprocessed_documents]
dictionary = Dictionary(texts)
dictionary.filter_extremes(no_below=2, no_above=0.9)


In [6]:
preprocessed_documents = [' '.join([dictionary[token[0]].replace("-", "_").replace("®", "_") for token in dictionary.doc2bow(text)]) for text in texts]

In [7]:
from contextualized_topic_models.models.ctm import ZeroShotTM, CombinedTM
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
from contextualized_topic_models.utils.preprocessing import WhiteSpacePreprocessingStopwords
import nltk
import torch
import random
import numpy as np

In [8]:
def fix_seeds():
    torch.manual_seed(10)
    torch.cuda.manual_seed(10)
    np.random.seed(10)
    random.seed(10)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True

In [9]:
from contextualized_topic_models.evaluation.measures import CoherenceNPMI, InvertedRBO

fix_seeds()

num_topics = [3, 5, 10, 15, 20, 25, 50, 100, 250]

corpus = [doc.split(' ') for doc in preprocessed_documents]
model_results = []
embedding_models = [
    "paraphrase-distilroberta-base-v2",
    "sentence-transformers/allenai-specter",
    "allenai/aspire-sentence-embedder",
    "allenai/aspire-contextualsentence-multim-compsci", 
]
for embedding_model in embedding_models:
    tp = TopicModelDataPreparation(embedding_model)
    training_dataset = tp.fit(text_for_contextual=unpreprocessed_corpus, text_for_bow=preprocessed_documents)
    for n_components in num_topics:
        print("num topics:", n_components)
        ztm = ZeroShotTM(bow_size=len(tp.vocab), contextual_size=768, 
                     n_components=n_components, num_epochs=50)
        ztm.fit(training_dataset, n_samples=20)
        coh = CoherenceNPMI(ztm.get_topic_lists(10), [doc.split(' ') for doc in preprocessed_documents])
        coh_score = coh.score()
        print("coherence score:", coh_score)
        diversity_score = InvertedRBO(ztm.get_topic_lists(10)).score()
        print("diversity score LDA:", diversity_score)
        model_results.append({
            "coherence": coh_score,
            "diversity": diversity_score,
            "num_topics": n_components,
            "model": ztm,
            "embedding_model": embedding_model,
            "tp": tp,
            "training_dataset": training_dataset
        })



Batches:   0%|          | 0/4 [00:00<?, ?it/s]



num topics: 3


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 173.12433085097976	Time: 0:00:00.353339: : 50it [00:17,  2.79it/s]
Sampling: [20/20]: : 20it [00:06,  3.12it/s]


coherence score: -0.294257671477365
diversity score LDA: 1.0
num topics: 5


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 172.4755294549028	Time: 0:00:00.359181: : 50it [00:18,  2.77it/s]
Sampling: [20/20]: : 20it [00:06,  3.12it/s]


coherence score: -0.1189763516683188
diversity score LDA: 0.9652398526257143
num topics: 10


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 171.97549087389382	Time: 0:00:00.450055: : 50it [00:19,  2.55it/s]
Sampling: [20/20]: : 20it [00:07,  2.82it/s]


coherence score: -0.10610156642756122
diversity score LDA: 0.9613596490592063
num topics: 15


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 173.50566554401075	Time: 0:00:00.351447: : 50it [00:18,  2.73it/s]
Sampling: [20/20]: : 20it [00:06,  3.03it/s]


coherence score: -0.07570320683168348
diversity score LDA: 0.9529943432771428
num topics: 20


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 174.41088898052308	Time: 0:00:00.359962: : 50it [00:18,  2.70it/s]
Sampling: [20/20]: : 20it [00:06,  3.07it/s]


coherence score: -0.06209996452827473
diversity score LDA: 0.9415406893859398
num topics: 25


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 176.7942456270741	Time: 0:00:00.353641: : 50it [00:18,  2.69it/s]
Sampling: [20/20]: : 20it [00:06,  3.03it/s]


coherence score: -0.07404288070831037
diversity score LDA: 0.9530382753302857
num topics: 50


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 187.4638474340234	Time: 0:00:00.354440: : 50it [00:19,  2.55it/s]
Sampling: [20/20]: : 20it [00:06,  2.95it/s]


coherence score: -0.04459871110348148
diversity score LDA: 0.9253880469338017
num topics: 100


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 208.3258681652971	Time: 0:00:00.370891: : 50it [00:18,  2.65it/s]
Sampling: [20/20]: : 20it [00:06,  2.88it/s]


coherence score: -0.08529750729723469
diversity score LDA: 0.9316158575508168
num topics: 250


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 275.80222263649654	Time: 0:00:00.374240: : 50it [00:18,  2.66it/s]
Sampling: [20/20]: : 20it [00:07,  2.82it/s]


coherence score: -0.18723874414914435
diversity score LDA: 0.9533949223157451




Batches:   0%|          | 0/4 [00:00<?, ?it/s]



num topics: 3


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 173.12694633474243	Time: 0:00:00.385194: : 50it [00:19,  2.54it/s]
Sampling: [20/20]: : 20it [00:07,  2.58it/s]


coherence score: -0.18193505555408676
diversity score LDA: 0.9809680855214286
num topics: 5


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 173.35869751748183	Time: 0:00:00.402600: : 50it [00:19,  2.53it/s]
Sampling: [20/20]: : 20it [00:07,  2.80it/s]


coherence score: -0.1503836781094218
diversity score LDA: 0.9769944960614285
num topics: 10


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 173.79831021353507	Time: 0:00:00.373762: : 50it [00:19,  2.56it/s]
Sampling: [20/20]: : 20it [00:07,  2.81it/s]


coherence score: -0.07818394294015443
diversity score LDA: 0.9617767256963492
num topics: 15


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 175.2343805556653	Time: 0:00:00.378905: : 50it [00:24,  2.00it/s]
Sampling: [20/20]: : 20it [00:07,  2.82it/s]


coherence score: -0.0419062886703115
diversity score LDA: 0.9355484757591837
num topics: 20


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 177.5326883000158	Time: 0:00:00.393716: : 50it [00:20,  2.41it/s]
Sampling: [20/20]: : 20it [00:07,  2.79it/s]


coherence score: -0.10302318631832794
diversity score LDA: 0.9504746185034586
num topics: 25


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 179.29714722759957	Time: 0:00:00.405937: : 50it [00:19,  2.51it/s]
Sampling: [20/20]: : 20it [00:07,  2.77it/s]


coherence score: -0.0718614175363379
diversity score LDA: 0.9252747233331905
num topics: 50


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 189.29381907543853	Time: 0:00:00.394177: : 50it [00:20,  2.47it/s]
Sampling: [20/20]: : 20it [00:07,  2.67it/s]


coherence score: -0.06589169292764963
diversity score LDA: 0.922158011268
num topics: 100


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 214.1371872777734	Time: 0:00:00.404369: : 50it [00:20,  2.44it/s]
Sampling: [20/20]: : 20it [00:08,  2.42it/s]


coherence score: -0.13470085053145056
diversity score LDA: 0.9323748521276811
num topics: 250


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 280.3263033590787	Time: 0:00:00.387842: : 50it [00:20,  2.47it/s]
Sampling: [20/20]: : 20it [00:07,  2.57it/s]


coherence score: -0.24930272252006086
diversity score LDA: 0.9593711264361051




Batches:   0%|          | 0/4 [00:00<?, ?it/s]



num topics: 3


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 173.08108462636298	Time: 0:00:00.389857: : 50it [00:20,  2.47it/s]
Sampling: [20/20]: : 20it [00:07,  2.73it/s]


coherence score: -0.1627809977022958
diversity score LDA: 0.98349875695
num topics: 5


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 172.99024178255374	Time: 0:00:00.409851: : 50it [00:20,  2.45it/s]
Sampling: [20/20]: : 20it [00:08,  2.48it/s]


coherence score: -0.13890030815135618
diversity score LDA: 1.0
num topics: 10


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 173.95297931810998	Time: 0:00:00.401288: : 50it [00:20,  2.43it/s]
Sampling: [20/20]: : 20it [00:07,  2.71it/s]


coherence score: -0.111088199058282
diversity score LDA: 0.942522478520635
num topics: 15


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 175.21124491347976	Time: 0:00:00.407747: : 50it [00:20,  2.44it/s]
Sampling: [20/20]: : 20it [00:07,  2.68it/s]


coherence score: -0.05922428433731063
diversity score LDA: 0.9350169084543537
num topics: 20


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 177.37505555665297	Time: 0:00:00.398551: : 50it [00:20,  2.41it/s]
Sampling: [20/20]: : 20it [00:07,  2.66it/s]


coherence score: -0.0673740795600226
diversity score LDA: 0.9352903108659398
num topics: 25


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 179.26368483821904	Time: 0:00:00.404892: : 50it [00:21,  2.32it/s]
Sampling: [20/20]: : 20it [00:07,  2.66it/s]


coherence score: -0.04455256281043065
diversity score LDA: 0.9285885143857857
num topics: 50


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 193.72675694828538	Time: 0:00:00.418129: : 50it [00:20,  2.39it/s]
Sampling: [20/20]: : 20it [00:07,  2.62it/s]


coherence score: -0.08943558002425787
diversity score LDA: 0.9261430645305656
num topics: 100


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 219.9129303788717	Time: 0:00:00.429765: : 50it [00:21,  2.38it/s]
Sampling: [20/20]: : 20it [00:07,  2.56it/s]


coherence score: -0.23336915523817126
diversity score LDA: 0.9596211491970462
num topics: 250


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 286.12413331621366	Time: 0:00:00.628262: : 50it [00:21,  2.33it/s]
Sampling: [20/20]: : 20it [00:08,  2.36it/s]


coherence score: -0.2869954160328756
diversity score LDA: 0.9715869265794804




Batches:   0%|          | 0/4 [00:00<?, ?it/s]



num topics: 3


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 173.43208446092763	Time: 0:00:00.412713: : 50it [00:21,  2.34it/s]
Sampling: [20/20]: : 20it [00:07,  2.58it/s]


coherence score: -0.10421348292576112
diversity score LDA: 1.0
num topics: 5


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 174.5867801092367	Time: 0:00:00.408312: : 50it [00:21,  2.34it/s]
Sampling: [20/20]: : 20it [00:07,  2.59it/s]


coherence score: -0.09157956584780594
diversity score LDA: 0.995049627085
num topics: 10


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 173.88796630550726	Time: 0:00:00.421214: : 50it [00:22,  2.25it/s]
Sampling: [20/20]: : 20it [00:07,  2.57it/s]


coherence score: -0.10836200217803393
diversity score LDA: 0.9566104916220635
num topics: 15


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 176.3305719619153	Time: 0:00:00.426301: : 50it [00:21,  2.33it/s]
Sampling: [20/20]: : 20it [00:07,  2.54it/s]


coherence score: -0.0893630676267676
diversity score LDA: 0.9171393575982313
num topics: 20


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 177.9860762681732	Time: 0:00:00.440010: : 50it [00:21,  2.30it/s]
Sampling: [20/20]: : 20it [00:07,  2.55it/s]


coherence score: -0.05580319966025936
diversity score LDA: 0.9254114872051504
num topics: 25


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 180.07477308193742	Time: 0:00:00.424527: : 50it [00:21,  2.30it/s]
Sampling: [20/20]: : 20it [00:08,  2.32it/s]


coherence score: -0.07689032434216966
diversity score LDA: 0.9132932877954761
num topics: 50


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 193.82512987910872	Time: 0:00:00.428583: : 50it [00:21,  2.28it/s]
Sampling: [20/20]: : 20it [00:07,  2.54it/s]


coherence score: -0.09411504321138661
diversity score LDA: 0.9381236065526647
num topics: 100


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 217.62912971120417	Time: 0:00:00.429246: : 50it [00:21,  2.29it/s]
Sampling: [20/20]: : 20it [00:08,  2.47it/s]


coherence score: -0.17887581623239343
diversity score LDA: 0.9499786003006969
num topics: 250


Epoch: [50/50]	 Seen Samples: [39550/39550]	Train Loss: 282.9947782919169	Time: 0:00:00.439273: : 50it [00:21,  2.28it/s]
Sampling: [20/20]: : 20it [00:08,  2.29it/s]


coherence score: -0.31226021816547955
diversity score LDA: 0.9776125841885076


## Select Top Topic Model

In [10]:
results_df = pd.DataFrame(model_results)

In [11]:
results_df.sort_values(by='coherence', ascending=False)[
    ['coherence','diversity','num_topics', 'embedding_model']
]

Unnamed: 0,coherence,diversity,num_topics,embedding_model
12,-0.041906,0.935548,15,sentence-transformers/allenai-specter
23,-0.044553,0.928589,25,allenai/aspire-sentence-embedder
6,-0.044599,0.925388,50,paraphrase-distilroberta-base-v2
31,-0.055803,0.925411,20,allenai/aspire-contextualsentence-multim-compsci
21,-0.059224,0.935017,15,allenai/aspire-sentence-embedder
4,-0.0621,0.941541,20,paraphrase-distilroberta-base-v2
15,-0.065892,0.922158,50,sentence-transformers/allenai-specter
22,-0.067374,0.93529,20,allenai/aspire-sentence-embedder
14,-0.071861,0.925275,25,sentence-transformers/allenai-specter
5,-0.074043,0.953038,25,paraphrase-distilroberta-base-v2


In [12]:
results_df.sort_values(by='diversity', ascending=False)[
    ['coherence','diversity','num_topics', 'embedding_model']
]

Unnamed: 0,coherence,diversity,num_topics,embedding_model
0,-0.294258,1.0,3,paraphrase-distilroberta-base-v2
19,-0.1389,1.0,5,allenai/aspire-sentence-embedder
27,-0.104213,1.0,3,allenai/aspire-contextualsentence-multim-compsci
28,-0.09158,0.99505,5,allenai/aspire-contextualsentence-multim-compsci
18,-0.162781,0.983499,3,allenai/aspire-sentence-embedder
9,-0.181935,0.980968,3,sentence-transformers/allenai-specter
35,-0.31226,0.977613,250,allenai/aspire-contextualsentence-multim-compsci
10,-0.150384,0.976994,5,sentence-transformers/allenai-specter
26,-0.286995,0.971587,250,allenai/aspire-sentence-embedder
1,-0.118976,0.96524,5,paraphrase-distilroberta-base-v2


In [17]:
ztm_model = results_df.iloc[27]['model']

In [18]:
ztm_model.get_topic_lists(5)

[['datum', 'observe', 'pubme', 'issue', 'english_language'],
 ['literature', 'scopus_index', 'conduct', 'obtain', 'abstract'],
 ['scopus', 'large', 'database', 'journal', 'wos']]

## Visualize topics

In [15]:
import pyLDAvis as vis

ztm = results_df.sort_values(by='coherence', ascending=False).iloc[0]['model']
tp = results_df.sort_values(by='coherence', ascending=False).iloc[0]['tp']
training_dataset = results_df.sort_values(by='coherence', ascending=False).iloc[0]['training_dataset']
lda_vis_data = ztm.get_ldavis_data_format(tp.vocab, training_dataset, n_samples=1)

ztm_pd = vis.prepare(**lda_vis_data)
vis.display(ztm_pd)

  from collections import Iterable
Sampling: [1/1]: : 1it [00:00,  2.37it/s]
  by='saliency', ascending=False).head(R).drop('saliency', 1)


## Compute Topic Index

In [16]:
import numpy as np
import json
for i, result in results_df.iterrows():
    documents = []
    topic_predictions = result['model'].get_thetas(training_dataset, n_samples=5)
    for topics, doi, cite_id in zip(topic_predictions, citations['source_doi'], citations['id']):
        topic_number = np.argmax(topics)
        keywords = [term for term in result['model'].get_topic_lists(10)[topic_number]]
        documents.append({
            "doi": doi,
            "cite_id": cite_id,
            "keywords": [word for word in set(keywords)]
        })
        with open(f'./ztm_{result["num_topics"]}_{result["embedding_model"].replace("/", "_")}_topic_index.json', 'w+') as f:
            json.dump({
                "embedding_model": result["embedding_model"],
                "topics": result["num_topics"],
                "diversity": result["diversity"],
                "coherence": result["coherence"],
                "documents": documents
            }, f)

Sampling: [5/5]: : 5it [00:01,  2.51it/s]
Sampling: [5/5]: : 5it [00:01,  2.52it/s]
Sampling: [5/5]: : 5it [00:01,  2.51it/s]
Sampling: [5/5]: : 5it [00:01,  2.55it/s]
Sampling: [5/5]: : 5it [00:01,  2.53it/s]
Sampling: [5/5]: : 5it [00:01,  2.52it/s]
Sampling: [5/5]: : 5it [00:02,  2.46it/s]
Sampling: [5/5]: : 5it [00:02,  2.44it/s]
Sampling: [5/5]: : 5it [00:02,  2.37it/s]
Sampling: [5/5]: : 5it [00:01,  2.51it/s]
Sampling: [5/5]: : 5it [00:01,  2.53it/s]
Sampling: [5/5]: : 5it [00:02,  2.49it/s]
Sampling: [5/5]: : 5it [00:01,  2.52it/s]
Sampling: [5/5]: : 5it [00:02,  2.47it/s]
Sampling: [5/5]: : 5it [00:01,  2.52it/s]
Sampling: [5/5]: : 5it [00:02,  2.46it/s]
Sampling: [5/5]: : 5it [00:02,  2.45it/s]
Sampling: [5/5]: : 5it [00:02,  2.33it/s]
Sampling: [5/5]: : 5it [00:02,  2.48it/s]
Sampling: [5/5]: : 5it [00:02,  2.46it/s]
Sampling: [5/5]: : 5it [00:02,  2.47it/s]
Sampling: [5/5]: : 5it [00:02,  2.44it/s]
Sampling: [5/5]: : 5it [00:02,  2.45it/s]
Sampling: [5/5]: : 5it [00:02,  2.