In [None]:
%load_ext autoreload
%autoreload 2

import os, json
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

In [15]:
intervention_methods = ["baseline", "ablation", "actadd"]
model_dirs = {
    "Qwen 1.8B": "runs/Qwen-1_8B-chat", "Qwen 7B": "runs/Qwen-7B-chat",
    "Gemma 2B": "runs/gemma-2b-it", "Gemma 7B": "runs/gemma-7b-it", 
    "Yi 6B": "runs/yi-6b-chat", 
    "Llama2 7B": "runs/Llama-2-7b-chat-hf", "Llama2 13B": "runs/Llama-2-13b-chat-hf",
    "Llama3 8B": "runs/Meta-Llama-3-8B-Instruct"
}

color_discrete_map = {"baseline": 'rgba(31, 119, 180, 0.8)', "actadd": 'rgba(44, 160, 44, 0.8)', "ablation": 'rgba(214, 39, 40, 0.8)'}
titles = ["HarmBench", "JailbreakBench", "Harmless Test (Alpaca)"]

In [16]:
df = pd.DataFrame({"model": [], "benchmark": [], "intervention": [], "score": [], "metric": []})
benchmark_names = ["harmbench_test", "jailbreakbench", "harmless"]
for model, _dir in model_dirs.items():
    for benchmark in benchmark_names:
        for intervention in intervention_methods:
            result = json.load(open(os.path.join(_dir, f"evaluation/{benchmark}_{intervention}_eval_results.json")))
            refusal_score = 1 - result["substring_matching_success_rate"]
            safety_score = 1 - result["llamaguard2_success_rate"]
            df.loc[len(df.index)] = [model, benchmark, intervention, refusal_score, "refusal"]
            df.loc[len(df.index)] = [model, benchmark, intervention, safety_score, "safety"]

In [17]:
def plot_bar_chart(benchmark_name, title_text, width=825, height=320):
    fig = px.bar(
        df[df["benchmark"] == benchmark_name], x="model", y="score", 
        color="intervention", pattern_shape="metric",
        color_discrete_map=color_discrete_map, width=width, height=height
    )
    fig.update_layout(
        barmode='group', plot_bgcolor='white', margin=dict(l=20, r=20, t=40, b=20),
        title_text=title_text, title_font_size=18, title_x=0.5, font=dict(size=14), 
        legend_title_text=None, bargroupgap=0, bargap=0.25
    )
    fig.update_xaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey', 
        zeroline = True, zerolinecolor='darkgrey', 
        showline=True, linewidth=1, linecolor='darkgrey',
        tickfont=dict(size=14), title_text=None
    )
    fig.update_yaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey', 
        zeroline = True, zerolinecolor='darkgrey', 
        showline=True, linewidth=1, linecolor='darkgrey',
        title_font=dict(size=15), tickfont=dict(size=14), 
        title_text="Score (%)", range=[0,1]
    )
    fig.show()

In [23]:
def plot_scatter(benchmark_name, title_text, width=650, height=400):
    temp = df.pivot(index=["model", "benchmark", "metric"], columns="intervention", values="score").reset_index()
    temp = temp.sort_values("model")
    jitter = {"refusal": -0.15, 'safety': 0.15}
    model_loc = {model: i for i, model in enumerate(sorted(temp["model"].unique()))}
    
    refusal_df = temp[(temp["benchmark"] == benchmark_name)&(temp["metric"] == "refusal")]
    safety_df = temp[(temp["benchmark"] == benchmark_name)&(temp["metric"] == "safety")]

    fig = go.Figure()
    for metric, score_df in zip(["refusal", "safety"], [refusal_df, safety_df]):
        for model, baseline_score, actadd_score, ablation_score in zip(score_df["model"], score_df["baseline"], score_df["actadd"], score_df["ablation"]):
            fig.add_trace(go.Scatter(
                y=[model_loc[model]+jitter[metric], model_loc[model]+jitter[metric]], x=[baseline_score, actadd_score], mode="markers+lines", showlegend=False, 
                marker=dict(color="grey", symbol="arrow", size=10, angleref="previous", standoff=5)
            ))

            fig.add_trace(go.Scatter(
                y=[model_loc[model]+jitter[metric], model_loc[model]+jitter[metric]], x=[baseline_score, ablation_score], mode="markers+lines", showlegend=False, 
                marker=dict(color="grey", symbol="arrow", size=10, angleref="previous", standoff=5)
            ))


    for intervention in ["baseline", "actadd", "ablation"]:
        fig.add_trace(go.Scatter(
            y=[model_loc[model] + jitter["refusal"] for model in refusal_df["model"].tolist()], x=refusal_df[intervention].tolist(), mode="markers", 
            name=intervention, legendgroup="refusal", legendgrouptitle_text="refusal",
            marker=dict(color=color_discrete_map[intervention], size=10, symbol="circle")
        ))
        fig.add_trace(go.Scatter(
            y=[model_loc[model] + jitter["safety"] for model in safety_df["model"].tolist()], x=safety_df[intervention].tolist(), mode="markers", 
            name=intervention, legendgroup="safety", legendgrouptitle_text="safety",
            marker=dict(color=color_discrete_map[intervention], size=10, symbol="diamond")
        ))
    fig.update_layout(
        width=width, height=height, plot_bgcolor='white', font=dict(size=14), margin=dict(l=20, r=20, t=40, b=20),
        title_text=title_text, title_font_size=18, title_x=0.45, scattermode="group", scattergap=0.75,   
    )
    fig.update_xaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey', zeroline = True, zerolinecolor='darkgrey', showline=True, linewidth=1, linecolor='darkgrey',
        title_font=dict(size=15), tickfont=dict(size=13), title_text="Refusal/Safe %", title_standoff=2, range=[-0.05, 1.05]

    )
    fig.update_yaxes(
        mirror=True, showgrid=False, gridcolor='darkgrey',  showline=True, linewidth=1, linecolor='darkgrey',
        title_font=dict(size=15), tickfont=dict(size=13.5), title_text=None, tickvals=list(model_loc.values()), ticktext=list(model_loc.keys()),
    )
    fig.show()

In [19]:
for i, benchmark_name in enumerate(benchmark_names):
    plot_bar_chart(benchmark_name, title_text=titles[i])

In [24]:
for i, benchmark_name in enumerate(benchmark_names):
    plot_scatter(benchmark_name, titles[i])