In [None]:
import altair as alt
import ast
import pandas as pd

from dsp_ai_eval import config
from dsp_ai_eval.getters.gpt import get_gpt_themes_embeddings, get_representative_docs, get_topics, get_probs, get_topic_model, get_cluster_summaries_cleaned
from dsp_ai_eval.utils.clustering_utils import create_df_for_viz

SEED = config["seed"]

In [None]:
def temp_plot(df_vis, gpt_model = 'gpt-3.5-turbo', temps = [0, 0.25, 0.5]):
    temp_scale = alt.Scale(domain=[0, 0.25, 0.5, 1],
                        range=['#0d0887', '#7e03a8', '#cc4778', '#f0f921'])
    
    fig = (
        alt.Chart(df_vis[(df_vis['gpt_model'] == gpt_model) & df_vis['temperature'].isin(temps)])
        .mark_circle(size=200)
        .encode(
            x='x',
            xOffset="random:Q",
            y='y',
            yOffset="random:Q",
            color=alt.Color('temperature', scale=temp_scale
                            ),
            opacity=alt.value(0.5),
            tooltip=['topic_name', 'doc'],
        ).transform_calculate(
    random="random()"
)
        .properties(width=800, height=600)
        .interactive()
    )
    
    return fig

In [None]:
answers_long = get_gpt_themes_embeddings()
    
docs = answers_long['answer_cleaned'].tolist()
answers_long['embeddings'] = answers_long['embeddings'].apply(ast.literal_eval)
embeddings = answers_long["embeddings"].apply(pd.Series).values
    
topic_model = get_topic_model()

cluster_summaries = get_cluster_summaries_cleaned()

topics = get_topics()
probs = get_probs()
representative_docs = get_representative_docs()

df_vis = create_df_for_viz(embeddings, topic_model, topics, docs, seed=SEED)
    
df_vis = df_vis.merge(cluster_summaries, on='topic', how='left')
df_vis = df_vis.merge(answers_long[['answer_cleaned', 'temperature', 'gpt_model', 'heading']], left_index=True, right_index=True)
    
df_vis["topic_name"].fillna("NA", inplace=True)
df_vis["heading"].fillna("NA", inplace=True)

In [None]:
temp_plot(df_vis, gpt_model = 'gpt-3.5-turbo', temps = [0, 0.25, 0.5, 1])

In [None]:
temp_plot(df_vis, gpt_model = 'gpt-4', temps = [0, 0.25, 0.5, 1])