In [1]:
%load_ext autoreload
%autoreload 2
from plots.data_functions import *

[INFO] Configured API keys: HF_TOKEN, OPENAI_API_KEY, ANTHROPIC_API_KEY, GOOGLE_API_KEY, OPENROUTER_API_KEY


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go

# -----------------------------------------------------------
#  stance palette (same one we used before)
# -----------------------------------------------------------
STANCE_COLOURS = {
    # "Consequentialism": "#E0B274", # goldish-yellow
    # "Deontology":       "#8CC888", # light-green

    "Consequentialism": "#E07474",  # Deeper pastel red
    "Deontology":       "#74A0E0",  # Deeper pastel blue

    # "Consequentialism": "#E9A178",  # Dusty Orange
    # "Deontology":       "#A8C3A0",  # Muted Sage Green

    "Other":            "#D9D9D9",
    "None":             "#000000",
}

# stance_cols = ["Consequentialism", "Deontology", "Other"]
stance_cols = ["Other", "Deontology", "Consequentialism"]

category2two_groups = {
    'Species':        ['Animals', 'Humans'],
    'SocialValue':    ['Low', 'High'],
    'Gender':         ['Female', 'Male'],
    'Age':            ['Young', 'Old'],
    'Fitness':        ['Unfit', 'Fit'],
    'Utilitarianism': ['Less', 'More'],
}

label_map = {
    "Animals": "Sparing Animals",
    "Humans": "Sparing Humans",
    "Low": "Sparing Low Status",
    "High": "Sparing High Status",
    "Female": "Sparing Women",
    "Male": "Sparing Men",
    "Young": "Sparing Young",
    "Old": "Sparing Old",
    "Unfit": "Sparing Unfit",
    "Fit": "Sparing Fit",
    "Less": "Sparing Individuals",
    "More": "Sparing Groups",
    "Consequentialism": "Consequentialism",
    "Deontology": "Deontology",
    "Other": "Other",
}

# Define colors for different categories
CATEGORY_COLORS = {
    'Species': 'rgba(220, 220, 220, 0.6)',
    'SocialValue': 'rgba(150, 150, 150, 0.6)',
    'Gender': 'rgba(220, 220, 220, 0.6)',
    'Age': 'rgba(150, 150, 150, 0.6)',
    'Fitness': 'rgba(220, 220, 220, 0.6)',
    'Utilitarianism': 'rgba(150, 150, 150, 0.6)',
}

