In [None]:
!cd .. && pip install -e ./../nnpatch ./../pycolors && pip install -U transformers kaleido && pip install circuitsvis python-dotenv --no-deps

In [None]:
!pip install -U kaleido


In [None]:
!sudo apt install fonts-noto-color-emoji cm-super fonts-cmu

In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")
from nnsight import NNsight
import torch
import os
from tqdm.notebook import tqdm, trange

from nnsight import NNsight

from analysis.circuit_utils.visualisation import *
from analysis.circuit_utils.decoding import get_decoding_args, get_data, generate_title, get_plot_prior_patch, get_plot_context_patch, get_plot_weightcp_patch, get_plot_weightpc_patch

from main import load_model_and_tokenizer

import pandas as pd
import os
from plotly.colors import n_colors
from analysis.circuit_utils.visualisation import plot_das_results

In [None]:
%cd ..

In [None]:
CONFIGS

In [11]:
from pycolors import TailwindColorPalette, to_rgb
from analysis.circuit_utils.visualisation import format_label
import json
from analysis.circuit_utils.visualisation import plot_das_results
from model_utils.utils import construct_test_results_dir, construct_paths_and_dataset_kwargs

from generate_run_script import CONFIGS

DATASET2SUBSPLIT = {
    "BaseFakepedia": "nodup_relpid",
    "MultihopFakepedia": "nodup_relpid",
    "Arithmetic": "d2ub9",
}

def get_results_dir(model, use_instruct, finetuned, dataset, cwf, k, steering, seed=3, in_domain_demonstrations=False, aafp=False, afpp="end"):
    if steering:
        eval_cwf="none" # cwf is always none for steering
    else:
        eval_cwf=cwf
    MODEL_CONFIGS = CONFIGS[model]
    _, _, _, results_dir, _, _, _, _, _ = construct_paths_and_dataset_kwargs(
        DATASET_NAME="BaseFakepedia",
        SUBSPLIT="nodup_relpid",
        SEED=seed,
        TRAIN_SIZE=2048,
        MODEL_ID=MODEL_CONFIGS["instruct_model"].split("/")[-1] if use_instruct else MODEL_CONFIGS["base_model"].split("/")[-1],
        PEFT=True,
        BATCH_SZ=MODEL_CONFIGS["bs"],
        GRAD_ACCUM=MODEL_CONFIGS["ga"],
        NO_TRAIN=not finetuned,
        LORA_MODULES=["q_proj", "k_proj", "v_proj", "o_proj"],
        LOAD_IN_4BIT=False,
        LOAD_IN_8BIT=False,
        CONTEXT_WEIGHT_AT_END=False,
        CONTEXT_WEIGHT_FORMAT=cwf,
        ANSWER_FORMAT_PROMPT_POSITION=afpp,
        ADD_ANSWER_FORMAT_PROMPT=aafp,
    )
    results_dir = construct_test_results_dir(
        results_dir,
        eval_name=dataset,
        subsplit=DATASET2SUBSPLIT[dataset],
        k_demonstrations=k,
        context_weight_format=eval_cwf,
        answer_format_prompt_position=afpp,
        add_answer_format_prompt=aafp,
        do_steering=steering,
        steering_prior_value=float(MODEL_CONFIGS["prior_value"]),
        steering_context_value=float(MODEL_CONFIGS["context_value"]),
        steering_layer=MODEL_CONFIGS["steering_layer"],
        in_domain_demonstrations=in_domain_demonstrations,
    )
    return results_dir
    
def get_results(model, use_instruct, finetuned, dataset, cwf, k, steering, seed=3, in_domain_demonstrations=False, aafp=False, afpp="end"): 
    results_dir = get_results_dir(model, use_instruct, finetuned, dataset, cwf, k, steering, seed, in_domain_demonstrations, aafp, afpp)
    results_file = os.path.join(results_dir, "metrics.json")
    print(results_file)
    with open(results_file, "r") as f:
        data = json.load(f)
    return data

def get_features(model, use_instruct, finetuned, dataset, cwf, k, steering, seed=3, in_domain_demonstrations=False, aafp=False, afpp="end"):
    MODEL_CONFIGS = CONFIGS[model]
    results_dir = get_results_dir(model, use_instruct, finetuned, dataset, cwf, k, steering, seed, in_domain_demonstrations, aafp, afpp)
    features_file = os.path.join(results_dir, f"features_{MODEL_CONFIGS['steering_layer']}.pt")
    features = torch.load(features_file).cpu().numpy()
    return features

