In [1]:
%load_ext autoreload
%autoreload 2

import torch
import pandas as pd
import pickle
import numpy as np

In [4]:
dfs = pd.read_csv("data/circuit_identification_data/final_toxicity_prompts_0.csv")
threshold_logit = -0.63

In [5]:
from src.utils.file_utils import load_pickle_from_gpu

def get_results(file_name, dfs, threshold_logit):
    results_0 = load_pickle_from_gpu(file_name)
    results_0 = [item.item() for sublist in results_0 for item in sublist]

    dfs['log_diff_0'] = np.array(results_0)

    grouped = dfs.groupby('prompt')['log_diff_0'].std()
    mean_std_dev = grouped.mean()

    print(f"Mean standard deviation: {mean_std_dev}")

    # Determine correctness based on logit threshold
    dfs['correct_0'] = dfs.apply(
        lambda row: (row['toxicity'] > 0.5) == (row['log_diff_0'] > threshold_logit), axis=1
    )

    accuracy = dfs['correct_0'].mean()
    print(f"Accuracy: {accuracy}")

    # Compute majority metric per prompt group
    majority_counts = []
    for _, group in dfs.groupby('prompt'):
        above_threshold = (group['log_diff_0'] > threshold_logit).sum()
        below_threshold = len(group) - above_threshold
        majority_counts.append(max(above_threshold, below_threshold))  # Get the higher count

    mean_majority = np.mean(majority_counts)
    print(f"Mean majority count per prompt group: {mean_majority}")

    return mean_std_dev, accuracy, mean_majority

seed 0 

In [6]:
print("-0.3:")
file_name = "work/bias_abl/results_abl_ablated_bias_EAP-IG_step3000_2357edges_with_toxicity_EAP_step10000_9785edges.json_scaleby0.3_split0.pkl"
std_3_0, acc_3_0,_ = get_results(file_name, dfs, threshold_logit)
print("-0.4:")
file_name_2 = "work/bias_abl/results_abl_ablated_bias_EAP-IG_step3000_2357edges_with_toxicity_EAP_step10000_9785edges.json_scaleby0.4_split0.pkl"
std_4_0, acc_4_0,_ = get_results(file_name_2, dfs, threshold_logit)
print("-0.5:")
file_name_2 = "work/bias_abl/results_abl_ablated_bias_EAP-IG_step3000_2357edges_with_toxicity_EAP_step10000_9785edges.json_scaleby0.5_split0.pkl"
std_5_0, acc_5_0,_ = get_results(file_name_2, dfs, threshold_logit)
print("baseline:")
file_name = "work/bias_abl/results_abl_no_edges_bias.json_scaleby0.3_split0.pkl"
std_b_0, acc_b_0,_ = get_results(file_name, dfs, threshold_logit)

-0.3:
Mean standard deviation: 0.08174432047072691
Accuracy: 0.7418981481481481
Mean majority count per prompt group: 34.583333333333336
-0.4:


Mean standard deviation: 0.0793114990929329
Accuracy: 0.7228009259259259
Mean majority count per prompt group: 34.6875
-0.5:
Mean standard deviation: 0.07702623520133105
Accuracy: 0.7297453703703703
Mean majority count per prompt group: 34.1875
baseline:
Mean standard deviation: 0.08718439555977958
Accuracy: 0.7934027777777778
Mean majority count per prompt group: 34.520833333333336


seed 1

In [8]:
print("-0.3:")
file_name = 'work/bias_abl/results_abl_ablated_bias_EAP-IG_step3000_2357edges_with_toxicity_EAP_step10000_9785edges.json_scaleby0.3_split1.pkl'
std_3_1, acc_3_1,_ = get_results(file_name, dfs, threshold_logit)
print("-0.4:")
file_name = 'work/bias_abl/results_abl_ablated_bias_EAP-IG_step3000_2357edges_with_toxicity_EAP_step10000_9785edges.json_scaleby0.4_split1.pkl'
std_4_1, acc_4_1,_ = get_results(file_name, dfs, threshold_logit)
print("baseline:")
file_name = 'work/bias_abl/results_abl_no_edges_bias.json_scaleby0.3_split1.pkl'
std_b_1, acc_b_1 ,_ = get_results(file_name, dfs, threshold_logit)