# -----------------------------------------------------------
#  the new function
# -----------------------------------------------------------
def make_sankey(
        df: pd.DataFrame,
        node_pad: int = 15,
        node_thickness: int = 170,
        width: int = 900,
        height: int = 500,
) -> go.Figure:
    """
    Build a static Sankey / alluvial plot where:
      • Left-hand nodes (sub-categories) follow the order of
        `category2two_groups`.
      • All links (“snakes”) are coloured by ethical stance.
      • Right-hand nodes (stances) use the same palette.
    """
    # ------------------------------------------------------------------
    # 1) Figure out ordered sub-category list from the dict’s insertion order
    # ------------------------------------------------------------------
    subcats_ordered = [
        subgroup
        for _, two in category2two_groups.items()
        for subgroup in two
    ]

    # ------------------------------------------------------------------
    # 2) Re-index dataframe so rows follow that order
    # ------------------------------------------------------------------
    df_ord = (
        df.set_index("Subcategory")
            .reindex(subcats_ordered)    # will raise if a subcat is missing
            .reset_index()
    )

    # ------------------------------------------------------------------
    # 3) Explode into one row per (subcat, stance) with absolute flow value
    # ------------------------------------------------------------------
    flows = []
    for _, row in df_ord.iterrows():
        for stance in stance_cols:
            # Values are raw counts of rationales per macro category
            # flow_value = row[stance+"Count"]

            # Values are relative counts per row
            # row_count = row['Total_Rationales']
            # flow_value = row[stance+"Count"] / row_count if row_count > 0 else 0

            # Values are relative counts of rationales per sub-category
            category = row["Category"]
            cat_count = df_ord[df_ord['Category'] == category]['Total_Rationales'].sum()
            flow_value = row[stance+"Count"] / cat_count if cat_count > 0 else 0

            if flow_value > 0 and not any(stance == r for r in ["Refusal", "", "Other"]):   # omit empty ribbons
                flows.append(
                    dict(
                        source=row["Subcategory"],
                        target=stance,
                        value=flow_value,
                    )
                )
    flow_df = pd.DataFrame(flows)

    # ------------------------------------------------------------------
    # 4) Build node & link arrays that Plotly expects
    # ------------------------------------------------------------------
    # subcats_present = [s for s in flow_df.source.unique() if s in subcats_ordered]
    # stance_present = [s for s in flow_df.target.unique() if s in stance_cols]

    subcats_present = [s for s in subcats_ordered if s in flow_df.source.unique()]
    stance_present = [s for s in stance_cols if s in flow_df.target.unique()]

    nodes = subcats_present + stance_present
    node_index = {name: i for i, name in enumerate(nodes)}

    link_source = flow_df["source"].map(node_index)
    link_target = flow_df["target"].map(node_index)
    link_value  = flow_df["value"]
    link_colour = flow_df["target"].map(STANCE_COLOURS)   # <-- stance colours

    eps = 1e-3
    x = [eps] * len(subcats_present) + [1 - eps] * len(stance_present)
    y = (
        [0.0] * len(subcats_present) +
        [0.0] * len(stance_present)  # Placeholder values that will be updated
    )

    # Update y positions for source nodes
    gap_intra = 0.010     # spacing between nodes in the same category
    gap_inter = 0.060     # spacing between categories

    # Build a helper table with the total flow leaving every sub-category
    source_df = (
        flow_df.groupby('source', as_index=False)['value']
        .sum()
        .assign(percentage=lambda d: d['value'] / d['value'].sum())
    )

    # Preserve the dict insertion order of the macro-categories,
    # but keep only the sub-categories that are actually present.
    cats_present = [
        (cat, [s for s in category2two_groups[cat] if s in subcats_present])
        for cat in category2two_groups
        if any(s in subcats_present for s in category2two_groups[cat])
    ]

    # Count how many gaps we will need
    n_intra  = sum(max(len(subs) - 1, 0) for _, subs in cats_present)
    n_inter  = max(len(cats_present) - 1, 0)

    # Vertical real-estate left for the bars after the gaps
    space_for_bars = 1 - 2 * eps - n_intra * gap_intra - n_inter * gap_inter

    # Allocate a height to each sub-category proportional to its flow
    source_df['height'] = source_df['percentage'] * space_for_bars

    # Calculate mid-points
    midpoints = {}
    cursor    = 1 - eps

    for ci, (cat, subs) in enumerate(cats_present):
        for si, sub in enumerate(subs):
            h        = source_df.loc[source_df['source'] == sub, 'height'].values[0]
            cursor  -= h / 2                     # move to the middle of the bar
            midpoints[sub] = cursor
            cursor  -= h / 2                     # move to the bottom edge

            # gap after every sub-category except the last one in its category
            if si < len(subs) - 1:
                cursor -= gap_intra

        # larger gap between categories (except after the last category)
        if ci < len(cats_present) - 1:
            cursor -= gap_inter

    # Transfer the computed y positions into the master y list
    for i, sub in enumerate(subcats_present):
        y[i] = midpoints[sub]


    # Update y positions for target nodes
    target_df = (
        flow_df.groupby('target', as_index=False)['value']
        .sum()
        .assign(percentage=lambda df: df['value'] / df['value'].sum())
    )

    # Sort stance nodes in same order as stance_present
    target_df['order'] = target_df['target'].apply(lambda x: stance_present.index(x))
    target_df = target_df.sort_values('order')

    padding = 0.025  # same padding as source
    num_gaps = len(stance_present) - 1
    spacing = 1 - 2 * eps - padding * num_gaps  # reduce space for actual bars

    cumulative_target = target_df['percentage'].cumsum()
    target_df['midpoint'] = 1 - eps - (
        cumulative_target - target_df['percentage'] / 2
    ) * spacing - np.arange(len(target_df)) * padding

    for i, stance in enumerate(stance_present):
        y[len(subcats_present) + i] = target_df[target_df['target'] == stance]['midpoint'].values[0]

    # ------------------------------------------------------------------
    # 5) Create the Sankey figure
    # ------------------------------------------------------------------

    # Create a list for node colors that we can modify before creating the figure
    node_colors = ["rgba(220,220,220,0.6)"] * len(nodes)

    # Calculate color mixtures for source nodes based on their stance distribution
    for i, node_name in enumerate(subcats_present):
        if node_name in flow_df['source'].values:
            # Get all outgoing flows for this source node
            node_flows = flow_df[flow_df['source'] == node_name]

            total_flow = node_flows['value'].sum()

            if total_flow > 0:
                # Get percentage flow to each stance
                stance_percentages = {}
                for stance in stance_present:
                    stance_flow = node_flows[node_flows['target'] == stance]['value'].sum()
                    stance_percentages[stance] = stance_flow / total_flow

                # If only flows to one stance, use that stance's color
                if len(stance_percentages) == 1:
                    stance = list(stance_percentages.keys())[0]
                    node_colors[i] = STANCE_COLOURS[stance]
                else:
                    # Mix colors based on percentages
                    # For simplicity, we'll use weighted average of RGB components
                    r, g, b = 0, 0, 0
                    for stance, percentage in stance_percentages.items():
                        # Convert hex color to RGB
                        color = STANCE_COLOURS[stance].lstrip('#')
                        stance_r, stance_g, stance_b = tuple(int(color[j:j+2], 16) for j in (0, 2, 4))

                        # Add weighted contribution
                        r += stance_r * percentage
                        g += stance_g * percentage
                        b += stance_b * percentage

                    # Convert back to hex
                    node_colors[i] = f'#{int(r):02x}{int(g):02x}{int(b):02x}'

    # Update colors for nodes before creating the figure
    for i, node_name in enumerate(nodes):
        # For stance nodes
        if node_name in STANCE_COLOURS:
            node_colors[i] = STANCE_COLOURS[node_name]

        # For subcategory nodes
        # for category, group in category2two_groups.items():
        #     if node_name in group:  # If node belongs to this category group
        #         node_colors[i] = CATEGORY_COLORS[category]

    node_pad=node_pad
    node_thickness=node_thickness
    fig = go.Figure(
        go.Sankey(
            arrangement="fixed",
            node=dict(
                pad=node_pad,
                thickness=node_thickness,
                line=dict(width=0.5, color="rgba(0,0,0,0.25)"),
                # label=[label_map.get(n, n) for n in nodes],
                # left nodes → light grey, right nodes → stance colour
                color=node_colors,
                x=x,
                y=y,
            ),
            link=dict(
                source=link_source,
                target=link_target,
                value=link_value,
                color=link_colour,
            ),
        )
    )

    for i, n in enumerate(nodes):
        if n in df_ord.Subcategory.values:           # sub-category
            cat        = df_ord.loc[df_ord['Subcategory'] == n, 'Category'].values[0]
            cat_total  = df_ord[df_ord['Category'] == cat]['Count'].sum()
            subcount   = df_ord.loc[df_ord['Subcategory'] == n, 'Count'].values[0]
            pct        = (subcount / cat_total) * 100 if cat_total else 0
            suffix     = f" ({pct:.0f}%)"
            suffix     = ""
        elif n in stance_present:                    # stance
            stance_count = df_ord[n+'Count'].sum()
            total_count  = df_ord['Count'].sum()
            pct    = (stance_count / total_count) * 100 if total_count else 0
            suffix = f" ({stance_count})"
            suffix = ""
        else:
            suffix = ""
        fig.add_annotation(
            x=x[i],
            y=1 - y[i],
            text=label_map.get(n, n) + suffix,
            showarrow=False,
            xanchor="center",
            yanchor="middle",
            font=dict(size=12),
        )

    # Add category labels on the left side
    categories_present = {}
    for i, node_name in enumerate(subcats_present):
        if node_name in source_df['source'].values:
            cat = df_ord.loc[df_ord['Subcategory'] == node_name, 'Category'].values[0]
            if cat not in categories_present:
                categories_present[cat] = []
            categories_present[cat].append((i, node_name, y[i]))

    # evenly space out the category labels
    spacing = 1 - 2 * gap_intra - gap_inter
    y_positions = np.linspace(spacing, 1 - spacing, len(categories_present))[::-1]
    for i, (cat, nodes) in enumerate(categories_present.items()):
        if len(nodes) > 0:
            # Add the category annotation
            fig.add_annotation(
                x=-0.215,  # Position to the left of the nodes
                y=y_positions[i],
                text=f"<b>{cat}</b>",  # Make the text bold with HTML tags
                showarrow=False,
                xanchor="center",
                yanchor="middle",
                font=dict(size=14, color="rgba(0,0,0,0.7)"),
            )


    # Add annotation for each category
    # for cat, nodes in categories_present.items():
    #     if len(nodes) > 0:
    #         # Calculate center position for the category label
    #         avg_y = sum(pos for _, _, pos in nodes) / len(nodes)

    #         # Add the category annotation
    #         fig.add_annotation(
    #             x=-0.212,  # Position to the left of the nodes
    #             y=1 - avg_y,
    #             text=f"<b>{cat}</b>",  # Make the text bold with HTML tags
    #             showarrow=False,
    #             xanchor="center",
    #             yanchor="middle",
    #             font=dict(size=14, color="rgba(0,0,0,0.7)"),
    #         )

    # ------------------------------------------------------------------
    # 6) Colour the right-hand (stance) nodes to match their links
    # ------------------------------------------------------------------

    tb = 15
    lr = 100
    fig.update_layout(
        # title_text="Sub-category → Ethical Stance flow",
        font_size=12,
        margin=dict(l=lr+85, r=lr, t=tb, b=tb),
        width=width,
        height=height,
    )
    return fig

