In [None]:
import altair as alt
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance, OpenAI, PartOfSpeech
from dotenv import load_dotenv
from hdbscan import HDBSCAN
import json
import networkx as nx
import numpy as np
import openai
import os
import pandas as pd
from pathlib import Path
import re
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import spacy
from umap import UMAP

from nesta_ds_utils.viz.altair import saving as viz_save

from dsp_ai_eval import PROJECT_DIR, logging
from dsp_ai_eval.utils import text_cleaning as tc
from dsp_ai_eval.utils.clustering_utils import create_new_topic_model, create_df_for_viz, get_top_docs_per_topic

# Increase the maximum number of rows Altair will process
alt.data_transformers.disable_max_rows()

embedding_model = SentenceTransformer('all-miniLM-L6-v2')

load_dotenv()

OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")

In [None]:
scite_abstracts = pd.read_parquet(PROJECT_DIR / "inputs/data/embeddings/scite_embeddings.parquet")
scite_abstracts.head()

In [None]:
with open(PROJECT_DIR / "outputs/data/cluster_summaries.json") as file:
    cluster_summaries = json.load(file)
    
cluster_summaries

In [None]:
def to_snake_case(s):
    # Replace all non-word characters (everything except letters and numbers) with an underscore
    s = re.sub(r'\W+', '_', s)
    # Convert to lowercase
    s = s.lower()
    # Remove leading and trailing underscores
    s = s.strip('_')
    return s

def clean_column_names(df):
    """
    Converts all column names to snake case and strips leading or trailing punctuation.

    :param df: pandas DataFrame with any column names
    :return: pandas DataFrame with cleaned column names
    """
    new_columns = {col: to_snake_case(col) for col in df.columns}
    return df.rename(columns=new_columns)



# Convert the nested dictionary into a list of dictionaries
data = [v for k, v in cluster_summaries.items()]

# Create a DataFrame
df = pd.DataFrame(data).reset_index().rename(columns={'index': 'topic'})

df = clean_column_names(df)

df[['topic', 'name', 'description', 'docs', 'keywords']].to_csv(PROJECT_DIR / "outputs/data/cluster_summaries_cleaned.csv", index=False)

In [None]:
df = pd.DataFrame()

for key, value in cluster_summaries.items():
    temp_df = pd.DataFrame(value)
    temp_df['topic'] = int(key)
    pd.concat([df, temp_df])

In [None]:
cluster_summaries['0']['Name:']

In [None]:
for cluster, summary in cluster_summaries.items():
    print(f"Cluster {cluster}: {summary['Name:']}")
    print(f"{summary['Description:']}")

In [None]:
scite_abstracts['category'].value_counts()

In [None]:
scite_abstracts['total_cites'].describe()

