In [None]:
import altair as alt
import matplotlib.pyplot as plt
import pandas as pd
from sentence_transformers import SentenceTransformer

from nesta_ds_utils.viz.altair import saving as viz_save

from dsp_ai_eval import PROJECT_DIR, config
from dsp_ai_eval.getters.gpt import get_gpt_themes_embeddings

model = SentenceTransformer(config['embedding_model'])

SEED = config["seed"]

pd.set_option('display.width', 1000)

In [None]:
answers_long = get_gpt_themes_embeddings()

In [None]:
# Aggregate the data to count the occurrences of each heading under each temperature, under each gpt_model
agg_df = answers_long.groupby(['gpt_model', 'temperature', 'heading']).size().reset_index(name='counts')

# Sorting the DataFrame by counts in descending order to get top headings
agg_df = agg_df.sort_values(['gpt_model', 'temperature', 'counts'], ascending=[True, True, False])

top_5_per_group = agg_df.groupby(['gpt_model', 'temperature']).head(5)

n_headings = agg_df.groupby(['gpt_model', 'temperature']).agg('size').reset_index(name='count')

chart = alt.Chart(n_headings).mark_bar().encode(
    x='temperature:N',  # Treat gpt_model as a nominal categorical variable
    y='count:Q',      # Quantitative scale for count
    color='temperature:N',  # Color bars by temperature, treated as a nominal categorical variable
    column='gpt_model:N'  # Separate charts for each temperature value for clarity
).properties(
    width=400,  # Adjust the width of each chart
    height=400  # Adjust the height of the chart
)

chart.save(PROJECT_DIR / f"outputs/figures/gpt_n_headings.html")
viz_save.save(chart, f"gpt_n_headings", PROJECT_DIR / "outputs/figures", save_png=True)

chart.display()

In [None]:
# We create a new column to serve as the x-axis labels, combining 'heading' with 'gpt_model' and 'temperature'.
top_5_per_group['label'] = top_5_per_group.apply(lambda x: f"{x['heading']}, ({x['gpt_model']}, {x['temperature']})", axis=1)

# Creating the Altair chart
chart = alt.Chart(top_5_per_group).mark_bar().encode(
    x='counts:Q',  # Quantitative scale for counts
    y=alt.Y('heading:N'),  # Nominal scale for labels, sorted by counts
    #color=alt.Color('gpt_model:N', legend=alt.Legend(title='GPT Model')),  # Differentiate by GPT model
    color='temperature:N',  # Color bars by temperature, treated as a nominal categorical variable
    column='gpt_model:N', # Separate charts for each temperature value for clarity
    tooltip=['heading', 'gpt_model', 'temperature', 'counts']  # Show these details when hovering over a bar
).properties(
    width=250,
    height=500,
    title='Top Headings by GPT Model and Temperature'
)

chart.save(PROJECT_DIR / f"outputs/figures/gpt_frequent_headings.html")
viz_save.save(chart, f"gpt_frequent_headings", PROJECT_DIR / "outputs/figures", save_png=True)

chart.display()

In [None]:
# A matplotlib version showing the same info as above
fig, ax = plt.subplots(figsize=(10, 8))
for (gpt_model, temperature), group in top_5_per_group.groupby(['gpt_model', 'temperature']):
    ax.barh(group['heading'] + f" ({gpt_model}, {temperature})", group['counts'], label=f'{gpt_model}, {temperature}')

ax.set_xlabel('Counts')
ax.set_title('Top Headings by GPT Model and Temperature')
plt.legend()
plt.show()

fig.savefig(PROJECT_DIR / 'outputs/figures/plt_headings.png', dpi=300)