In [1]:
%load_ext autoreload
%autoreload 2

import os, json
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

In [2]:
intervention_methods = ["baseline", "ablation", "actadd"]
model_dirs = {
    "Qwen 1.8B": "runs/Qwen-1_8B-chat", "Qwen 7B": "runs/Qwen-7B-chat",
    "Gemma 2B": "runs/gemma-2b-it", "Gemma 7B": "runs/gemma-7b-it", 
    "Yi 6B": "runs/yi-6b-chat", 
    "Llama2 7B": "runs/Llama-2-7b-chat-hf", "Llama2 13B": "runs/Llama-2-13b-chat-hf",
    "Llama3 8B": "runs/Meta-Llama-3-8B-Instruct"
}

## Model Coherence

In [3]:
df = pd.DataFrame({"model": [], "benchmark": [], "intervention": [], "accuracy": []})
for model, _dir in model_dirs.items():
    for intervention in intervention_methods:
        results = json.load(open(os.path.join(_dir, f"coherence_evals/{intervention}.json")))
        for benchmark, acc in results.items():
            df.loc[len(df.index)] = [model, benchmark, intervention, acc]

df = df.pivot(index=["model", "benchmark"], columns="intervention", values="accuracy").reset_index()

In [4]:
color_discrete_map = {"baseline": 'rgba(31, 119, 180, 0.9)', "actadd": 'rgba(44, 160, 44, 0.9)', "ablation": 'rgba(214, 39, 40, 0.9)'}
benchmark_names = ["arc", "mmlu", "truthful_qa"]
titles = ["ARC", "MMLU", "TruthfulQA"]

def plot_acc(intervention="ablation", width=725, height=300):
    for i, benchmark_name in enumerate(benchmark_names):
        temp = df[df["benchmark"] == benchmark_name]

        fig = go.Figure()
        for model, baseline_acc, intervene_acc in zip(temp["model"], temp["baseline"], temp[intervention]):
            fig.add_trace(go.Scatter(
                y=[model, model], x=[baseline_acc, intervene_acc],
                mode="markers+lines", showlegend=False, marker=dict(color="grey", symbol="arrow", size=12, angleref="previous", standoff=8)
            ))

        fig.add_trace(go.Scatter(y=temp["model"].tolist(), x=temp["baseline"].tolist(), mode="markers", name="baseline", marker=dict(color=color_discrete_map["baseline"], size=15)))
        fig.add_trace(go.Scatter(y=temp["model"].tolist(), x=temp["ablation"].tolist(), mode="markers", name=intervention, marker=dict(color=color_discrete_map[intervention], size=15)))
        
        fig.update_layout(
            width=width, height=height, plot_bgcolor='white', margin=dict(l=20, r=20, t=40, b=20),
            title_text=titles[i], title_font_size=18, title_x=0.45,
            font=dict(size=14), bargroupgap=0, bargap=0.2
        )
        fig.update_xaxes(
            mirror=True, showgrid=True, gridcolor='darkgrey', zeroline = True, zerolinecolor='darkgrey', showline=True, linewidth=1, linecolor='darkgrey',
            title_font=dict(size=15), tickfont=dict(size=13), title_text="Accuracy (%)", title_standoff=2,
        )
        fig.update_yaxes(
            mirror=True, showgrid=True, gridcolor='darkgrey', zeroline = True, zerolinecolor='darkgrey', showline=True, linewidth=1, linecolor='darkgrey',
            title_font=dict(size=15), tickfont=dict(size=13.5), title_text=None,
        )
        fig.show()

### Baseline VS Ablation

In [5]:
plot_acc(intervention="ablation")

In [6]:
plot_acc(intervention="actadd")

## Cross Entropy Loss

In [7]:
df = pd.DataFrame({"model": [], "benchmark": [], "intervention": [], "ce_loss": []})
for model, _dir in model_dirs.items():
    for intervention in intervention_methods:
        results = json.load(open(os.path.join(_dir, f"loss_evals/{intervention}_loss_eval.json")))
        for benchmark in results:
            df.loc[len(df.index)] = [model, benchmark, intervention, results[benchmark]["ce_loss"]]

In [10]:
benchmark_names = ["pile", "alpaca", "alpaca_custom_completions"]

def plot_ce_loss(benchmark_name):
    fig = px.bar(
        df[df["benchmark"] == benchmark_name], x="model", y="ce_loss", 
        color="intervention", color_discrete_map=color_discrete_map, width=795, height=320
    )
    fig.update_layout(
        barmode='group', plot_bgcolor='white', margin=dict(l=20, r=20, t=40, b=20),
        title_text=benchmark_name, title_font_size=18, title_x=0.5,
        font=dict(size=14), legend_title_text="Intervention", legend_title_font_size=14,
        bargroupgap=0, bargap=0.25
    )
    fig.update_xaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey', zeroline = True, zerolinecolor='darkgrey', showline=True, linewidth=1, linecolor='darkgrey',
        title_font=dict(size=15), tickfont=dict(size=14), title_text="Model", title_standoff=2
    )
    fig.update_yaxes(
        mirror=True, showgrid=True, gridcolor='darkgrey', zeroline = True, zerolinecolor='darkgrey', showline=True, linewidth=1, linecolor='darkgrey',
        title_font=dict(size=15), tickfont=dict(size=14), title_text="CE Loss",
    )
    fig.show()

In [11]:
for benchmark_name in benchmark_names:
    plot_ce_loss(benchmark_name)