In [682]:
RESULTS_DIR = "data/20250422/all_models/"
RESULTS_DIR = "data/20250501/all_models2/"

file_paths = [fp for fp in glob(f"{RESULTS_DIR}/judge/*.csv")]

# Create a dictionary to store dataframes by model
model_dfs = {}

# Group dataframes by model
for file in file_paths:
    # Extract model_id and sample number from filename
    filename = os.path.basename(file)
    model_id = filename.split('_')[0]
    sample_num = filename.split('_w')[1].replace('.csv', '')

    if model_id not in model_dfs:
        model_dfs[model_id] = {}

    model_dfs[model_id][sample_num] = pd.read_csv(file, keep_default_na=False)

In [None]:
category2two_groups = {
    'Species':        ['Animals', 'Humans'],
    'SocialValue':    ['Low', 'High'],
    'Gender':         ['Female', 'Male'],
    'Age':            ['Young', 'Old'],
    'Fitness':        ['Unfit', 'Fit'],
    'Utilitarianism': ['Less', 'More'],
}

def get_catdf_single(df: pd.DataFrame, num1num2):
    results = {}

    for category, two_groups in category2two_groups.items():
        # Filter dataframe for current phenomenon category
        category_df = df[df['phenomenon_category'] == category]

        if num1num2 == '==':
            category_df = category_df[category_df.num1 == category_df.num2]
        elif num1num2 == '!=':
            category_df = category_df[category_df.num1 != category_df.num2]

        # Process each group in the category
        for group in two_groups:
            group_df = category_df[category_df['decision_category'] == group]

            # Initialize counters
            macro_counts = {}
            total_rationales = 0

            # Process each row's rationales
            for rationales_str in group_df['rationales']:
                if isinstance(rationales_str, str) and rationales_str:
                    rationales_list = rationales_str.split('; ')
                    for rationale in rationales_list:
                        total_rationales += 1
                        # Use TAXONOMY_MACRO_MAP to categorize each rationale
                        if rationale in TAXONOMY_MACRO_MAP:
                            macro_category = TAXONOMY_MACRO_MAP[rationale]
                            if macro_category not in macro_counts:
                                macro_counts[macro_category] = 0
                            macro_counts[macro_category] += 1

            # Calculate proportions
            if total_rationales > 0:
                if category not in results:
                    results[category] = {}

                results[category][group] = {
                    macro: count / total_rationales for macro, count in macro_counts.items()
                }

                # Add totals for reference
                results[category][group]['Total_Rationales'] = total_rationales
                results[category][group]['Count'] = len(group_df)
                results[category][group]['ConsequentialismCount'] = macro_counts.get('Consequentialism', 0)
                results[category][group]['DeontologyCount'] = macro_counts.get('Deontology', 0)
                results[category][group]['OtherCount'] = macro_counts.get('Other', 0)
            else:
                # Handle the case where there are no rationales
                if category not in results:
                    results[category] = {}

                results[category][group] = {
                    'Consequentialism': 0.0, 'Deontology': 0.0, 'Other': 0.0,
                    'Total_Rationales': 0, 'Count': len(group_df)
                }

    # Create a DataFrame for better visualization
    cat_df = pd.DataFrame()

    for category, groups in results.items():
        for group, values in groups.items():
            row_data = {
                'Category': category,
                'Subcategory': group,
            }

            # for macro in ['Consequentialism', 'Deontology', 'Other']:
            #     row_data[macro] = values.get(macro, 0)

            for macro in ['Count', 'ConsequentialismCount', 'DeontologyCount', 'OtherCount']:
                row_data[macro] = values.get(macro, 0)

            row_data['Total_Rationales'] = values.get('Total_Rationales', 0)

            cat_df = pd.concat([cat_df, pd.DataFrame([row_data])], ignore_index=True)

    # Organize the DataFrame
    cat_df = cat_df.sort_values(['Category', 'Subcategory'])
    # subcategory_df = subcategory_df[['Category', 'Subcategory', 'Count', 'Consequentialism', 'Deontology', 'Other', 'Total_Rationales']]

    # Display results
    return cat_df


