In [None]:
# # Create Snowflake Session object
import yaml
import os
import json
from snowflake.snowpark.session import Session

connection_parameters = json.load(open('connection.json'))

session = Session.builder.configs(connection_parameters).create()
session.sql_simplifier_enabled = True


In [None]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

def create_bandit_visualization(
    session,
    *,
    arms_table="ARMS",
    config_table="CONFIG",
    events_table="EVENTS",
    judge_outputs_table="JUDGE_OUTPUTS"
):
    """
    Animated visuals of bandit learning, parameterized by table names.
    """

    # Pull active models + costs
    model_rows = session.sql(
        f"""SELECT MODEL_ID, CREDITS_PER_MTOK
            FROM {arms_table}
            WHERE IS_ACTIVE
            ORDER BY MODEL_ID"""
    ).collect()
    if not model_rows:
        print("No active models in ARMS table.")
        return None, None

    all_models = [r[0] for r in model_rows]
    model_costs = {r[0]: float(r[1]) for r in model_rows}

    # perf_weight
    perf_weight = float(
        session.sql(
            f"SELECT VALUE FROM {config_table} WHERE KEY='perf_weight'"
        ).collect()[0][0]
    )

    # Exploration events joined to judge outputs
    df = session.sql(f"""
        SELECT 
            e.EVENT_ID,
            e.TS,
            e.CHOSEN_MODEL,
            e.IS_EXPLORATION,
            e.REWARD,
            jo.JUDGE_MODEL,
            jo.PARSED_SCORE,
            ROW_NUMBER() OVER (ORDER BY e.TS) AS ROUND_NUMBER
        FROM {events_table} e
        LEFT JOIN {judge_outputs_table} jo
          ON e.EVENT_ID = jo.EVENT_ID
        WHERE e.IS_EXPLORATION = TRUE
        ORDER BY e.TS, jo.JUDGE_MODEL
    """).to_pandas()

    if df.empty:
        print("No exploration data found!")
        return None, None

    performance_data = []
    cost_adjusted_data = []

    rounds = sorted(df["ROUND_NUMBER"].unique())
    model_cumulative = {m: {"total_score": 0.0, "count": 0} for m in all_models}

    for rnum in rounds:
        rdf = df[df["ROUND_NUMBER"] == rnum]
        chosen_model = rdf["CHOSEN_MODEL"].iloc[0]
        avg_score = (
            0.0 if rdf["PARSED_SCORE"].isna().all() else float(rdf["PARSED_SCORE"].mean())
        )

        # update chosen model perf
        model_cumulative[chosen_model]["total_score"] += avg_score
        model_cumulative[chosen_model]["count"] += 1

        # emit a row per model per round
        for m in all_models:
            cnt = model_cumulative[m]["count"]
            cum_avg = (model_cumulative[m]["total_score"] / cnt) if cnt > 0 else 0.0

            if cum_avg > 0:
                ce = 1.0 / model_costs[m] if model_costs[m] > 0 else 0.0
                cost_adj = perf_weight * cum_avg + (1 - perf_weight) * ce
            else:
                cost_adj = 0.0

            performance_data.append(
                {
                    "round": rnum,
                    "model": m,
                    "cumulative_avg_score": cum_avg,
                    "exploration_count": cnt,
                    "is_current_round_model": (m == chosen_model),
                }
            )
            cost_adjusted_data.append(
                {
                    "round": rnum,
                    "model": m,
                    "cost_adjusted_score": cost_adj,
                    "exploration_count": cnt,
                    "is_current_round_model": (m == chosen_model),
                    "model_cost": model_costs[m],
                    "pure_performance": cum_avg,
                }
            )

    perf_df = pd.DataFrame(performance_data)
    cost_df = pd.DataFrame(cost_adjusted_data)
    if perf_df.empty:
        print("No performance data to visualize!")
        return None, None

    # Pure performance animation
    fig1 = px.bar(
        perf_df,
        x="model",
        y="cumulative_avg_score",
        color="model",
        animation_frame="round",
        animation_group="model",
        title="Multi-Armed Bandit: Pure Judge Performance Over Time",
        labels={
            "cumulative_avg_score": "Cumulative Average Judge Score",
            "model": "Model",
            "round": "Exploration Round",
        },
        range_y=[0, 1],
        height=600,
        hover_data=["exploration_count"],
    )

    # Cost-adjusted animation
    ymax = max(3, float(cost_df["cost_adjusted_score"].max() or 0) * 1.1)
    fig2 = px.bar(
        cost_df,
        x="model",
        y="cost_adjusted_score",
        color="model",
        animation_frame="round",
        animation_group="model",
        title="Multi-Armed Bandit: Cost-Adjusted Performance Over Time",
        labels={
            "cost_adjusted_score": "Cost-Adjusted Score (Performance + Cost Efficiency)",
            "model": "Model",
            "round": "Exploration Round",
        },
        range_y=[0, ymax],
        height=600,
        hover_data=["exploration_count", "model_cost", "pure_performance"],
    )

    # common styling + slower animation
    for fig in (fig1, fig2):
        fig.update_layout(
            title_font_size=20,
            xaxis_title_font_size=14,
            yaxis_title_font_size=14,
            showlegend=True,
            template="plotly_white",
            xaxis={"categoryorder": "array", "categoryarray": all_models},
        )
        # slow animation
        fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 1500
        fig.layout.updatemenus[0].buttons[0].args[1]["transition"]["duration"] = 500

    # per-frame annotations (n = exploration_count)
    for f1, f2 in zip(fig1.frames, fig2.frames):
        rnum = int(f1.name)
        p_slice = perf_df[perf_df["round"] == rnum]
        c_slice = cost_df[cost_df["round"] == rnum]

        f1.layout.annotations = [
            dict(
                x=row["model"],
                y=row["cumulative_avg_score"] + 0.05,
                text=f"n={int(row['exploration_count'])}",
                showarrow=False,
                font=dict(size=10, color="black"),
            )
            for _, row in p_slice.iterrows()
        ]
        f2.layout.annotations = [
            dict(
                x=row["model"],
                y=row["cost_adjusted_score"] + 0.1,
                text=f"n={int(row['exploration_count'])}",
                showarrow=False,
                font=dict(size=10, color="black"),
            )
            for _, row in c_slice.iterrows()
        ]

    return fig1, fig2