Following [this best practices guide](https://maartengr.github.io/BERTopic/getting_started/best_practices/best_practices.html)

In [None]:
docs = scite_abstracts['title_abstract'].to_list()
embeddings = scite_abstracts['embeddings'].apply(pd.Series).values

In [None]:
topic_model = BERTopic.load(PROJECT_DIR / "outputs/models/bertopic_abstracts_model", embedding_model=embedding_model)

TOPICS_INPATH = PROJECT_DIR / "outputs/data/bertopic_abstracts_model_topics.pkl"
PROBS_INPATH = PROJECT_DIR / "outputs/data/bertopic_abstracts_model_probs.npy"
REPRESENTATIVE_DOCS_INPATH = PROJECT_DIR / "outputs/data/bertopic_abstracts_representative_docs.pkl"

topics = pd.read_pickle(TOPICS_INPATH)
probs = np.load(PROBS_INPATH)
representative_docs = pd.read_pickle(REPRESENTATIVE_DOCS_INPATH)

# Show topics
topic_model.get_topic_info()

In [None]:
representative_docs

In [None]:
topic_model.visualize_barchart()

In [None]:
topic_model.visualize_distribution(probs[30])

In [None]:
scite_abstracts, top_docs_per_topic = get_top_docs_per_topic(scite_abstracts, topics, docs, probs,10)

In [None]:
top_docs_per_topic[0]

In [None]:
topic_model.visualize_topics()

In [None]:
df_vis = create_df_for_viz(embeddings,
                      topic_model,
                      topics,
                      docs,
                      seed=42)

df_vis.head(20)

In [None]:
df_vis = df_vis.merge(scite_abstracts, left_on="doc", right_on="title_abstract", how="left")
df_vis.head()

In [None]:
df_vis.columns

In [None]:
# Define the base chart with common encodings
base = alt.Chart(#df_vis[df_vis['x']>5]
                 df_vis
                 ).encode(
    x='x',
    y='y',
    size=alt.condition(
        alt.datum.category == 'main',  # Condition for the 'category' column
        alt.value(200),  # If True, size is 50
        alt.value(30)  # If False, size is 30
    ),
    opacity=alt.condition(
        alt.datum.category == 'main',  # Condition for the 'topic' column
        alt.value(1),
        alt.value(0.5)  # If False, opacity is 0.5
    ),
    tooltip=['Name:N', 'doc:N']
)

# Chart for 'main' category points
main_points = base.transform_filter(
    alt.datum.category == 'main'
).mark_circle().encode(
    color=alt.value('red')  # Color is red for 'main'
)

# Chart for other points, colored by 'Name'
other_points = base.transform_filter(
    alt.datum.category != 'main'
).mark_circle().encode(
    color='Name:N'  # Color mapped by 'Name'
)

# Combine the charts
plot = (main_points + other_points).properties(
    width=800,
    height=600,
).interactive()


plot.save(PROJECT_DIR / 'outputs/figures/scite_abstracts.html')
# plot.save(PROJECT_DIR / 'outputs/figures/scite_abstracts.png')
viz_save.save(plot, 'scite_abstracts', PROJECT_DIR / 'outputs/figures', save_png=True)

plot.display()

In [None]:
# Define the base chart with common encodings
base = alt.Chart(#df_vis[df_vis['x']>5]
                 df_vis
                 ).transform_calculate(
    # Create a new field for size, multiplying 'total_cites' by 10
    size_calculated='datum.total_cites * 10'#'log(datum.total_cites + 1)'#'datum.total_cites * 10'
).encode(
    x='x',
    y='y',
    size=alt.Size('size_calculated:Q', scale=alt.Scale(range=[0, 2000])),  # Use the calculated field for size
    opacity=alt.condition(
        alt.datum.category == 'main',  # Condition for the 'topic' column
        alt.value(1),
        alt.value(0.5)  # If False, opacity is 0.5
    ),
    tooltip=['Name:N', 'doc:N', 'total_cites:N']
)

# Chart for 'main' category points
main_points = base.transform_filter(
    alt.datum.category == 'main'
).mark_circle().encode(
    color=alt.value('red')  # Color is red for 'main'
)

# Chart for other points, colored by 'Name'
other_points = base.transform_filter(
    alt.datum.category != 'main'
).mark_circle().encode(
    color='Name:N'  # Color mapped by 'Name'
)

# Combine the charts
plot = (main_points + other_points).properties(
    width=800,
    height=600,
).interactive()


plot.save(PROJECT_DIR / 'outputs/figures/scite_abstracts.html')
# plot.save(PROJECT_DIR / 'outputs/figures/scite_abstracts.png')
viz_save.save(plot, 'scite_abstracts', PROJECT_DIR / 'outputs/figures', save_png=True)

plot.display()

In [None]:
# Create the plot
plot = alt.Chart(df_vis[(df_vis['topic']!=-1)] # get rid of outliers
                 ).mark_circle(size=30, opacity=0.5).encode(
    x='x',
    y='y',
    color='Name:N',
    tooltip=['Name:N','doc:N']
).properties(
    width=800,
    height=600,
).interactive()

# plot.save(PROJECT_DIR / 'outputs/figures/scite_abstracts.png')
viz_save.save(plot, 'scite_abstracts_filtered', PROJECT_DIR / 'outputs/figures', save_png=True)

plot.display()