# Ablations Analysis

In [1]:
import numpy as np
import pandas as pd
import json
import pickle

import sys
import pickle
import os

sys.path.append("../../model")
from model_params import *

sys.path.append("../../data")
from dataset_params import *

## 1) Load Metrics

In [2]:
models = [
    "gpt2_xl",
    "gpt_j",
    "gpt_neox_20B",
]

datasets = [
    "sst2",
    "agnews",
    "trec",
    "dbpedia",
    "rte",
    "mrpc",
    "tweet_eval_hate",
    "sick",
    "poem_sentiment",
    "ethos",
    "financial_phrasebank",
    "medical_questions_pairs",
    "tweet_eval_stance_feminist",
    "tweet_eval_stance_atheism",
#     "unnatural",
#     "sst2_ab",
]

settings = [
    "permuted_incorrect_labels",
    "half_permuted_incorrect_labels",
    "random_labels",
]

In [3]:
logit_lens_results = {}
for model_name in models:
    logit_lens_results[model_name] = {}
    for setting in settings:
        logit_lens_results[model_name][setting] = {}
        for dataset_name in datasets:
            logit_lens_results[model_name][setting][dataset_name] = np.load(
                f"../../results/logit_lens/{model_name}/{setting}/{dataset_name}.npy",
                allow_pickle=True,
            ).item()

In [4]:
ablation_results = {}
for model_name in models:
    ablation_results[model_name] = {}
    for setting in settings:
        ablation_results[model_name][setting] = {}
        for dataset_name in datasets:
            ablation_results[model_name][setting][dataset_name] = np.load(
                f"../../results/ablations/{model_name}/{setting}/{dataset_name}.npy",
                allow_pickle=True,
            ).item()

## 2) {Attention, MLP, Late Layer} Ablation

In [5]:
# Load model early exiting layers.
with open(f"../../results/early_exiting_layers/early_exiting_layers.pickle", "rb") as file:
    early_exiting_layers = pickle.load(file)

In [6]:
metric = "cal_correct_over_incorrect"
abl_types = ["none", "attention", "mlp", "late_layers"]
true_prfx, false_prfx = 0, 1
layer_indx = -1 # last layer
demo_indx = -1 # last demo

tp = {}
fp = {}
for model in models:
    print(f"Model: {model}\n")
    tp[model] = {}
    fp[model] = {}
    for setting in settings:
        tp[model][setting] = {key: [] for key in abl_types}
        fp[model][setting] = {key: [] for key in abl_types}
        for abl_type in abl_types:
            tp_avg = fp_avg = 0
            for dataset in datasets:
                if abl_type == "none" or abl_type == "late_layers":
                    res = np.mean(logit_lens_results[model][setting][dataset][metric], axis=0)[0]
                else:
                    res = ablation_results[model][setting][dataset][abl_type][metric][0]
                    
                if abl_type == "late_layers":
                    layer_indx = early_exiting_layers[model]
                else:
                    layer_indx = -1
                    
                tp_val = res[true_prfx][layer_indx][demo_indx]
                fp_val = res[false_prfx][layer_indx][demo_indx]
                tp_avg += tp_val
                fp_avg += fp_val

                tp[model][setting][abl_type].append(tp_val)
                fp[model][setting][abl_type].append(fp_val)
            
            tp[model][setting][abl_type].append(tp_avg/len(datasets))
            fp[model][setting][abl_type].append(fp_avg/len(datasets))
        tp_df = pd.DataFrame(tp[model][setting])
        tp_df.set_index(pd.Index(datasets + ["average"]), inplace=True)
        fp_df = pd.DataFrame(fp[model][setting])
        fp_df.set_index(pd.Index(datasets + ["average"]), inplace=True)
        print("Correct Labels\n", tp_df)
        print(setting + "\n", fp_df)
        print()
    print("********************************************")

Model: gpt2_xl

Correct Labels
                                 none  attention       mlp  late_layers