def get_configurations(model, seed=3, in_domain_demonstrations=False, aafp=False, afpp="end"):
    configurations = [
        ("INSTRUCT FT INSTRUCTION", (model, True, True, "BaseFakepedia", "instruction", 0, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("INSTRUCT FT FLOAT", (model, True, True, "BaseFakepedia", "float", 0, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE FT INSTRUCTION", (model, False, True, "BaseFakepedia", "instruction", 0, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE FT FLOAT", (model, False, True, "BaseFakepedia", "float", 0, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("INSTRUCT FS INSTRUCTION", (model, True, False, "BaseFakepedia", "instruction", 10, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("INSTRUCT FS FLOAT", (model, True, False, "BaseFakepedia", "float", 10, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE FS INSTRUCTION", (model, False, False, "BaseFakepedia", "instruction", 10, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE FS FLOAT", (model, False, False, "BaseFakepedia", "float", 10, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("INSTRUCT ZS INSTRUCTION", (model, True, False, "BaseFakepedia", "instruction", 0, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("INSTRUCT ZS FLOAT", (model, True, False, "BaseFakepedia", "float", 0, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE ZS INSTRUCTION", (model, False, False, "BaseFakepedia", "instruction", 0, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE ZS FLOAT", (model, False, False, "BaseFakepedia", "float", 0, False, seed, in_domain_demonstrations, aafp, afpp)),
    ]
    steering_configurations = [
        ("INSTRUCT FT INSTRUCTION", (model, True, True, "BaseFakepedia", "instruction", 0, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("INSTRUCT FT FLOAT", (model, True, True, "BaseFakepedia", "float", 0, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE FT INSTRUCTION", (model, False, True, "BaseFakepedia", "instruction", 0, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE FT FLOAT", (model, False, True, "BaseFakepedia", "float", 0, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("INSTRUCT FS INSTRUCTION", (model, True, False, "BaseFakepedia", "instruction", 10, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("INSTRUCT FS FLOAT", (model, True, False, "BaseFakepedia", "float", 10, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE FS INSTRUCTION", (model, False, False, "BaseFakepedia", "instruction", 10, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE FS FLOAT", (model, False, False, "BaseFakepedia", "float", 10, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("INSTRUCT ZS INSTRUCTION", (model, True, False, "BaseFakepedia", "instruction", 0, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("INSTRUCT ZS FLOAT", (model, True, False, "BaseFakepedia", "float", 0, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE ZS INSTRUCTION", (model, False, False, "BaseFakepedia", "instruction", 0, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE ZS FLOAT", (model, False, False, "BaseFakepedia", "float", 0, True, seed, in_domain_demonstrations, aafp, afpp)),
    ]
    return configurations, steering_configurations


# Performance Comparison

In [4]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pycolors import TailwindColorPalette, to_rgb
from collections import defaultdict
import pandas as pd

MODEL_NAMES = {
    "Meta-Llama-3.1-8B-Instruct": "<b>Llama</b> 3.1 8B",
    "gemma-2-9b-it": "<b>Gemma</b> 2 9B",
    "Mistral-7B-Instruct-v0.3": "<b>Mistral</b> 7B",
}

MODEL_NAMES = {
    "llama": "ü¶ô",
    "gemma": "üíé",
    "mistral": "üå¨Ô∏è",
}

column_map = {
    "baseline": "Baseline: Intent Instruction",
    "no_instruction": "Steering: No Instruction",
}
metric_map = {
    "acc": "Accuracy",
    "pair_acc": "PairAcc"
}
def get_str_colors(color):
    return f"rgb({','.join(map(str, color))})"

colors = {
    'no_instruction': get_str_colors(to_rgb(COLORS.get_shade(1, 500))),
    'baseline': get_str_colors(to_rgb(COLORS.get_shade(4, 300))),
}

def get_ds_configurations(model, dataset="BaseFakepedia", seed=3, in_domain_demonstrations=False, aafp=False, afpp="end"):
    return [
        ("INSTRUCT FT INSTRUCTION", (model, True, True, dataset, "instruction", 0, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE FS INSTRUCTION", (model, False, False, dataset, "instruction", 10, False, seed, in_domain_demonstrations, aafp, afpp)),
        ("INSTRUCT ZS INSTRUCTION", (model, True, False, dataset, "instruction", 0, False, seed, in_domain_demonstrations, aafp, afpp)),
    ], [
        ("INSTRUCT FT INSTRUCTION", (model, True, True, dataset, "instruction", 0, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("BASE FS INSTRUCTION", (model, False, False, dataset, "instruction", 10, True, seed, in_domain_demonstrations, aafp, afpp)),
        ("INSTRUCT ZS INSTRUCTION", (model, True, False, dataset, "instruction", 0, True, seed, in_domain_demonstrations, aafp, afpp)),
    ]

DATASETS = ["BaseFakepedia", "MultihopFakepedia", "Arithmetic"]

def plot_das_results_by_model(dataset, models=["gemma", "llama", "mistral"], seed=3, in_domain_demonstrations=False, aafp=False, afpp="end", metric='accuracy', no_y_axis=False, COLORS=TailwindColorPalette(), legend_loc="top"):
    fig = go.Figure()
    
    # Create hierarchical x-axis labels
    x_labels = defaultdict(lambda: [[], []])
    y_values_dict = {column: [] for column in ["baseline", "no_instruction"]}
    
    for model in models:
        model_name = MODEL_NAMES[model]
        configurations, steering_configurations = get_ds_configurations(model, dataset, seed, in_domain_demonstrations, aafp, afpp)
        for baseline, steering in zip(configurations, steering_configurations):
            baseline_value = get_results(*baseline[1])[metric]
            steering_value = get_results(*steering[1])[metric]
            y_values_dict["baseline"].append(baseline_value)
            y_values_dict["no_instruction"].append(steering_value)
            x_labels["baseline"][0].append(format_label(to_standart_label(baseline[0])))
            x_labels["no_instruction"][0].append(format_label(to_standart_label(steering[0])))
            x_labels["baseline"][1].append(model_name)
            x_labels["no_instruction"][1].append(model_name)
    
    ORDER = ["baseline", "no_instruction"]
    # Add traces for each column
    for column in ORDER:
        fig.add_trace(go.Bar(
            name=column_map[column],
            x=x_labels[column],
            y=y_values_dict[column],
            text=[f'{v:.2f}' if v > 0.05 else '0' for v in y_values_dict[column]],
            textposition=['auto' if v != 0 else 'outside' for v in y_values_dict[column]],
            textfont=dict(size=20),
            marker_color=colors[column]
        ))

    fig.update_layout(
        yaxis_title=metric_map[metric],
        # legend_title='Evaluation Setting',
        font=dict(size=20),
    )

    if legend_loc == "top":
        x = 0.97
        y = 0.97
    elif legend_loc == "middle":
        x = 0.97
        y = 0.4
    else:
        x = 0.97
        y = 0.27

    if legend_loc != None:  
        fig.update_layout(legend=dict(
            groupclick="toggleitem",
            yanchor="top",
            y=y,
            xanchor="right",
            x=x,
            # traceorder="grouped",
            orientation="v",
        ))
    else:
        fig.update_layout(showlegend=False)


    width = 750
    fig.update_yaxes(range=[0.0, 1.0])
    if no_y_axis:
        width = 720
        fig.update_layout(yaxis=dict(visible=True, showticklabels=False, ticks="", title=""))
    
    fig.update_xaxes(tickfont=dict(size=29))

    # set font size
    fig.update_layout(
        uniformtext_minsize=25,
        uniformtext_mode='show',
    )
    # set width
    fig.update_layout(width=width, height=400, margin=dict(l=0, r=0, t=0, b=0), font_family="Computer Modern")
    return fig

DATASET_NAMES = {
    "BaseFakepedia": "BF",
    "MultihopFakepedia": "MH",
    "Arithmetic": "AR",
}

def plot_das_results_by_dataset(model, datasets=["BaseFakepedia", "MultihopFakepedia", "Arithmetic"], metric='accuracy', no_y_axis=False, seed=3, in_domain_demonstrations=False, aafp=False, afpp="end", COLORS=TailwindColorPalette(), legend_loc="top"):
    fig = go.Figure()
    # Create hierarchical x-axis labels
    x_labels = defaultdict(lambda: [[], []])
    y_values_dict = {column: [] for column in colors.keys()}
    
    for dataset in datasets:
        configurations, steering_configurations = get_ds_configurations(model, dataset, seed, in_domain_demonstrations, aafp, afpp)
        dataset_name = DATASET_NAMES[dataset]
        for baseline, steering in zip(configurations, steering_configurations):
            baseline_value = get_results(*baseline[1])[metric]
            steering_value = get_results(*steering[1])[metric]
            y_values_dict["baseline"].append(baseline_value)
            y_values_dict["no_instruction"].append(steering_value)
            x_labels["baseline"][0].append(format_label(to_standart_label(baseline[0])))
            x_labels["no_instruction"][0].append(format_label(to_standart_label(steering[0])))
            x_labels["baseline"][1].append(dataset_name)
            x_labels["no_instruction"][1].append(dataset_name)
        # for column in columns:
        #     for key in configurations:
        #         if key in data[model][dataset].keys() and column in data[model][dataset][key].columns:
        #             value = float(data[model][dataset][key].loc[data[model][dataset][key]['Unnamed: 0'] == metric, column].values[0])
        #             y_values_dict[column].append(value)
        #             x_labels[column][0].append(format_label(to_standart_label(key)))
        #             x_labels[column][1].append(DATASET_NAMES[dataset])
    
    # Add traces for each column
    def get_num(num):
        if num > 0.1:
            return f'{num:.2f}'
        elif num > 0.05:
            return f'0.1'
        else:
            return '0'

    ORDER = ["baseline", "no_instruction"]

    for column in ORDER:
        fig.add_trace(go.Bar(
            name=column_map[column],
            x=x_labels[column],
            y=y_values_dict[column],
            text=[get_num(v) for v in y_values_dict[column]],
            textposition=['auto' if v > 0.05 else 'outside' for v in y_values_dict[column]],
            textfont=dict(size=20),
            marker_color=colors[column]
        ))

    fig.update_layout(
        yaxis_title=metric_map[metric],
        font=dict(size=20),
    )

    if legend_loc == "top":
        x = 0.97
        y = 0.97
    elif legend_loc == "middle":
        x = 0.97
        y = 0.4
    else:
        x = 0.97
        y = 0.27

    if legend_loc != None:  
        fig.update_layout(legend=dict(
            groupclick="toggleitem",
            yanchor="top",
            y=y,
            xanchor="right",
            x=x,
            orientation="v",
        ))
    else:
        fig.update_layout(showlegend=False)

    width = 760
    fig.update_yaxes(range=[0.0, 1.0])
    if no_y_axis:
        width = 720
        fig.update_layout(yaxis=dict(visible=True, showticklabels=False, ticks="", title=""))
    
    fig.update_xaxes(tickfont=dict(size=29))

    # set font size
    fig.update_layout(
        uniformtext_minsize=25,
        uniformtext_mode='show',
    )
    # set width
    fig.update_layout(width=width, height=400, margin=dict(l=0, r=0, t=0, b=0), font_family="CMU Serif")
    return fig


In [None]:
fig = plot_das_results_by_model("BaseFakepedia", legend_loc=None, no_y_axis=False, metric='pair_acc', COLORS=TailwindColorPalette())
fig.update_layout(font_family="CMU Serif", width=760)
fig.write_image("plots/pair_acc_BF_models.pdf")
fig.show()

In [None]:
fig = plot_das_results_by_dataset("llama", legend_loc="top", datasets=["BaseFakepedia", "MultihopFakepedia", "Arithmetic"],  no_y_axis=False, metric='pair_acc',COLORS=TailwindColorPalette())
fig.update_layout(font_family="CMU Serif", width=760)
fig.write_image("plots/generalization_llama_pair_acc.pdf")
fig.show()

In [None]:
fig = plot_das_results_by_dataset("mistral", legend_loc="top", datasets=["BaseFakepedia", "MultihopFakepedia", "Arithmetic"],  no_y_axis=False, metric='pair_acc',COLORS=TailwindColorPalette())
fig.update_layout(font_family="CMU Serif", width=760)
fig.write_image("plots/generalization_mistral_pair_acc.pdf")
fig.show()

In [None]:
fig = plot_das_results_by_dataset("gemma", legend_loc="top", datasets=["BaseFakepedia", "MultihopFakepedia", "Arithmetic"],  no_y_axis=False, metric='pair_acc',COLORS=TailwindColorPalette())
fig.update_layout(font_family="CMU Serif", width=760)
fig.write_image("plots/generalization_gemma_pair_acc.pdf")
fig.show()

# Feature Distribution

In [None]:
%cd ..

In [5]:
import os
import torch
import pandas as pd
device = "cuda" if torch.cuda.is_available() else "cpu"
# we just need one of the results for the labels
df = pd.read_csv("data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/3/models/Meta-Llama-3.1-8B-Instruct-NT/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_instruction/test.csv")
labels = df["weight_context"].to_numpy() == 1


In [None]:
configs, steering_configs = get_configurations("llama", in_domain_demonstrations=False)
configs

In [7]:
from plotly.colors import n_colors
from plotly.subplots import make_subplots
from pycolors import TailwindColorPalette, to_rgb
import numpy as np

def plot_feature_distribution(configs, labels, COLORS=TailwindColorPalette()):
    names = [name for name, _ in configs]
    print(names)
    data = [get_features(*values).flatten() for _, values in configs]
    accuracies = {name: get_results(*values)["pair_acc"] for name, values in configs}

    colors_accs =[to_rgb(COLORS.get_shade(4, 300)) for _ in range(len(names))]
    colors_zero = n_colors(to_rgb(COLORS.get_shade(1, 500)), to_rgb(COLORS.get_shade(1, 600)), len(names))
    colors_one = n_colors(to_rgb(COLORS.get_shade(7, 500)), to_rgb(COLORS.get_shade(7, 600)), len(names))

    # Define colors for each column
    colors = {
        'one': [f"rgb({','.join(map(lambda x: str(int(x)), c))})" for c in colors_one],
        'zero': [f"rgb({','.join(map(lambda x: str(int(x)), c))})" for c in colors_zero],
        'accs': [f"rgb({','.join(map(lambda x: str(int(x)), c))})" for c in colors_accs],
    }
    fig = make_subplots(rows=1, cols=2, shared_yaxes=True, horizontal_spacing = 0.02, subplot_titles=(r"Distribution of Subspace Values", r"PairAcc"))
    add_legend = [False for _ in range(len(names))]
    add_legend[-1] = True
    for data_line, color_one, color_zero, name, legend in zip(data, colors['one'], colors['zero'], names, add_legend):
        formated_name = format_label(name)
        # if "INSTRUCT FT Instruction" in name:
        #     formated_name = "<span style='text-decoration:underline;'>"+formated_name+"</span>"
        fig.add_trace(go.Violin(x=data_line[labels], y=[formated_name]*len(data_line[labels]), side='positive', width=3, points=False, line_color=color_one,  name="intent = ctx", legendgroup=name.split(" ")[0], orientation='h', showlegend=legend, line=dict(color="black"), meanline_visible=True), row=1, col=1)
        fig.add_trace(go.Violin(x=data_line[~labels], y=[formated_name]*len(data_line[~labels]),side='positive', width=3, points=False, line_color=color_zero,  name="intent = prior", legendgroup=name.split(" ")[0], orientation='h', showlegend=legend, line=dict(color="black"), meanline_visible=True), row=1, col=1)


    # Add horizontal bar plot with same x axis on subplot col 2
    for name, accuracy in accuracies.items():
        fig.add_trace(go.Bar(
            y=[format_label(name)],
            x=[accuracy],
            orientation='h',
            text=[f"{accuracy:.2f}"],
            textposition='auto' if accuracy > 0.1 else 'outside',
            # textfont=dict(size=25), #if accuracy > 0.1 else dict(size=16),
            insidetextfont=dict(size=25),
            name=name,
            showlegend=False,
            marker_color=colors['accs'][names.index(name)],
        ), row=1, col=2)
    fig.update_layout(
    uniformtext_minsize=25,
    uniformtext_mode='show',
    )

    # Ensure the x-axis ranges are appropriate
    fig.update_xaxes(range=[0, 1], row=1, col=2)

    from scipy.stats.stats import pearsonr   

    mean_diff = {name: np.abs(np.mean(data_line[labels]) - np.mean(data_line[~labels])) for name, data_line in zip(names, data)}
    mean_diffs = [mean_diff[name] for name in names]
    accs = [accuracies[name] for name in names]


    # Add a textbox to display the Pearson correlation
    correlation, p_value = pearsonr(mean_diffs, accs)


    # Add custom legend using annotations
    legend_annotations = [
        dict(
        x=0.99,
        y=0.98,
        xref="paper",
        yref="paper",
        text=f"Pearson correlation between <br> subspace value mean difference <br> and PairAcc:<br> {correlation:.3f} (p={p_value:.3f})",
        showarrow=False,
        font=dict(size=25),
        align="right",
        bgcolor="rgba(255, 255, 255, 0.95)",
        bordercolor="black",
        borderwidth=1,
        borderpad=4,
        xanchor="right",
        yanchor="top")
    ]
    fig.update_layout(annotations=legend_annotations)


    # set x axis titles
    fig.update_xaxes(title_text="Feature Value", row=1, col=1)
    fig.update_xaxes(title_text="PairAcc", row=1, col=2)


    fig.update_layout(
        height=500,
        width=1800
    )

    fig.update_layout(xaxis_showgrid=True, xaxis_zeroline=False)
    fig.update_layout(yaxis_showgrid=True, yaxis_zeroline=False)
    # fig.update_layout(title="Llama3.1 8B ‚Äì Distribution of Feature R_{cp,16} and Relationship with Model Accuracy on BaseFakepedia")
    # larger font
    fig.update_layout(font=dict(size=25))


    add_title = False
    if add_title:
        fig.update_layout(title="Llama3.1 8B ‚Äì Distribution of Feature R_{cp,16} and Relationship with Model PairAcc on BaseFakepedia")
    else:
        fig.update_layout(margin=dict(t=0, b=0, l=0, r=0))

        

    # Move legend to the second subplot
    fig.update_layout(legend=dict(
        yanchor="top",
        y=0.98,
        xanchor="right",
        x=0.485, #48
        orientation="h",
        bgcolor="rgba(255, 255, 255, 0.9)"
    ), legend2=dict(
        yanchor="top",
        y=0.98,
        xanchor="right",
        x=0.995, #48
        orientation="h",
        bgcolor="rgba(255, 255, 255, 0.9)"
    ), margin=dict(l=0, r=0, t=0, b=0))

    return fig


In [None]:
configs, steering_configs = get_configurations("llama", in_domain_demonstrations=False)
fig = plot_feature_distribution(configs, labels)
fig.update_layout(font_family="CMU Serif")
fig.write_image("plots/llama_distribution.pdf")
fig.show()

In [None]:
configs, steering_configs = get_configurations("mistral", in_domain_demonstrations=False)
fig = plot_feature_distribution(configs, labels)
fig.update_layout(font_family="CMU Serif")
fig.write_image("plots/mistral_distribution.pdf")
fig.show()

In [None]:
configs, steering_configs = get_configurations("gemma", in_domain_demonstrations=False)
fig = plot_feature_distribution(configs, labels)
fig.update_layout(font_family="CMU Serif")
fig.write_image("plots/gemma_distribution.pdf")
fig.show()


# Results Baseline vs Steering

In [None]:
%cd ..

In [16]:
import pandas as pd
import os
import json
from plotly.colors import n_colors
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from plotly.colors import n_colors
import torch
import numpy as np
import einops

column_map = {
    "baseline": "Baseline: Intent Instruction",
    "with_instruction": "Steering: Same Instruction",
    "against_instruction": "Steering: Opposite Instruction",
    "one_word": "Steering: Only One Word Instruction",
    "one_word_instruction": "Steering: One Word Instruction and Same Instruction",
    "no_instruction": "Steering: No Instruction",
    "baseline_one_word_instruction": "Baseline: Intent + One Word Instruction"
}
metric_map = {
    "acc": "Accuracy",
    "pair_acc": "PairAcc"
}

from pycolors import TailwindColorPalette

def plot_das_results(baseline_configs, steering_configs, metric='pair_acc', COLORS=TailwindColorPalette(), extended_legend=False):
    
    _data = {key: get_results(*value)[metric] for key, value in baseline_configs}
    _steering_data = {key: get_results(*value)[metric] for key, value in steering_configs}
    
    fig = go.Figure()

    color_no_instruction = to_rgb(COLORS.get_shade(1, 500))
    color_one_word = to_rgb(COLORS.get_shade(1, 500))
    color_against = to_rgb(COLORS.get_shade(1, 500))
    color_baseline = to_rgb(COLORS.get_shade(4, 300))
    # Define colors for each column
    colors = {
        'baseline': [f"rgb({','.join(map(str, color_baseline))})"] * len(_data),
        'no_instruction': [f"rgb({','.join(map(str, color_no_instruction))})"] * len(_data),
    }
    for column, data in zip(["baseline", "no_instruction"], [_data, _steering_data]):    
        y_values = []
        x_labels = []
        for key in data.keys():
            y_values.append(data[key])
            x_labels.append(format_label(key))
        
        fig.add_trace(go.Bar(
            name=column_map[column],
            x=x_labels,
            y=y_values,
            text=[f'{v:.2f}' if v != 0 else '0' for v in y_values],
            textposition=['auto' if v != 0 else 'outside' for v in y_values],
            # textfont=dict(size=20),
            marker_color=colors[column]
        ))


    fig.update_layout(
        barmode='group',
        # title=f'Feature F_{{w}} Causality - {metric_map[metric]}',
        # xaxis_title='Model Configuration',
        yaxis_title=metric_map[metric],
        font=dict(size=16),
        xaxis=dict(tickangle=-25),
        width=2000,
        height=700
    )

    fig.update_layout(legend=dict(
        yanchor="bottom",
        y=0.88,
        xanchor="right",
        x=0.992,
        orientation="h",
        borderwidth=4,
        bordercolor="white"
    ))
    # xrange
    fig.update_yaxes(range=[0.0, 1.0])
    # font
    fig.update_layout(font=dict(size=25))
    
    
    if extended_legend:
        #Add custom legend for color codes
        custom_annotations = [
            go.Scatter(
                x=[None],
                y=[None],
                mode='markers',
                marker=dict(
                    size=10,
                    color=get_label_color("FT", COLORS)
                ),
                legendgroup='config',
                showlegend=True,
                name=' Finetuning (FT)',
                legend = 'legend2'
            ),
            go.Scatter(
                x=[None],
                y=[None],
                mode='markers',
                marker=dict(
                    size=10,
                    color=get_label_color("FS", COLORS)
                ),
                legendgroup='config',
                showlegend=True,
                name=' In-Context Learning (ICL)',
                legend = 'legend2'

            ),
            go.Scatter(
                x=[None],
                y=[None],
                mode='markers',
                marker=dict(
                    size=10,
                    color=get_label_color("ZS", COLORS)
                ),
                legendgroup='config',
                showlegend=True,
                name=' Zero-Shot (ZS)',
                legend = 'legend2'
            )
        ]

        # Add the custom legend to the figure
        for annotation in custom_annotations:
            fig.add_trace(annotation)


        # Add custom legend using annotations
        legend_annotations = [
            dict(
            x=.995,
            y=0.79,
            xref="paper",
            yref="paper",
            text=f"ü´µ  IF = instruction<br>1Ô∏è‚É£  IF = float",
            showarrow=False,
            font=dict(size=20),
            align="left",
            bgcolor="rgba(255, 255, 255, 0.9)",
            borderpad=10,
            xanchor="right",
            height=48,
            width=292,
            yanchor="top")
            
        ]
        fig.update_layout(annotations=legend_annotations)

    fig.update_layout(legend2=dict(
        yanchor="top",
        y=0.98,
        xanchor="right",
        x=0.995, #48
        orientation="v",
        bgcolor="rgba(255, 255, 255, 0.9)"
    ), margin=dict(l=0, r=0, t=0, b=0))

    return fig


In [None]:
configs, steering_configs = get_configurations("llama", in_domain_demonstrations=False)
fig = plot_das_results(configs, steering_configs, metric='pair_acc')
fig.update_layout(font_family="CMU Serif", width=1800, height=450, legend=dict(orientation="v", yanchor="top", y=0.985, xanchor="right", x=0.995, font=dict(size=25)))
fig.write_image(f"plots/llama_pair_acc.pdf")
fig.show()

In [None]:
!ls data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/3/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_instruction/results/


In [None]:
configs, steering_configs = get_configurations("mistral", in_domain_demonstrations=False)
fig = plot_das_results(configs, steering_configs, metric='pair_acc')
fig.update_layout(font_family="CMU Serif", width=1800, height=550, legend=dict(orientation="v", yanchor="top", y=0.985, xanchor="right", x=0.995, font=dict(size=25)))
fig.write_image(f"plots/mistral_pair_acc.pdf")
fig.show()

In [None]:
configs, steering_configs = get_configurations("gemma", in_domain_demonstrations=False)
fig = plot_das_results(configs, steering_configs, metric='pair_acc')
fig.update_layout(font_family="CMU Serif", width=1800, height=550, legend=dict(orientation="v", yanchor="top", y=0.985, xanchor="right", x=0.995, font=dict(size=25)))
fig.write_image(f"plots/gemma_pair_acc.pdf")
fig.show()