-0.3:
Mean standard deviation: 0.08524864738837272
Accuracy: 0.7216435185185185
Mean majority count per prompt group: 33.979166666666664
-0.4:
Mean standard deviation: 0.08089876037899744
Accuracy: 0.6886574074074074
Mean majority count per prompt group: 34.041666666666664
baseline:
Mean standard deviation: 0.08954408683916733
Accuracy: 0.8501157407407407
Mean majority count per prompt group: 34.354166666666664


seed 2

In [10]:
print("-0.3:")
file_name = 'work/bias_abl/results_abl_ablated_bias_EAP-IG_step3000_2357edges_with_toxicity_EAP_step10000_9785edges.json_scaleby0.3_split2.pkl'
std_3_2, acc_3_2,_ = get_results(file_name, dfs, threshold_logit)
print("-0.4:") 
file_name = 'work/bias_abl/results_abl_ablated_bias_EAP-IG_step3000_2357edges_with_toxicity_EAP_step10000_9785edges.json_scaleby0.4_split2.pkl'
std_4_2, acc_4_2,_ = get_results(file_name, dfs, threshold_logit)
print("baseline:")
file_name = 'work/bias_abl/results_abl_no_edges_bias.json_scaleby0.3_split2.pkl'
std_b_2, acc_b_2,_ = get_results(file_name, dfs, threshold_logit)

-0.3:
Mean standard deviation: 0.07945041440760921
Accuracy: 0.7424768518518519
Mean majority count per prompt group: 33.645833333333336
-0.4:
Mean standard deviation: 0.07858758939444574
Accuracy: 0.7517361111111112
Mean majority count per prompt group: 33.9375
baseline:
Mean standard deviation: 0.08578017668111598
Accuracy: 0.8101851851851852
Mean majority count per prompt group: 33.75


### group the results

In [11]:
# create a dataframe with the results
df = pd.DataFrame({
    "std_3": [std_3_0, std_3_1, std_3_2],
    "std_4": [std_4_0, std_4_1, std_4_2],
    "std_5": [std_5_0, None, None],
    "std_b": [std_b_0, std_b_1, std_b_2],
    "acc_3": [acc_3_0, acc_3_1, acc_3_2],
    "acc_4": [acc_4_0, acc_4_1, acc_4_2],
    "acc_5": [acc_5_0, None, None],
    "acc_b": [acc_b_0, acc_b_1, acc_b_2]
})

# add the procentage change
df["std_3_change"] = (df["std_3"] - df["std_b"]) / df["std_b"] * 100
df["std_4_change"] = (df["std_4"] - df["std_b"]) / df["std_b"] * 100
df["std_5_change"] = (df["std_5"] - df["std_b"]) / df["std_b"] * 100

df["acc_3_change"] = (df["acc_3"] - df["acc_b"]) / df["acc_b"] * 100
df["acc_4_change"] = (df["acc_4"] - df["acc_b"]) / df["acc_b"] * 100
df["acc_5_change"] = (df["acc_5"] - df["acc_b"]) / df["acc_b"] * 100

In [12]:
df

Unnamed: 0,std_3,std_4,std_5,std_b,acc_3,acc_4,acc_5,acc_b,std_3_change,std_4_change,std_5_change,acc_3_change,acc_4_change,acc_5_change
0,0.081744,0.079311,0.077026,0.087184,0.741898,0.722801,0.729745,0.793403,-6.239735,-9.030167,-11.651351,-6.491612,-8.898614,-8.023341
1,0.085249,0.080899,,0.089544,0.721644,0.688657,,0.850116,-4.797011,-9.654827,,-15.112321,-18.992512,
2,0.07945,0.078588,,0.08578,0.742477,0.751736,,0.810185,-7.37905,-8.384906,,-8.357143,-7.214286,


In [13]:
# mean std dev and accuracy across for each ablation
df.mean()


std_3            0.082148
std_4            0.079599
std_5            0.077026
std_b            0.087503
acc_3            0.735340
acc_4            0.721065
acc_5            0.729745
acc_b            0.817901
std_3_change    -6.138599
std_4_change    -9.023300
std_5_change   -11.651351
acc_3_change    -9.987025
acc_4_change   -11.701804
acc_5_change    -8.023341
dtype: float64

In [14]:
df.std()

std_3           0.002920
std_4           0.001182
std_5                NaN
std_b           0.001902
acc_3           0.011865
acc_4           0.031575
acc_5                NaN
acc_b           0.029133
std_3_change    1.293987
std_4_change    0.634988
std_5_change         NaN
acc_3_change    4.535587
acc_4_change    6.369855
acc_5_change         NaN
dtype: float64