In [1]:
import json
from pathlib import Path
import numpy as np
import plotly.graph_objects as go

In [2]:
COLORS = {
    "F":"#FF7F0E", 
    "M": "#1F77B4", 
    "N": '#2CA02C'
}
FILL_COLORS = {
    "F": 'rgba(255, 127, 14, 0.2)', 
    "M": 'rgba(31, 119, 180, 0.2)', 
    "N": "rgba(44,160,44, 0.2)"
}

def load_results(artifact_path, normalized=True):
    coeffs = json.load(open(artifact_path / "evaluation/winogenerated/coeffs.json", "r"))["coeffs"]
    outputs = json.load(open(artifact_path / "evaluation/winogenerated/steering_outputs.json", "r"))
    if normalized:
        for x in outputs:
            F_probs = np.array(x["F_probs"])
            M_probs = np.array(x["M_probs"])
            N_probs = np.array(x["N_probs"])
            x["F_probs"] = (F_probs  / (F_probs + M_probs + N_probs)).tolist()
            x["M_probs"] = (M_probs / (F_probs + M_probs + N_probs)).tolist()
            x["N_probs"] = (N_probs / (F_probs + M_probs + N_probs)).tolist()

    return coeffs, outputs


def get_avg_std(x):
    avg = np.mean(x, axis=0)
    std = np.std(x, axis=0)
    return avg, std


def plot_steering(coeffs, outputs, title_text=None, width=425, height=300, legend_title="Gender", error_band=False, x_range=None):
    fig = go.Figure()

    for group in ["N", "F", "M"]:
        avg, std = get_avg_std([x[f'{group}_probs'] for x in outputs])
        fig.add_trace(go.Scatter(
            x=coeffs, y=avg, mode='lines+markers', name=group, marker_color=COLORS[group], showlegend=True
        ))
        if error_band:
            fig.add_trace(go.Scatter(
                x=coeffs, y=avg+std, mode='lines', marker=dict(color="#444"), line=dict(width=0), showlegend=False
            ))
            fig.add_trace(go.Scatter(
                x=coeffs, y=avg-std, mode='lines', marker=dict(color="#444"), line=dict(width=0), 
                fillcolor=FILL_COLORS[group], fill='tonexty', showlegend=False
            ))
        
    fig.update_layout(
        width=width, height=height, plot_bgcolor='white',
        margin=dict(l=10, r=10, t=20, b=25),
        font=dict(size=14), title_text=title_text, 
        title_font=dict(size=16), title_x=0.48, title_y=0.98,
        legend_title_text=legend_title, legend_title_font=dict(size=15),
    )
    fig.update_xaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey',
        zeroline = True, zerolinecolor='black',
        title_text="Steering Coefficient (λ)",
        title_font=dict(size=15), tickfont=dict(size=13),
        showline=True, linewidth=1, linecolor='darkgrey',
        title_standoff=1, nticks=10, range=x_range, 
    )
    fig.update_yaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey',
        zeroline = True, zerolinecolor='darkgrey',
        title_text="Probability (%)",
        title_font=dict(size=15), tickfont=dict(size=13),
        showline=True, linewidth=1, linecolor='darkgrey',
        title_standoff=2, range=[0, 1],
    )
    return fig

In [3]:
artifact_path = Path("runs_gender/Qwen-1_8B-chat")
coeffs, outputs = load_results(artifact_path, normalized=True)
fig = plot_steering(coeffs, outputs, width=470, height=300, error_band=True, title_text="Winogenerated", x_range=[-81, 81])
fig.show()
# fig.write_image("plots/qwen-1.8b-winogenerated-steering.pdf")