In [2]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
import kagglehub
organizations_cornell_university_arxiv_path = kagglehub.dataset_download('cornell-university/arxiv')

Downloading from https://www.kaggle.com/api/v1/datasets/download/cornell-university/arxiv?dataset_version_number=232...


100%|██████████| 1.44G/1.44G [00:14<00:00, 105MB/s]

Extracting files...





# Topic Modeling arXiv Abstracts with BERTopic

This problem consists of grouping a large amount of unseen & unlabeled research papers from arXiv based on their keywords found in the abstract section. The tool being utilized is BERTopic which consists of modular layers that can be customized or simply use the default settings.
<br>
<br>
Note: this python notebook is best utilized in Kaggle because of the access to 30GB RAM memory and 2x T4 GPUs

# Quick Overview

1) Loading data (with Category selection)
2) Pre-process the data
3) Default BERTopic (with examples)
4) Pre-compute embeddings
5) Fine-tune BERTopic Layers
    - Embedding
    - Clustering
    - Tokenization of topics
    - Weight tokens
    - Representation of topics
6) Custom BERTopic model and fitting
7) Outputs
8) Visualizations

In [3]:
!pip install -q bertopic

# the following resolves warning with .fit_transform() method
!pip install -q huggingface_hub[hf_xet]
!pip install -q hf_xet

from bertopic import BERTopic
from sentence_transformers import SentenceTransformer

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.6/150.6 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m103.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m76.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m51.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [6]:
import numpy as np
import pandas as pd

# Input data files are available in the read-only "../input/" directory

import os
file_path = ''
for dirname, _, filenames in os.walk(organizations_cornell_university_arxiv_path):
    for filename in filenames:
        file_path = os.path.join(dirname, filename)
        print('file_path:', file_path)

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"

file_path: /root/.cache/kagglehub/datasets/cornell-university/arxiv/versions/232/arxiv-metadata-oai-snapshot.json


# Loading data
#### (Specific category: cs.AI)

The data is in a json file which contains different types of information including submitter, authors, title, comments, category, and the abstract. The total number of articles is 2,725,401 while specific categories contain a managable amount of articles ideal for limited memory (e.g. cs.AI has 12,180 articles).

In [7]:
import json

# Examples of Categories: cs.AI, stat.ML, cs.LG (aka Machine Learning)
category_desired = 'stat.ML'
papers = []
with open(file_path, 'r') as f:
    for i, line in enumerate(f):
        paper = json.loads(line)
        if paper.get('categories') == category_desired:
            papers.append(paper)
print("Number of papers in", category_desired, ":", len(papers))

# Convert list of papers to a DataFrame
df = pd.DataFrame(papers)
df.head()

Number of papers in stat.ML : 1601


Unnamed: 0,id,submitter,authors,title,comments,journal-ref,doi,report-no,categories,license,abstract,versions,update_date,authors_parsed
0,705.2363,Marten Wegkamp,Marten Wegkamp,Lasso type classifiers with a reject option,Published at http://dx.doi.org/10.1214/07-EJS0...,"Electronic Journal of Statistics 2007, Vol. 1,...",10.1214/07-EJS058,IMS-EJS-EJS_2007_58,stat.ML,,We consider the problem of binary classifica...,"[{'version': 'v1', 'created': 'Wed, 16 May 200...",2009-09-29,"[[Wegkamp, Marten, ]]"
1,706.3499,Bharath Sriperumbudur,Bharath K. Sriperumbudur and Gert R. G. Lanckriet,Metric Embedding for Nearest Neighbor Classifi...,"9 pages, 1 table",,,,stat.ML,,The distance metric plays an important role ...,"[{'version': 'v1', 'created': 'Sun, 24 Jun 200...",2007-06-26,"[[Sriperumbudur, Bharath K., ], [Lanckriet, Ge..."
2,707.3536,Patrick Erik Bradley,Patrick Erik Bradley,Degenerating families of dendrograms,"13 pages, 8 figures","J. Classif. 25, 27-42 (2008)",10.1007/s00357-008-9009-5,,stat.ML,,Dendrograms used in data analysis are ultram...,"[{'version': 'v1', 'created': 'Tue, 24 Jul 200...",2008-06-28,"[[Bradley, Patrick Erik, ]]"
3,707.4072,Patrick Erik Bradley,Patrick Erik Bradley,Families of dendrograms,"7 pages, 3 figures. To appear in: Proceedings ...",,10.1007/978-3-540-78246-9_12,,stat.ML,,A conceptual framework for cluster analysis ...,"[{'version': 'v1', 'created': 'Fri, 27 Jul 200...",2009-12-01,"[[Bradley, Patrick Erik, ]]"
4,708.2377,Roberto Alamino,"Roberto C. Alamino, Nestor Caticha",Online Learning in Discrete Hidden Markov Models,"8 pages, 6 figures",,10.1063/1.2423274,,stat.ML,,We present and analyse three online algorith...,"[{'version': 'v1', 'created': 'Fri, 17 Aug 200...",2007-08-20,"[[Alamino, Roberto C., ], [Caticha, Nestor, ]]"