def get_catdf_all_samples(dfs: dict, num1num2):
    """Get the category dataframe for all samples in a model."""
    all_cat_dfs = []
    for sample_num, df in dfs.items():
        cat_df = get_catdf_single(df, num1num2)
        cat_df['Sample'] = sample_num
        all_cat_dfs.append(cat_df)

    # Concatenate all dataframes
    all_cat_df = pd.concat(all_cat_dfs, ignore_index=True)

    # Sum counts for each category and subcategory
    all_cat_df = all_cat_df.groupby(['Category', 'Subcategory']).sum().reset_index()
    all_cat_df = all_cat_df.drop(columns=['Sample'])
    return all_cat_df


def get_catdf(df: dict | pd.DataFrame, num1num2=False):
    if isinstance(df, dict):
        return get_catdf_all_samples(df, num1num2)
    elif isinstance(df, pd.DataFrame):
        return get_catdf_single(df, num1num2)
    else:
        raise ValueError("Input must be a dictionary of dataframes or a single dataframe.")

In [880]:
cat_df = get_catdf(model_dfs['gpt-4o-mini-2024-07-18'])
fig = make_sankey(cat_df)
# fig.write_image("sankey_v2.svg")     # needs `pip install kaleido`
fig

In [879]:
cat_df = get_catdf(model_dfs['gemini-2.5-flash-preview'])
fig = make_sankey(cat_df)
# fig.write_image("sankey_v2.svg")     # needs `pip install kaleido`
fig