def create_detailed_dashboard(
    session,
    *,
    arm_stats_table="ARM_STATS",
    events_table="EVENTS"
):
    """
    Static dashboard with current scores, exploration/exploitation, EWMA, cost vs perf.
    """

    arm_stats = session.sql(f"""
        SELECT MODEL_ID, PULLS, EWMA_PERF, EWMA_COST, SCORE, LAST_UPDATE
        FROM {arm_stats_table}
        ORDER BY SCORE DESC
    """).to_pandas()

    exploration_df = session.sql(f"""
        SELECT 
            CHOSEN_MODEL,
            IS_EXPLORATION,
            COUNT(*) AS COUNT,
            AVG(REWARD) AS AVG_REWARD
        FROM {events_table}
        GROUP BY CHOSEN_MODEL, IS_EXPLORATION
        ORDER BY CHOSEN_MODEL, IS_EXPLORATION
    """).to_pandas()

    fig = make_subplots(
        rows=2,
        cols=2,
        subplot_titles=[
            "Current Model Scores (ARM_STATS)",
            "Exploration vs Exploitation Count",
            "Model Performance (EWMA)",
            "Cost vs Performance",
        ],
        specs=[[{"type": "bar"}, {"type": "bar"}], [{"type": "bar"}, {"type": "scatter"}]],
    )

    # Plot 1: Scores
    if not arm_stats.empty:
        fig.add_trace(
            go.Bar(x=arm_stats["MODEL_ID"], y=arm_stats["SCORE"], name="Current Score"),
            row=1,
            col=1,
        )

    # Plot 2: Exploration vs exploitation
    if not exploration_df.empty:
        pivot = exploration_df.pivot(
            index="CHOSEN_MODEL", columns="IS_EXPLORATION", values="COUNT"
        ).fillna(0)
        if False in pivot.columns:
            fig.add_trace(
                go.Bar(x=pivot.index, y=pivot[False], name="Exploitation"),
                row=1,
                col=2,
            )
        if True in pivot.columns:
            fig.add_trace(
                go.Bar(x=pivot.index, y=pivot[True], name="Exploration"),
                row=1,
                col=2,
            )

    # Plot 3: EWMA perf
    if not arm_stats.empty:
        fig.add_trace(
            go.Bar(x=arm_stats["MODEL_ID"], y=arm_stats["EWMA_PERF"], name="EWMA Perf"),
            row=2,
            col=1,
        )

    # Plot 4: cost vs perf scatter
    if not arm_stats.empty:
        fig.add_trace(
            go.Scatter(
                x=arm_stats["EWMA_COST"],
                y=arm_stats["EWMA_PERF"],
                mode="markers+text",
                text=arm_stats["MODEL_ID"],
                textposition="top center",
                name="Models",
                marker=dict(size=arm_stats["PULLS"] * 3 + 10, opacity=0.7),
            ),
            row=2,
            col=2,
        )

    fig.update_layout(height=800, title_text="Multi-Armed Bandit Dashboard", showlegend=True)
    return fig


In [None]:
pure_performance_fig, cost_adjusted_fig = create_bandit_visualization(session)
pure_performance_fig.show()
cost_adjusted_fig.show()

dashboard_fig = create_detailed_dashboard(session)
dashboard_fig.show()