In [None]:
import altair as alt
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance, OpenAI, PartOfSpeech
from dotenv import load_dotenv
from hdbscan import HDBSCAN
import networkx as nx
import numpy as np
import openai
import os
import pandas as pd
from pathlib import Path
import re
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import spacy
from umap import UMAP

from nesta_ds_utils.viz.altair import saving as viz_save

from dsp_ai_eval import PROJECT_DIR, logging
from dsp_ai_eval.utils import text_cleaning as tc
from dsp_ai_eval.utils.clustering_utils import create_new_topic_model, create_df_for_viz

# Increase the maximum number of rows Altair will process
alt.data_transformers.disable_max_rows()

embedding_model = SentenceTransformer('all-miniLM-L6-v2')

load_dotenv()

OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")

In [None]:
SCITE_SUMMARY = PROJECT_DIR / 'inputs/data/scite/scite_summary.txt'

scite_abstracts = pd.read_parquet(PROJECT_DIR / "inputs/data/embeddings/scite_embeddings.parquet")
scite_abstracts.head()

In [None]:
docs = scite_abstracts['title_abstract'].to_list()
logging.info(len(docs))

Following [this best practices guide](https://maartengr.github.io/BERTopic/getting_started/best_practices/best_practices.html)

In [None]:
embeddings = scite_abstracts['embeddings'].apply(pd.Series).values

In [None]:
topic_model = create_new_topic_model(hdbscan_min_cluster_size = 150, tfidf_min_df=10, tfidf_max_df=1.0, tfidf_ngram_range=(1, 3), seed=42)

In [None]:
# Train model
topics, probs = topic_model.fit_transform(docs, embeddings)

# Show topics
topic_model.get_topic_info()

In [None]:
topic_lookup = topic_model.get_topic_info()[['Topic', 'Name']]

In [None]:
topic_model.visualize_barchart(topics=[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [None]:
representative_docs = topic_model.get_representative_docs()
representative_docs

In [None]:
topic_model.visualize_topics()

In [None]:
df_vis = create_df_for_viz(embeddings,
                      topic_model,
                      topics,
                      docs,
                      seed=42)

df_vis.head()

In [None]:
# Define the base chart with common encodings
base = alt.Chart(df_vis[df_vis['x']>5]).encode(
    x='x',
    y='y',
    size=alt.condition(
        alt.datum.category == 'main',  # Condition for the 'category' column
        alt.value(200),  # If True, size is 50
        alt.value(30)  # If False, size is 30
    ),
    opacity=alt.condition(
        alt.datum.category == 'main',  # Condition for the 'topic' column
        alt.value(1),  # If True, opacity is 0.2
        alt.value(0.2)  # If False, opacity is 0.5
    ),
    tooltip=['Name:N', 'title_abstract:N']
)

# Chart for 'main' category points
main_points = base.transform_filter(
    alt.datum.category == 'main'
).mark_circle().encode(
    color=alt.value('red')  # Color is red for 'main'
)

# Chart for other points, colored by 'Name'
other_points = base.transform_filter(
    alt.datum.category != 'main'
).mark_circle().encode(
    color='Name:N'  # Color mapped by 'Name'
)

# Combine the charts
plot = (main_points + other_points).properties(
    width=800,
    height=600,
).interactive()


plot.save(PROJECT_DIR / 'outputs/figures/scite_abstracts.html')
# plot.save(PROJECT_DIR / 'outputs/figures/scite_abstracts.png')
viz_save.save(plot, 'scite_abstracts', PROJECT_DIR / 'outputs/figures', save_png=True)

plot.display()

In [None]:
# Create the plot
plot = alt.Chart(df_vis[(df_vis['x']>5) & (df_vis['topic']!=-1)] # get rid of outliers
                 ).mark_circle(size=30, opacity=0.25).encode(
    x='x',
    y='y',
    color='Name:N',
    tooltip=['Name:N','title_abstract:N']
).properties(
    width=800,
    height=600,
).interactive()

# plot.save(PROJECT_DIR / 'outputs/figures/scite_abstracts.png')
viz_save.save(plot, 'scite_abstracts_filtered', PROJECT_DIR / 'outputs/figures', save_png=True)

plot.display()

In [None]:
noise_cluster = df_vis[df_vis['topic']==-1]
noise_cluster = noise_cluster.drop(columns=['topic', 'Topic', 'Name', 'x', 'y'])
noise_cluster.head()

In [None]:
central_embeddings = noise_cluster['embeddings'].apply(pd.Series).values
central_docs = noise_cluster['title_abstract'].to_list()

In [None]:
new_topic_model = create_new_topic_model(hdbscan_min_cluster_size=50,
                           tfidf_min_df = 10,
                           tfidf_max_df = 1.0,
                           tfidf_ngram_range = (1, 3),
                           gpt_model = "gpt-3.5-turbo",
                           openai_api_key = OPENAI_API_KEY,
                           seed=42)

# Train model
new_topics, new_probs = new_topic_model.fit_transform(central_docs, central_embeddings)

# Show topics
new_topic_model.get_topic_info()

In [None]:
new_topic_lookup = new_topic_model.get_topic_info()[['Topic', 'Name']]

In [None]:
df_vis_new = create_df_for_viz(central_embeddings,
                      new_topic_model,
                      new_topics,
                      central_docs,
                      seed=42)

df_vis_new = df_vis_new.merge(noise_cluster, left_index=True, right_index=True)

In [None]:
new_representative_docs = new_topic_model.get_representative_docs()
new_representative_docs

In [None]:
noise_cluster[noise_cluster['category']=='main']

In [None]:
# Filter the DataFrame to exclude outliers
filtered_df = df_vis_new[df_vis_new['x'] > 5]

# Define the base chart with common encodings
base = alt.Chart(filtered_df).encode(
    x='x',
    y='y',
    size=alt.condition(
        alt.datum.category == 'main',  # Condition for the 'category' column
        alt.value(200),  # If True, size is 50
        alt.value(30)  # If False, size is 30
    ),
    opacity=alt.condition(
        alt.datum.category == 'main',  # Condition for the 'topic' column
        alt.value(0.75),  # If True, opacity is 0.2
        alt.value(0.25)  # If False, opacity is 0.5
    ),
    tooltip=['Name:N', 'title_abstract:N']
)

# Chart for 'main' category points
main_points = base.transform_filter(
    alt.datum.category == 'main'
).mark_circle().encode(
    color=alt.value('red')  # Color is red for 'main'
)

# Chart for other points, colored by 'Name'
other_points = base.transform_filter(
    alt.datum.category != 'main'
).mark_circle().encode(
    color='Name:N'  # Color mapped by 'Name'
)

# Combine the charts
new_plot = (main_points + other_points).properties(
    width=800,
    height=600,
).interactive()

new_plot.save(PROJECT_DIR / 'outputs/figures/scite_abstracts_noise_cluster.html')
# plot.save(PROJECT_DIR / 'outputs/figures/scite_abstracts.png')
viz_save.save(new_plot, 'scite_abstracts_noise_cluster', PROJECT_DIR / 'outputs/figures', save_png=True)

new_plot.display()