# A notebook to analyse baseline Alignment and create related plots

In [33]:
### imports 
import pandas as pd 
import seaborn as sns 

from matplotlib import pyplot as plt

import plotly.express as px
import plotly.graph_objects as go


In [None]:
import os
cur_dir = os.getcwd() 
os.chdir(cur_dir.replace("notebooks", ""))
print(os.getcwd())

In [None]:
%load_ext autoreload
%autoreload 2

## Download data

In [58]:
... download from dropbox ... 

zsh:1: command not found: wget


## Load the data

In [None]:
## Load the data 
df = pd.read_csv("data/simulation_results/baseline_alignment.csv") 

print("Columns:", df.columns)

## All model 
print("All models:", df["model_"].unique())

## Functions to plot and visualise baseline dataframe

In [53]:
def baseline_dataframe(df):
    """ Plot the baseline dataframe misalignment """
    grouped_misal = df[["model_", "misal_"]].groupby(['model_']).agg(["mean", "sem"]).reset_index()

    # Sort by mean misalignment in descending order
    grouped_misal = grouped_misal.sort_values(by=[('misal_', 'mean')], ascending=True)

    # Rename columns for clarity
    grouped_misal.columns = ["model_", "Mean Misalignment", "Std Error"]

    # Style the DataFrame
    styled_df = grouped_misal.style.format({
        'Mean Misalignment': "{:.2f}",
        'Std Error': "{:.2f}"
    }).background_gradient(cmap='coolwarm', subset=["Mean Misalignment"])

    # Return the styled DataFrame so it can be rendered in a notebook
    return styled_df


def plot_baseline_misal(df):
    """ Plot the baseline misalignment """    
    # Group and aggregate the data
    grouped_misal = df[["model_", "misal_"]].groupby(['model_']).agg(["mean", "sem"]).reset_index()
    grouped_misal = grouped_misal.sort_values(by=[('misal_', 'mean')], ascending=True)
    grouped_misal.columns = ["model_", "Mean Misalignment", "Std Error"]

    # Display DataFrame in a more professional way with background gradient
    styled_df = grouped_misal.style.format({
        'Mean Misalignment': "{:.2f}",
        'Std Error': "{:.2f}"
    }).background_gradient(cmap='coolwarm', subset=["Mean Misalignment"])

    # Plotting
    plt.figure(figsize=(8, 5))
    plt.barh(grouped_misal['model_'], grouped_misal['Mean Misalignment'], 
            xerr=grouped_misal['Std Error'], color='green', ecolor='red', capsize=5)
    plt.xlabel('Mean Misalignment')
    plt.ylabel('Model')
    plt.title('Model vs. Mean Misalignment with Error Bars')
    plt.gca().invert_yaxis()  # Invert to display the highest values at the top
    plt.tight_layout()

    plt.show(), styled_df


def plot_baseline_responses(df_temp):

    grouped = df_temp.groupby(['model_'])
    frequency = grouped['resp'].value_counts(normalize=True).unstack().fillna(0)
    frequency = frequency.sort_values(by=[1], axis=0)

    # Define a DataFrame for Plotly
    frequency.reset_index(inplace=True)
    frequency.columns = ['model_', 'Deny', 'Partial approve', 'Approve']

    # Define a custom color palette: Green for Deny, Orange for Partial, Red for Approve
    colors = ['#2ca02c', '#ff7f0e', '#d62728']

    # Create traces manually for more control
    fig = go.Figure()

    # Add bars for each category with custom colors
    fig.add_trace(go.Bar(
        y=frequency['model_'],
        x=frequency['Deny'],
        orientation='h',
        name='Deny',
        marker=dict(color=colors[0]),
        width=0.7
    ))

    fig.add_trace(go.Bar(
        y=frequency['model_'],
        x=frequency['Partial approve'],
        orientation='h',
        name='Partial approve',
        marker=dict(color=colors[1]),
        width=0.7
    ))

    fig.add_trace(go.Bar(
        y=frequency['model_'],
        x=frequency['Approve'],
        orientation='h',
        name='Approve',
        marker=dict(color=colors[2]),
        width=0.7
    ))

    # Update layout to enhance professionalism
    fig.update_layout(
        barmode='stack',  # stacked bars for easier comparison
        xaxis=dict(
            range=[0, 1],  # Limiting the x-axis to [0, 1]
            showgrid=True,
            gridcolor='lightgrey'
        ),
        yaxis=dict(
            title='',  # No label for the y-axis
            showgrid=False
        ),
        legend=dict(
            title='',  # No title for the legend
            orientation='h', 
            x=0.5, 
            y=-0.15, 
            xanchor='center',
            font=dict(size=10)
        ),
        plot_bgcolor='white',  # Clean background
        width=600,
        height=400,
        margin=dict(l=50, r=50, t=30, b=50),
        title=dict(
            text='Approval Frequency by Model',
            x=0.5,
            xanchor='center',
            font=dict(size=16, color='black')
        )
    )

    # Add custom y-axis tick labels
    fig.update_yaxes(tickangle=0)
    ## show 
    fig.show()
   


## Filter our data and plot results 

In [None]:
temperature = 1 
tmp_df = df[df.temp_ == temperature]

## dataframe
baseline_dataframe(tmp_df)

In [None]:
plot_baseline_responses(df)

In [None]:
## plot
plot_baseline_misal(tmp_df)