sst2                        0.718267   0.841000  0.845000     0.721600
agnews                      0.633000   0.658000  0.592000     0.617000
trec                        0.371000   0.342000  0.372000     0.339000
dbpedia                     0.797000   0.778000  0.627000     0.747000
rte                         0.548000   0.485000  0.477000     0.506000
mrpc                        0.537000   0.509000  0.505000     0.543000
tweet_eval_hate             0.530000   0.557000  0.535000     0.544000
sick                        0.432000   0.444667  0.405333     0.403333
poem_sentiment              0.575333   0.528667  0.554333     0.590000
ethos                       0.527000   0.506000  0.534000     0.540000
financial_phrasebank        0.713000   0.716333  0.703333     0.778333
medical_questions_pairs     0.453000   0.464000  0.484000     0.481000
tweet_eval_stance_feminist  0.414000   0.4136

### 3) Head Ablation

In [7]:
model = "gpt_j"
metric = "cal_correct_over_incorrect"
true_prfx, false_prfx = 0, 1
layer_indx = -1 # last layer
demo_indx = -1 # last demo

results = {}
for setting in settings:
    print(f"Setting: {setting}")
    tp_delta_avg_ours = tp_delta_avg_null = gr_avg_ours = gr_avg_null = 0
    tp_delta_avg_ours_upper = tp_delta_avg_null_upper = gr_avg_ours_upper = gr_avg_null_upper = 0
    results[setting] = {"tp_delta_ours": [], "gap_reduction_ours": [], "tp_delta_null": [], "gap_reduction_null": []}
    for dataset in datasets:
        ll_results = np.mean(logit_lens_results[model][setting][dataset][metric], axis=0)
        heads_ours = ablation_results[model][setting][dataset]["heads_ours"][metric]
        heads_null = ablation_results[model][setting][dataset]["heads_null"][metric]

        tp = ll_results[0][true_prfx][layer_indx][demo_indx]
        tp_upper = ll_results[1][true_prfx][layer_indx][demo_indx]
        fp = ll_results[0][false_prfx][layer_indx][demo_indx]
        fp_upper = ll_results[1][false_prfx][layer_indx][demo_indx]
        tp_ours = heads_ours[0][true_prfx][layer_indx][demo_indx]
        tp_ours_upper = heads_ours[1][true_prfx][layer_indx][demo_indx]
        fp_ours = heads_ours[0][false_prfx][layer_indx][demo_indx]
        fp_ours_upper = heads_ours[1][false_prfx][layer_indx][demo_indx]
        tp_null = heads_null[0][true_prfx][layer_indx][demo_indx]
        tp_null_upper = heads_null[1][true_prfx][layer_indx][demo_indx]
        fp_null = heads_null[0][false_prfx][layer_indx][demo_indx]
        fp_null_upper = heads_null[1][false_prfx][layer_indx][demo_indx]

        tp_delta_ours = (tp_ours - tp)
        tp_delta_ours_upper = (tp_ours_upper - tp_upper)
        tp_delta_null = (tp_null - tp)
        tp_delta_null_upper = (tp_null_upper - tp_upper)
        gap_reduction_ours = (1 - (tp - fp_ours) / (tp - fp))
        gap_reduction_ours_upper = (1 - (tp_upper - fp_ours_upper) / (tp_upper - fp_upper))
        gap_reduction_null = (1 - (tp - fp_null) / (tp - fp))
        gap_reduction_null_upper = (1 - (tp_upper - fp_null_upper) / (tp_upper - fp_upper))
        tp_delta_avg_ours += tp_delta_ours
        tp_delta_avg_ours_upper += tp_delta_ours_upper
        tp_delta_avg_null += tp_delta_null
        tp_delta_avg_null_upper += tp_delta_null_upper
        gr_avg_ours += gap_reduction_ours
        gr_avg_ours_upper += gap_reduction_ours_upper
        gr_avg_null += gap_reduction_null
        gr_avg_null_upper += gap_reduction_null_upper
        
        tp_delta_ours_sd = tp_delta_ours_upper - tp_delta_ours
        tp_delta_null_sd = tp_delta_null_upper - tp_delta_null
        gap_reduction_ours_sd = gap_reduction_ours_upper - gap_reduction_ours
        gap_reduction_null_sd = gap_reduction_null_upper - gap_reduction_null

        results[setting]["tp_delta_ours"].append("{:.2f}_{{{:.2f}}}".format(tp_delta_ours*100, abs(tp_delta_ours_sd)*100))
        results[setting]["tp_delta_null"].append("{:.2f}_{{{:.2f}}}".format(tp_delta_null*100, abs(tp_delta_null_sd)*100))
        results[setting]["gap_reduction_ours"].append("{:.2f}_{{{:.2f}}}".format(gap_reduction_ours*100, abs(gap_reduction_ours_sd)*100))
        results[setting]["gap_reduction_null"].append("{:.2f}_{{{:.2f}}}".format(gap_reduction_null*100, abs(gap_reduction_null_sd)*100))      
    
    
    tp_delta_avg_ours = tp_delta_avg_ours/len(datasets)
    tp_delta_avg_ours_upper = tp_delta_avg_ours_upper/len(datasets)
    tp_delta_avg_ours_sd = tp_delta_avg_ours_upper - tp_delta_avg_ours
    tp_delta_avg_null = tp_delta_avg_null/len(datasets)
    tp_delta_avg_null_upper = tp_delta_avg_null_upper/len(datasets)
    tp_delta_avg_null_sd = tp_delta_avg_null_upper - tp_delta_avg_null

    gr_avg_ours = gr_avg_ours/len(datasets)
    gr_avg_ours_upper = gr_avg_ours_upper/len(datasets)
    gr_avg_ours_sd = gr_avg_ours_upper - gr_avg_ours

    gr_avg_null = gr_avg_null/len(datasets)
    gr_avg_null_upper = gr_avg_null_upper/len(datasets)
    gr_avg_null_sd = gr_avg_null_upper - gr_avg_null
    
    results[setting]["tp_delta_ours"].append("{:.2f}_{{{:.2f}}}".format(tp_delta_avg_ours*100, abs(tp_delta_avg_ours_sd)*100))
    results[setting]["tp_delta_null"].append("{:.2f}_{{{:.2f}}}".format(tp_delta_avg_null*100, abs(tp_delta_avg_null_sd)*100))
    results[setting]["gap_reduction_ours"].append("{:.2f}_{{{:.2f}}}".format(gr_avg_ours*100, abs(gr_avg_ours_sd)*100))
    results[setting]["gap_reduction_null"].append("{:.2f}_{{{:.2f}}}".format(gr_avg_null*100, abs(gr_avg_null_sd)*100))    

    df = pd.DataFrame(results[setting])
    df.set_index(pd.Index(datasets + ["average"]), inplace=True)
    print(df)
    print()

Setting: permuted_incorrect_labels
                           tp_delta_ours gap_reduction_ours tp_delta_null  \
sst2                         5.86_{0.36}       54.56_{0.83}   5.46_{0.32}   
agnews                       2.40_{0.11}       32.34_{0.39}   2.70_{0.12}   
trec                        -5.90_{0.03}       19.65_{0.10}   2.10_{0.00}   
dbpedia                      2.10_{0.21}       31.83_{1.24}   1.80_{0.18}   
rte                          1.90_{0.02}       95.16_{0.02}  -0.50_{0.01}   
mrpc                        -5.70_{0.04}       89.02_{0.38}  -3.50_{0.03}   
tweet_eval_hate             -4.10_{0.07}       10.63_{0.15}  -1.50_{0.03}   
sick                        -3.63_{0.05}       15.29_{0.33}   2.27_{0.04}   
poem_sentiment               1.67_{0.03}       30.76_{0.39}   1.47_{0.02}   
ethos                       -6.00_{0.14}       28.61_{0.11}  -3.00_{0.08}   
financial_phrasebank         2.30_{0.09}       32.67_{0.58}   2.33_{0.09}   
medical_questions_pairs      0.30_{0.00} 