In [678]:
cat_df = get_catdf(model_dfs['llama-4-scout'])
fig = make_sankey(cat_df)
# fig.write_image("sankey_v2.svg")     # needs `pip install kaleido`
fig

In [679]:
cat_df = get_catdf(model_dfs['gpt-4o-2024-08-06'])
fig = make_sankey(cat_df)
# fig.write_image("sankey_v2.svg")     # needs `pip install kaleido`
fig

In [680]:
cat_df = get_catdf(model_dfs['gpt-4.1-2025-04-14'])
fig = make_sankey(cat_df)
# fig.write_image("sankey_v2.svg")     # needs `pip install kaleido`
fig

In [903]:
cat_df = get_catdf(model_dfs['gpt-4o-mini-2024-07-18'], num1num2='==')
fig = make_sankey(cat_df)
# fig.write_image("sankey_v2.svg")     # needs `pip install kaleido`
fig

In [905]:
cat_df = get_catdf(model_dfs['gemini-2.5-flash-preview'], num1num2='==')

fig = make_sankey(cat_df)
# fig.write_image("sankey_v2.svg")     # needs `pip install kaleido`
fig

In [709]:
cat_df = get_catdf(model_dfs['gpt-3.5-turbo-0125'], num1num2='')
fig = make_sankey(cat_df)
# fig.write_image("sankey_v2.svg")     # needs `pip install kaleido`
fig

In [710]:
cat_df = get_catdf(model_dfs['gpt-4.1-nano-2025-04-14'], num1num2='')
fig = make_sankey(cat_df)
# fig.write_image("sankey_v2.svg")     # needs `pip install kaleido`
fig

In [714]:
cat_df = get_catdf(model_dfs['gemini-2.5-flash-preview'], num1num2='!=')

fig = make_sankey(cat_df)
# fig.write_image("sankey_v2.svg")     # needs `pip install kaleido`
fig

In [708]:
cat_df = get_catdf(model_dfs['gemini-flash-1.5-8b'], num1num2='!=')

fig = make_sankey(cat_df)
# fig.write_image("sankey_v2.svg")     # needs `pip install kaleido`
fig