# Pre-process the data

(add small description)

In [9]:
# Drop rows with missing or very short abstracts
df = df.dropna(subset=['abstract'])
df = df[df['abstract'].str.len() > 100]

# Extract just the abstracts & lowercase
abstracts = df['abstract'].tolist()
abstracts = [doc.lower() for doc in abstracts]

# Check how many abstracts are ready
print(f"Number of cleaned abstracts: {len(abstracts)}")

Number of cleaned abstracts: 1601


In [10]:
import nltk
from nltk.corpus import stopwords

# We will remove stopwords (e.g. the, at, are)
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

def preprocess(text):
    # convert to lowercase
    text = text.lower()

    # only keep letters and space, so we drop punctuation/symbols/numbers
    text = ''.join(ch for ch in text if ch.isalpha() or ch.isspace())
    words = text.split()

    # parse out stopwords keeping keywords and words with small length
    words = [w for w in words if w not in stop_words and len(w) > 2]
    return ' '.join(words)

# Apply to a list of abstracts
abstracts = [preprocess(doc) for doc in abstracts]
abstracts[:3]


[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


['consider problem binary classification one particular cost choose classify observation present simple proof oracle inequality excess risk structural risk minimizers using lasso type penalty',
 'distance metric plays important role nearest neighbor classification usually euclidean distance metric assumed mahalanobis distance metric optimized improve performance paper study problem embedding arbitrary metric spaces euclidean space goal improve accuracy classifier propose solution appealing framework regularization reproducing kernel hilbert space prove representerlike theorem classification embedding function determined solving semidefinite program interesting connection softmargin linear binary support vector machine classifier although main focus paper present general theoretical framework metric embedding setting demonstrate performance proposed method benchmark datasets show performs better mahalanobis metric learning algorithm terms leaveoneout generalization errors',
 'dendrogram

### Example of how a desired outcome should look:
['note formally describe functionality calculate valid domain bdd represent solution space valid configuration formalization largely base clab configuration framework', <br>
 'motivation profile hidden markov model phmms popular useful tool detection remote homologue protein family unfortunately performance satisfactory protein twilight zone ...']

# Default BERTopic

In [80]:
topic_model = BERTopic()
topics, probs = topic_model.fit_transform(abstracts[:1000])

In [81]:
topic_model.get_topic_info()

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,386,-1_data_model_algorithm_models,"[data, model, algorithm, models, learning, met...",[paper study statistical properties semisuperv...
1,0,56,0_graph_graphs_network_clustering,"[graph, graphs, network, clustering, nodes, sp...",[partitioning graph groups vertices within gro...
2,1,52,1_matrix_subspace_algorithm_data,"[matrix, subspace, algorithm, data, completion...",[due challenging applications collaborative fi...
3,2,52,2_lasso_sparse_dictionary_group,"[lasso, sparse, dictionary, group, sparsity, r...",[study problem learning sparse linear regressi...
4,3,51,3_gaussian_processes_process_inference,"[gaussian, processes, process, inference, mode...",[study gaussian process regression model conte...
5,4,48,4_kernel_kernels_reproducing_test,"[kernel, kernels, reproducing, test, density, ...",[paper propose family tractable kernels dense ...
6,5,44,5_dirichlet_topic_process_model,"[dirichlet, topic, process, model, models, inf...",[nonparametric mixture models based dirichlet ...
7,6,42,6_brain_data_neuroimaging_functional,"[brain, data, neuroimaging, functional, analys...",[brain decoding involves determination subject...
8,7,38,7_clustering_cluster_clusters_data,"[clustering, cluster, clusters, data, method, ...",[mean shift clustering finds modes data probab...
9,8,34,8_graphical_models_structure_graph,"[graphical, models, structure, graph, model, v...",[propose new class semiparametric exponential ...


In [82]:
topic_model.get_topic(0)

[('graph', np.float64(0.05608021890451161)),
 ('graphs', np.float64(0.04391361855068794)),
 ('network', np.float64(0.0308825668121992)),
 ('clustering', np.float64(0.028191382569537415)),
 ('nodes', np.float64(0.02595808221364081)),
 ('spectral', np.float64(0.02435532602525256)),
 ('vertex', np.float64(0.024072513312475972)),
 ('networks', np.float64(0.023856096534178348)),
 ('model', np.float64(0.023688207143516368)),
 ('stochastic', np.float64(0.021385265929755804))]

In [83]:
# Show documents' topic, probability, Top_n_words
topic_model.get_document_info(abstracts[:1000]).head()

Unnamed: 0,Document,Topic,Name,Representation,Representative_Docs,Top_n_words,Probability,Representative_document
0,consider problem binary classification one par...,2,2_lasso_sparse_dictionary_group,"[lasso, sparse, dictionary, group, sparsity, r...",[study problem learning sparse linear regressi...,lasso - sparse - dictionary - group - sparsity...,0.6531,False
1,distance metric plays important role nearest n...,15,15_manifold_embedding_manifolds_lle,"[manifold, embedding, manifolds, lle, data, me...",[paper presents new framework manifold learnin...,manifold - embedding - manifolds - lle - data ...,0.459526,False
2,dendrograms used data analysis ultrametric spa...,7,7_clustering_cluster_clusters_data,"[clustering, cluster, clusters, data, method, ...",[mean shift clustering finds modes data probab...,clustering - cluster - clusters - data - metho...,0.442203,False
3,conceptual framework cluster analysis viewpoin...,7,7_clustering_cluster_clusters_data,"[clustering, cluster, clusters, data, method, ...",[mean shift clustering finds modes data probab...,clustering - cluster - clusters - data - metho...,0.454911,False
4,present analyse three online algorithms learni...,-1,-1_data_model_algorithm_models,"[data, model, algorithm, models, learning, met...",[paper study statistical properties semisuperv...,data - model - algorithm - models - learning -...,0.0,False


In [84]:
# note: outputs may differ and not save
try:
  topic_model.visualize_topics()
except Exception as error:
  print("Unable to display visualization. The count for the first topic is likely most of the documents. Please try fitting the model again.")

## Example output with cs.AI category:
![](https://raw.githubusercontent.com/dnxv/BERTopic/refs/heads/main/outputs/visualize_topics-embedding_BAAI-min_50.jpg)

In [85]:
topic_model.visualize_documents(abstracts)

## Example output with 10,000+ datapoints from cs.AI category:
![](https://raw.githubusercontent.com/dnxv/BERTopic/refs/heads/main/outputs/embedding-BAAI-bge-base%2Cmin_size_50_1.jpg)

# Pre-compute embeddings

In [86]:
# Prepare embeddings
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedding_model.encode(abstracts, show_progress_bar=True)

Batches:   0%|          | 0/51 [00:00<?, ?it/s]

In [87]:
from umap import UMAP

reduced_embeddings = UMAP(n_neighbors=30, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)

df_plot = pd.DataFrame({
    "x1": [point[0] for point in reduced_embeddings],
    "x2": [point[1] for point in reduced_embeddings],
    "docs": abstracts,
})

df_plot["docs_short"] = df_plot["docs"].str[:100] + "..."
df_plot.head(10)

Unnamed: 0,x1,x2,docs,docs_short
0,-0.65706,4.779006,consider problem binary classification one par...,consider problem binary classification one par...
1,-1.518607,3.958432,distance metric plays important role nearest n...,distance metric plays important role nearest n...
2,-3.145614,6.410346,dendrograms used data analysis ultrametric spa...,dendrograms used data analysis ultrametric spa...
3,-3.203511,6.571202,conceptual framework cluster analysis viewpoin...,conceptual framework cluster analysis viewpoin...
4,-4.546526,5.217755,present analyse three online algorithms learni...,present analyse three online algorithms learni...
5,-2.27789,4.325052,recent years kernel density estimation exploit...,recent years kernel density estimation exploit...
6,-0.375601,3.811363,thesis responds challenges using large number ...,thesis responds challenges using large number ...
7,-3.622922,2.599043,simulated annealing popular method approaching...,simulated annealing popular method approaching...
8,-4.468881,6.162722,present nested chinese restaurant process ncrp...,present nested chinese restaurant process ncrp...
9,-2.742345,3.328779,provide selfcontained proof theorem relating p...,provide selfcontained proof theorem relating p...


In [88]:
import plotly.express as px
import pandas as pd
import plotly.io as pio

total_docs = len(df_plot)
fig = px.scatter(df_plot, x="x1", y="x2",  hover_data=["docs_short"])
fig.update_traces(marker=dict(line=dict(width=0.5, color='white')))
fig.update_layout(
    title=f"Abstracts",
    title_font_size=20
)

fig.show()

In [89]:
# BERTopic with default parameters/models (e.g. all-MiniLM-L6-v2, UMAP, HDBSCAN)
topic_model = BERTopic()

# fit_transform runs faster since embeddings were already processed prior
topics, probs = topic_model.fit_transform(abstracts, embeddings)

In [90]:
topic_model.get_topic_info()

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,536,-1_data_model_learning_algorithm,"[data, model, learning, algorithm, method, usi...",[hierarchical probabilistic models gaussian mi...
1,0,163,0_graph_graphs_network_graphical,"[graph, graphs, network, graphical, model, net...",[latent space model family random graphs assig...
2,1,78,1_gaussian_process_processes_model,"[gaussian, process, processes, model, gps, var...",[study gaussian process regression model conte...
3,2,69,2_lasso_sparse_dictionary_group,"[lasso, sparse, dictionary, group, sparsity, r...",[study problem learning sparse linear regressi...
4,3,57,3_forest_trees_ensembles_forests,"[forest, trees, ensembles, forests, ensemble, ...",[tree ensembles random forest boosted trees re...
5,4,56,4_clustering_cluster_clusters_data,"[clustering, cluster, clusters, data, density,...",[goal data clustering partition data points gr...
6,5,55,5_variational_inference_models_generative,"[variational, inference, models, generative, l...",[stochastic variational inference relatively w...
7,6,52,6_kernel_kernels_reproducing_hilbert,"[kernel, kernels, reproducing, hilbert, test, ...",[paper propose family tractable kernels dense ...
8,7,50,7_topic_dirichlet_inference_models,"[topic, dirichlet, inference, models, topics, ...",[paper proposes novel dynamic hierarchical dir...
9,8,49,8_matrix_lowrank_tensor_completion,"[matrix, lowrank, tensor, completion, rank, al...",[consider problem noisy bit matrix completion ...


## Example output from cs.AI category:
| Topic | Count | Name                         | Representation                                               | Representative Docs                                     |
|-------|-------|------------------------------|---------------------------------------------------------------|----------------------------------------------------------|
| -1    | 4525  | -1_the_of_and_to             | [the, of, and, to, in, we, that, for, is, this]               | [ In this work we propose a planning and acti...         |
| 0     | 665   | 0_belief_of_theory_the       | [belief, of, theory, the, is, probability, in, evidence, a, that] | [ The paper presents a novel view of the Demp...         |
| 1     | 452   | 1_game_games_the_of          | [game, games, the, of, player, to, in, and, we, strategy]     | [ In many board games and other abstract game...         |
| 2     | 434   | 2_networks_bayesian_network_the | [networks, bayesian, network, the, inference, structure, in, of, learning, probabilistic] | [ Structure and parameters in a Bayesian netw... |
| 3     | 399   | 3_intelligence_of_the_and    | [intelligence, of, the, and, cognitive, artificial, agents, in, is, systems] | [ The overarching problem in artificial intel... |

<br>
<br>

## Error Output with high concentration on Topic -1:
| Topic | Count | Name                                    | Representation                                      | Representative\_Docs                                |   |
| ----- | ----- | --------------------------------------- | --------------------------------------------------- | --------------------------------------------------- | - |
| -1    | 10318 | -1\_model\_paper\_models\_learning      | \[model, paper, models, learning, data, problem...] | \[markov decision processes mdps well studied f...] |   |
| 0     | 1726  | 0\_learning\_models\_model\_data        | \[learning, models, model, data, knowledge, pap...] | \[artificial intelligence techniques used class...] |   |
| 1     | 136   | 1\_problem\_search\_algorithm\_problems | \[problem, search, algorithm, problems, algorit...] | \[constraint satisfaction problem csp framework...] |   |



# Fine-tune BERTopic with alternative modeling techniques

### Swap or remove any of the following:
1) Embedding model <br>
    * Note: The following embedding models kept grouping majority (12134) into group 0 with filler words
        * word2vec, Universal Sentence Encoder (USE), hugging face transformer (distilbert-base) (40+ min) <br>
    
2) Reducing dimensionality of embeddings (default=UMAP, PCA, t-SVD, cuML UMAP, or remove this layer)<br>
3) Clustering into topics (default=HDBSCAN, k-Means, sklearn.cluster, cuML HDBSCAN) <br>
4) Tokenization of topics <br>
5) Weight tokens
6) Represent topics with one or multiple representations <br>

(Gathered from here: https://maartengr.github.io/BERTopic/index.html#modularity)

### Embedding model ([more models here](https://maartengr.github.io/BERTopic/getting_started/embeddings/embeddings.html#sentence-transformers))

In [91]:
### (Default): all-MiniLM-L6-v2
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

### (alternative): bge-base
# embedding_model = SentenceTransformer("BAAI/bge-base-en-v1.5")

##### Utilize GPUs if available

In [92]:
import torch
if torch.cuda.device_count() > 1:
  print(f"Using {torch.cuda.device_count()} GPUs")
  embedding_model = torch.nn.DataParallel(embedding_model)
else:
  print("Using a single GPU or CPU")
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
  embedding_model = embedding_model.to(device)

Using a single GPU or CPU


### Dimensionality Reduction ([more models here](https://maartengr.github.io/BERTopic/getting_started/dim_reduction/dim_reduction.html))

In [93]:
### (Default): UMAP
from umap import UMAP
dim_model = UMAP(n_components=2, n_neighbors=10, min_dist=0.0, metric='cosine')

### (alternative): cuML UMAP (Note: ideal with GPU, error with cuda version if P100 is used)
# from cuml.manifold import UMAP
# dim_model = UMAP(n_components=2, n_neighbors=5, min_dist=0.0, metric='cosine')

### (alternative): PCA
# from sklearn.decomposition import PCA
# dim_model = PCA(n_components=2)

### Clustering ([more models here](https://maartengr.github.io/BERTopic/getting_started/clustering/clustering.html))

In [94]:
### (Default): HDSCAN
# from hdbscan import HDBSCAN
# cluster_model = HDBSCAN(min_cluster_size=20, metric='euclidean', cluster_selection_method='eom', prediction_data=True)

### (alternative): cuML HDSCAN (ideal with GPU)
# from cuml.cluster import HDBSCAN
# cluster_model = HDBSCAN(min_samples=50, gen_min_span_tree=True, prediction_data=True)

### (alternative): k-Means
from sklearn.cluster import KMeans
cluster_model = KMeans(n_clusters=20)

### Representation Model ([more models here](https://maartengr.github.io/BERTopic/getting_started/representation/representation.html))

In [95]:
from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance

# KeyBERT
keybert = KeyBERTInspired()

# MMR
mmr = MaximalMarginalRelevance(diversity=0.3)

# All representation models
representation_model = {
    "KeyBERT": keybert,
    "MMR": mmr,
}

# Custom BERTopic model and fitting

In [96]:
# Create the BERTopic model with modular/custom sub-models
topic_model = BERTopic(verbose=True,

           # Sub-models
           embedding_model=embedding_model,
           umap_model=dim_model,
           hdbscan_model=cluster_model,
           representation_model=representation_model,

           #Hyperparameters
           min_topic_size=25
)

In [97]:
### Fit the model on your list of abstracts, but no pre-processed embeddings
# topics, probs = topic_model.fit_transform(abstracts)

### Faster fitting of the model with abstracts pre-processed embeddings
topics, probs = topic_model.fit_transform(abstracts, embeddings)

2025-05-14 16:08:14,713 - BERTopic - Dimensionality - Fitting the dimensionality reduction algorithm
2025-05-14 16:08:21,922 - BERTopic - Dimensionality - Completed ✓
2025-05-14 16:08:21,925 - BERTopic - Cluster - Start clustering the reduced embeddings
2025-05-14 16:08:21,969 - BERTopic - Cluster - Completed ✓
2025-05-14 16:08:21,975 - BERTopic - Representation - Fine-tuning topics using representation models.
2025-05-14 16:08:24,152 - BERTopic - Representation - Completed ✓


# Outputs

In [98]:
# Get a summary of all topics
topic_model.get_topic_info()

Unnamed: 0,Topic,Count,Name,Representation,KeyBERT,MMR,Representative_Docs
0,0,150,0_gaussian_process_model_processes,"[gaussian, process, model, processes, models, ...","[gaussian, models, prediction, modelling, lear...","[gaussian, processes, models, variational, lea...",[multioutput gaussian processes received incre...
1,1,130,1_matrix_dictionary_algorithm_data,"[matrix, dictionary, algorithm, data, sparse, ...","[algorithms, lowrank, algorithm, matrix, matri...","[matrix, sparse, tensor, decomposition, lowran...",[consider generalization lowrank matrix comple...
2,2,115,2_lasso_regression_sparse_group,"[lasso, regression, sparse, group, regularizat...","[lasso, regularization, sparse, optimization, ...","[lasso, regression, sparse, group, regularizat...",[group lasso penalized regression method used ...
3,3,110,3_kernel_learning_kernels_test,"[kernel, learning, kernels, test, reproducing,...","[kernels, kernel, classification, supervised, ...","[kernel, learning, kernels, hilbert, distribut...",[study strictly proper scoring rules reproduci...
4,4,93,4_graph_graphs_network_vertex,"[graph, graphs, network, vertex, clustering, s...","[graphs, nodes, adjacency, clustering, cluster...","[graphs, spectral, nodes, networks, vertices, ...",[partitioning graph groups vertices within gro...
5,5,92,5_brain_data_kernel_method,"[brain, data, kernel, method, analysis, model,...","[fmri, neuroimaging, classification, imaging, ...","[brain, data, kernel, imaging, cca, functional...",[imaging genetic research essentially focused ...
6,6,85,6_algorithms_problems_algorithm_statistical,"[algorithms, problems, algorithm, statistical,...","[algorithms, optimization, algorithm, regulari...","[algorithms, convex, optimization, convergence...",[consider statistical algorithmic aspects solv...
7,7,81,7_sampling_carlo_monte_algorithm,"[sampling, carlo, monte, algorithm, markov, st...","[sgmcmc, mcmc, sampling, bayesian, stochastic,...","[sampling, monte, stochastic, algorithms, mcmc...",[hamiltonian monte carlo hmc popular markov ch...
8,8,81,8_variational_inference_models_generative,"[variational, inference, models, generative, d...","[variational, stochastic, probabilistic, bayes...","[variational, models, generative, distribution...",[stochastic variational inference svi paradigm...
9,9,76,9_forest_trees_data_ensembles,"[forest, trees, data, ensembles, forests, ense...","[ensembles, classification, forests, forest, e...","[forest, ensembles, forests, ensemble, boostin...",[missing data expected issue large amounts dat...


In [99]:
# Show top keywords for specific Topic
topic_model.get_topic(1)

[('matrix', np.float64(0.04432528328039532)),
 ('dictionary', np.float64(0.027532395201076802)),
 ('algorithm', np.float64(0.026082073785749598)),
 ('data', np.float64(0.025180198812647936)),
 ('sparse', np.float64(0.02387081271662058)),
 ('tensor', np.float64(0.022564080474555014)),
 ('subspace', np.float64(0.020480526917789912)),
 ('problem', np.float64(0.019544534767337578)),
 ('decomposition', np.float64(0.018440210165769556)),
 ('lowrank', np.float64(0.017763061428544016))]

# Visualizations

In [100]:
# note: outputs may differ and not save
try:
  topic_model.visualize_topics()
except Exception as error:
  print("Unable to display visualization. The count for the first topic is likely most of the documents. Please try fitting the model again.")

In [101]:
topic_model.visualize_documents(abstracts)

In [102]:
topic_model.visualize_barchart()

In [103]:
keybert_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["KeyBERT"].values()]
print(keybert_labels)

# Get document info
document_info = topic_model.get_document_info(abstracts)
document_info["KeyBERT"] = document_info["KeyBERT"].apply(lambda x: x[0])
all_labels = document_info["KeyBERT"]

['gaussian', 'algorithms', 'lasso', 'kernels', 'graphs', 'fmri', 'algorithms', 'sgmcmc', 'variational', 'ensembles', 'classifiers', 'clustering', 'graphical', 'pca', 'optimisation', 'forecasting', 'topics', 'causal', 'clinical', 'bayesian']


In [104]:
df_plot = pd.DataFrame({
    "x1": [point[0] for point in reduced_embeddings],
    "x2": [point[1] for point in reduced_embeddings],
    "docs": abstracts,
    "label": all_labels
})
df_plot["docs_short"] = df_plot["docs"].str[:100] + "..."
df_plot.head(10)


Unnamed: 0,x1,x2,docs,label,docs_short
0,-0.65706,4.779006,consider problem binary classification one par...,lasso,consider problem binary classification one par...
1,-1.518607,3.958432,distance metric plays important role nearest n...,classifiers,distance metric plays important role nearest n...
2,-3.145614,6.410346,dendrograms used data analysis ultrametric spa...,clustering,dendrograms used data analysis ultrametric spa...
3,-3.203511,6.571202,conceptual framework cluster analysis viewpoin...,clustering,conceptual framework cluster analysis viewpoin...
4,-4.546526,5.217755,present analyse three online algorithms learni...,sgmcmc,present analyse three online algorithms learni...
5,-2.27789,4.325052,recent years kernel density estimation exploit...,kernels,recent years kernel density estimation exploit...
6,-0.375601,3.811363,thesis responds challenges using large number ...,classifiers,thesis responds challenges using large number ...
7,-3.622922,2.599043,simulated annealing popular method approaching...,optimisation,simulated annealing popular method approaching...
8,-4.468881,6.162722,present nested chinese restaurant process ncrp...,topics,present nested chinese restaurant process ncrp...
9,-2.742345,3.328779,provide selfcontained proof theorem relating p...,forecasting,provide selfcontained proof theorem relating p...


In [124]:
import plotly.express as px

fig = px.scatter(df_plot, x="x1", y="x2", color="label", hover_data=["docs_short"])
fig.update_layout(
    title=f"Category: {category_desired}",
    title_font_size=20
)
print("Embedding Model:", topic_model.embedding_model)
print("\nDimensionality Reduction Model:", topic_model.umap_model)
print("\nClustering Model:", topic_model.hdbscan_model)
print("\nRepresentation Model(s):", topic_model.representation_model)
fig.show()

Embedding Model: <bertopic.backend._sentencetransformers.SentenceTransformerBackend object at 0x7e062c7b3e10>

Dimensionality Reduction Model: UMAP(angular_rp_forest=True, metric='cosine', min_dist=0.0, n_neighbors=10, tqdm_kwds={'bar_format': '{desc}: {percentage:3.0f}%| {bar} {n_fmt}/{total_fmt} [{elapsed}]', 'desc': 'Epochs completed', 'disable': True})

Clustering Model: KMeans(n_clusters=20)

Representation Model(s): {'KeyBERT': KeyBERTInspired(), 'MMR': MaximalMarginalRelevance(diversity=0.3)}


# Optimize Hyperparameters (UMAP & HDBSCAN)


In [106]:
import itertools

# Generate parameter vectors
def grid(param_grid):

  param_combinations = list(itertools.product(*param_grid.values()))
  param_list = [dict(zip(param_grid.keys(), values)) for values in param_combinations]

  return param_list


In [107]:
param_grid = {
    "st": ["all-MiniLM-L6-v2", "BAAI/bge-base-en-v1.5"], #embedding model (SentenceTransformer)
    "nn": [2, 15, 20],             # umap_n_neighbors
    "cs": [15, 20, 25],            # hdbscan_min_cluster_size
    "ts": [10, 25, 50],            # min_topic_size
}

### Note: If you would like to proceed with Hyperparameter tuning, uncomment the last line of code

In [108]:
from sklearn.metrics import silhouette_score

def GridSearch(param_grid):

    param_combinations = grid(param_grid)
    print("Total combinations:", len(param_combinations))

    best_score = -1
    best_params = {}

    for params in param_combinations:

        embedding_model = SentenceTransformer(params["st"])
        print("emb done")
        topic_model = BERTopic(verbose=False,
           embedding_model=embedding_model,
           # umap_model=UMAP(n_components=2, n_neighbors=params["nn"], min_dist=0.0, metric='cosine'),
           hdbscan_model=HDBSCAN(min_cluster_size=params["cs"], metric='euclidean', cluster_selection_method='eom', prediction_data=True),
           representation_model={"KeyBERT": KeyBERTInspired()},
           min_topic_size=params["ts"],
           top_n_words=10
        )
        topics, probs = topic_model.fit_transform(abstracts)

        # Get embeddings and labels
        embeddings = embedding_model.encode(abstracts, convert_to_tensor=False)
        labels = topic_model.get_document_info(abstracts)['Topic'].values

        # Compute silhouette score
        score = silhouette_score(embeddings, labels, metric='euclidean')
        output = "Silhouette Score:" + str(score) + "| Params:" + str(params)
        print(output)
        print(topic_model.get_topic_info()[["Topic", "Count", "Name"]])

        #========================= Visualization ====================================
        keybert_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["KeyBERT"].values()]

        # Get document info
        document_info = topic_model.get_document_info(abstracts)
        document_info["KeyBERT"] = document_info["KeyBERT"].apply(lambda x: x[0])
        all_labels = document_info["KeyBERT"]

        df_plot = pd.DataFrame({
            "x1": [point[0] for point in reduced_embeddings],
            "x2": [point[1] for point in reduced_embeddings],
            "docs": abstracts,
            "label": all_labels
        })
        df_plot["docs_short"] = df_plot["docs"].str[:100] + "..."

        fig = px.scatter(df_plot, x="x1", y="x2", color="label", hover_data=["docs_short"])
        fig.show()
        #==========================================================================

        # Update best score and parameters if current score is higher
        if score > best_score:
            best_score = score
            best_params = {
                'n_neighbors': params["nn"],
                # 'n_components': params["nc"],
                'min_cluster_size': params["cs"]
            }

        print("=======================================================================================================\n")
    print("best_score", best_score)
    print("best_params", best_params)
# GridSearch(param_grid)