In [9]:
!pip install circuitsvis python-dotenv --no-deps



In [20]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from peft import PeftModel
import einops
from nnsight import NNsight
from nnsight.models.LanguageModel import LanguageModel
import torch
import pandas as pd
import os
from transformer_lens import HookedTransformer
import numpy as np
from tqdm.notebook import tqdm, trange
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
import seaborn as sns
import torch.nn as nn
import circuitsvis as cv
from torch.utils.data import DataLoader, TensorDataset
from collections import defaultdict
import lightning.pytorch as pl

from analysis.circuit_utils.visualisation import *
from analysis.circuit_utils.model import *
from analysis.circuit_utils.validation import *
from analysis.circuit_utils.few_shot import *
from analysis.circuit_utils.utils import *
from analysis.circuit_utils.metrics import *
from main import load_model_and_tokenizer

device = "cuda:0"

args = get_default_parser().parse_args([
    "--context-weight-format", "float", 
])

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
PATHS = paths_from_args(args)
PATHS

{'BASE_MODEL': '/dlabscratch1/public/llm_weights/llama3.1_hf/Meta-Llama-3.1-8B-Instruct',
 'MODEL_NAME': 'Meta-Llama-3.1-8B-Instruct-bs8-ga2-NT-cwf_float',
 'DATAROOT': '/dlabscratch1/jminder/repositories/context-vs-prior-finetuning/data/BaseFakepedia',
 'TRAIN_DATA': '/dlabscratch1/jminder/repositories/context-vs-prior-finetuning/data/BaseFakepedia/splits/nodup_relpid/train.csv',
 'DATASET_CONFIG_NAME': 'BaseFakepedia_nodup_relpid-ts2048',
 'PEFT_MODEL': '/dlabscratch1/jminder/repositories/context-vs-prior-finetuning/data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/3/models/Meta-Llama-3.1-8B-Instruct-bs8-ga2-NT-cwf_float/model',
 'MERGED_MODEL': '/dlabscratch1/jminder/repositories/context-vs-prior-finetuning/data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/3/models/Meta-Llama-3.1-8B-Instruct-bs8-ga2-NT-cwf_float/merged',
 'VAL_DATA_ALL': '/dlabscratch1/jminder/repositories/context-vs-prior-finetuning/data/BaseFakepedia/splits/nodup_relpid/val.csv',
 'TRAIN_DATA_ALL': '/dlabscr

# Utils

In [45]:
RESULTS_BASE = "../overnight_results"
LAYER_WISE = False
K_MAX = 100
def load(path, head):
    data = torch.load(os.path.join(RESULTS_BASE, path))
    if isinstance(data["activation_patching"], dict) and head in data["activation_patching"].keys():
        act = data["activation_patching"][head][:, -1, :]
    else:
        act = data["activation_patching"]
        
    if isinstance(data["attribution_patching"], dict) and head in data["attribution_patching"].keys():
        attr = data["attribution_patching"][head]
    else:
        attr = data["attribution_patching"] 
        
    if LAYER_WISE:
        act = act.mean(1).unsqueeze(1)
        attr = attr.mean(1).unsqueeze(1)
    return act, attr

def load_plot(path, head="o"):
    act, attr = load(path, head)
    fig = make_subplots(rows=1, cols=2, subplot_titles=("Activation", "Attribution"))

    cmax = max(act.max(), attr.max()).item()
    cmin = min(act.min(), attr.min()).item()
    cmin = min(cmin, -cmax)
    cmax = max(cmax, -cmin)
    fig.add_trace(go.Heatmap(z=act.numpy(), zmin=cmin, zmax=cmax), row=1, col=1)
    fig.add_trace(go.Heatmap(z=attr.numpy(), zmin=cmin, zmax=cmax), row=1, col=2)

    # change colormap to red-blue
    fig.update_traces(colorscale='RdBu', selector=dict(type='heatmap'))
    return fig

def load_plot_double_old(path_a, path_b, a_name="A", b_name="B", title="", head="o"):
    separate_scale = "Zero Ablation" in title 
    title += f" (Site '{head}')"
    act_a, attr_a = load(path_a, head)
    act_b, attr_b = load(path_b, head)
    fig = make_subplots(rows=2, cols=2, subplot_titles=(f"Activation {a_name}", f"Activation {b_name}", f"Attribution {a_name}", f"Attribution {b_name}"))

    if not separate_scale:
        cmax = max(act_a.max(), act_b.max(), attr_a.max(), attr_b.max()).item()
        cmin = min(act_a.min(), act_b.min(), attr_a.min(), attr_b.min()).item()
        cmin = min(cmin, -cmax)
        cmax = max(cmax, -cmin)
        cmin_act = cmin
        cmax_act = cmax
        cmin_attr = cmin
        cmax_attr = cmax
    else:
        cmin_act = min(act_a.min(), act_b.min()).item()
        cmax_act = max(act_a.max(), act_b.max()).item()
        cmin_act = min(cmin_act, -cmax_act)
        cmax_act = max(cmax_act, -cmin_act)
        cmin_attr = min(attr_a.min(), attr_b.min()).item()
        cmax_attr = max(attr_a.max(), attr_b.max()).item()
        cmin_attr = min(cmin_attr, -cmax_attr)
        cmax_attr = max(cmax_attr, -cmin_attr)
    
    fig.add_trace(go.Heatmap(z=act_a.numpy(), zmin=cmin_act, zmax=cmax_act, colorbar=dict(y=0.75 if separate_scale else 0.5, len=0.4 if separate_scale else 1.0)), row=1, col=1)
    fig.add_trace(go.Heatmap(z=act_b.numpy(), zmin=cmin_act, zmax=cmax_act, colorbar=dict(y=0.75 if separate_scale else 0.5, len=0.4 if separate_scale else 1.0)), row=1, col=2)
    fig.add_trace(go.Heatmap(z=attr_a.numpy(), zmin=cmin_attr, zmax=cmax_attr, colorbar=dict(y=0.25 if separate_scale else 0.5, len=0.4 if separate_scale else 1.0)), row=2, col=1)
    fig.add_trace(go.Heatmap(z=attr_b.numpy(), zmin=cmin_attr, zmax=cmax_attr, colorbar=dict(y=0.25 if separate_scale else 0.5, len=0.4 if separate_scale else 1.0)), row=2, col=2)
    
    # change colormap to red-blue
    fig.update_traces(colorscale='RdBu', selector=dict(type='heatmap'))
    
    if separate_scale:
        # create two separate colorbars
        fig.update_layout(coloraxis=dict(colorscale='RdBu'))
        fig.update_traces(coloraxis="coloraxis1", selector=dict(row=1))
        fig.update_traces(coloraxis="coloraxis2", selector=dict(row=2))
    fig.update_layout(height=800, width=1200)
    
    # X-axis labels
    fig.update_xaxes(title_text="Head", row=2, col=1)
    fig.update_xaxes(title_text="Head", row=2, col=2)
    fig.update_yaxes(title_text="Layer", row=1, col=1)
    fig.update_yaxes(title_text="Layer", row=2, col=1)
    
    # Title
    fig.update_layout(title_text=title)
    return fig

def load_plot_double(path_a, path_b, a_name="A", b_name="B", title="", head="o"):
    separate_scale = "Zero Ablation" in title 
    title += f" (Site '{head}')"
    act_a, attr_a = load(path_a, head)
    act_b, attr_b = load(path_b, head)
    fig = make_subplots(rows=1, cols=2, subplot_titles=(f"Refined Attribution {a_name}", f"Refined Attribution {b_name}"))

    if not separate_scale:
        cmax = max(act_a.max(), act_b.max(), attr_a.max(), attr_b.max()).item()
        cmin = min(act_a.min(), act_b.min(), attr_a.min(), attr_b.min()).item()
        cmin = min(cmin, -cmax)
        cmax = max(cmax, -cmin)
        cmin_act = cmin
        cmax_act = cmax
        cmin_attr = cmin
        cmax_attr = cmax
    else:
        cmin_act = min(act_a.min(), act_b.min()).item()
        cmax_act = max(act_a.max(), act_b.max()).item()
        cmin_act = min(cmin_act, -cmax_act)
        cmax_act = max(cmax_act, -cmin_act)
        cmin_attr = min(attr_a.min(), attr_b.min()).item()
        cmax_attr = max(attr_a.max(), attr_b.max()).item()
        cmin_attr = min(cmin_attr, -cmax_attr)
        cmax_attr = max(cmax_attr, -cmin_attr)
    
    ref_attr_a = attr_a
    attr_a[act_a != 0] = ref_attr_a[act_a != 0]
    ref_attr_b = attr_b
    attr_b[act_b != 0] = ref_attr_b[act_b != 0]
    
    fig.add_trace(go.Heatmap(z=ref_attr_a.numpy(), zmin=cmin_act, zmax=cmax_act, colorbar=dict(y=0.75 if separate_scale else 0.5, len=0.4 if separate_scale else 1.0)), row=1, col=1)
    fig.add_trace(go.Heatmap(z=ref_attr_b.numpy(), zmin=cmin_act, zmax=cmax_act, colorbar=dict(y=0.75 if separate_scale else 0.5, len=0.4 if separate_scale else 1.0)), row=1, col=2)

    # change colormap to red-blue
    fig.update_traces(colorscale='RdBu', selector=dict(type='heatmap'))
    
    if separate_scale:
        # create two separate colorbars
        fig.update_layout(coloraxis=dict(colorscale='RdBu'))
        fig.update_traces(coloraxis="coloraxis1", selector=dict(row=1))
        fig.update_traces(coloraxis="coloraxis2", selector=dict(row=2))
    fig.update_layout(height=400, width=1200)
    
    # X-axis labels
    fig.update_xaxes(title_text="Head", row=1, col=1)
    fig.update_xaxes(title_text="Head", row=1, col=2)
    fig.update_yaxes(title_text="Layer", row=1, col=1)
    
    # Title
    fig.update_layout(title_text=title)
    return fig


def load_plot_iou(path_a, path_b, a_name="A", b_name="B", title="", head="o"):
    act_a, attr_a = load(path_a, head)
    act_b, attr_b = load(path_b, head)
    ious = iou_range(act_a, act_b, k_max=K_MAX)
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=list(range(1,100)), y=ious, mode="lines"))
    
    fig.update_layout(title_text=title)
    
    return fig

def topk_cosine(act_a, act_b, k):
    act_a = act_a.flatten()
    act_b = act_b.flatten()
    topk_a = torch.topk(act_a, k, dim=-1).indices
    topk_b = torch.topk(act_b, k, dim=-1).indices
    union = torch.cat([topk_a, topk_b], dim=-1).unique()
    act_a = act_a[union]
    act_b = act_b[union]
    return cosine(act_a, act_b)    

def load_cosine(path_a, path_b, a_name="A", b_name="B", title="", head="o"):
    act_a, attr_a = load(path_a, head)
    act_b, attr_b = load(path_b, head)
    act_a[act_a == 0] = attr_a[act_a == 0]
    act_a = act_a.flatten()
    act_b[act_b == 0] = attr_b[act_b == 0]
    act_b = act_b.flatten()
    topk_a = torch.topk(act_a, K_MAX, dim=-1).indices
    topk_b = torch.topk(act_b, K_MAX, dim=-1).indices
    union = torch.cat([topk_a, topk_b], dim=-1).unique()
    act_a = act_a[union]
    act_b = act_b[union]
    return cosine(act_a, act_b)



def load_plot_cosine_range(path_a, path_b, a_name="A", b_name="B", title="", head="o"):
    act_a, attr_a = load(path_a, head)
    act_b, attr_b = load(path_b, head)
    act_a[act_a == 0] = attr_a[act_a == 0]
    act_b[act_b == 0] = attr_b[act_b == 0]

    out = []
    for k in range(1, 100):
        out.append(topk_cosine(act_a, act_b, k))
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=list(range(1,100)), y=out, mode="lines"))
    
    fig.update_layout(title_text=title)
    
    return fig


def load_iou_auc(path_a, path_b, a_name="A", b_name="B", title="", head="o"):
    act_a, attr_a = load(path_a, head)
    act_b, attr_b = load(path_b, head)
    return iou_auc(act_a, act_b, k_max=K_MAX)

In [43]:
def parse_element(element):
    element_l = element.split("-")
    CWF = "instruction" if "I" in element_l else "float"
    ZERO_ABL = "ZERO" in element_l
    MODEL_TYPE = "ft" if "FT" in element_l else "fs10"
    if "ZS" in element_l:
        MODEL_TYPE = "fs0"
    if "IOI" in element_l:
        MODEL_TYPE = "ioi"
    return MODEL_TYPE, CWF, ZERO_ABL

FILE_STORE = defaultdict(list)

def construct_file_store():
    # list all files
    files = os.listdir(RESULTS_BASE)
    for file in files:
        if "ioi" in file:
            FILE_STORE["IOI"].append(file)
            continue
        elif "all_act" in file:
            continue
        ARR = file[:-3].split("_")
        # support for newest filenames
        if "heads" in ARR[-1]:
            ARR = ARR[:-1]
        if "_zero" in file:
            try:
                _, _, MODEL_TYPE, _, CWF, N = ARR
                TYPE = "global patch"
            except:
                _, _, MODEL_TYPE, _, CWF, N, TYPE = ARR
            ZERO_ABL = True
        else:
            try:
                _, MODEL_TYPE, _, CWF, N = ARR
                TYPE = "global patch"
            except:
                _, MODEL_TYPE, _, CWF, N, TYPE = ARR
            ZERO_ABL = False
        
        CWF = CWF.split("-")[1]
        FILE_STORE[(MODEL_TYPE, CWF, ZERO_ABL, TYPE)].append(file)
    
    # sort and keep only the last one
    for key, files in FILE_STORE.items():
        FILE_STORE[key] = sorted(files)[0]

def get_file(MODEL_TYPE, CWF, ZERO_ABL, TYPE):
    if MODEL_TYPE == "ioi":
        return FILE_STORE["IOI"]
    res = FILE_STORE[(MODEL_TYPE, CWF, ZERO_ABL, TYPE)]
    if len(res) == 0:
        raise ValueError(f"No file found for {MODEL_TYPE} {CWF} {'Zero Ablation' if ZERO_ABL else ''} {TYPE}")
    return res

TYPE_TO_NAME = {
    "global patch": "Global Patch",
    "pif": "Prior Information Flow",
    "cif": "Context Information Flow",
}

def INPS(config_string):
    els = config_string.split("$")
    if len(els) == 3:
        first, second, global_cfg = els
    else:
        first, second = els
        global_cfg = "Global Patch"
    TYPE = global_cfg.lower()
    MODEL_TYPE, CWF, ZERO_ABL = parse_element(first)
    first_file = get_file(MODEL_TYPE, CWF, ZERO_ABL, TYPE)
    first_name = f"{'Zero Ablation ' if ZERO_ABL else ''}{MODEL_TYPE} {CWF}"
    MODEL_TYPE, CWF, ZERO_ABL = parse_element(second)
    second_file = get_file(MODEL_TYPE, CWF, ZERO_ABL, TYPE)
    second_name = f"{'Zero Ablation ' if ZERO_ABL else ''}{MODEL_TYPE} {CWF}"
    
    print(first_file, second_file)
    return first_file, second_file, first_name, second_name, TYPE_TO_NAME[TYPE]
    
def plot_(config_string, head="o"):
    return load_plot_double(*INPS(config_string), head=head)

def iou_auc_(config_string, head="o"):
    return load_iou_auc(*INPS(config_string), head)

def iou_auc_range_(config_string, head="o"):
    return load_plot_iou(*INPS(config_string), head)

def cosine_(config_string, head="o"):
    return load_cosine(*INPS(config_string), head)

def cosine_range_(config_string, head="o"):
    return load_plot_cosine_range(*INPS(config_string), head)

def plot_metric(config_string, head="o"):
    print("IOUAUC", iou_auc_(config_string, head))
    print("COSINE", cosine_(config_string, head))
    plot_(config_string, head).show()

construct_file_store()        

In [44]:
cosine_range_("FS-I$FT-I", "o")

last_fs10_i0_cwf-instruction_n100_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n100_heads['o', 'q'].pt


In [36]:
load_cosine_range("ioi.pt", "ioi.pt")

In [25]:
iou_auc_("FT-F$FT-I$CIF", "o")

last_ft_i0_cwf-float_n100_cif_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt


59.15729235444721

In [26]:
iou_auc_range_("FS-F$FS-I$PIF", "o")

last_fs10_i0_cwf-float_n19_pif_heads['o', 'q'].pt last_fs10_i0_cwf-instruction_n23_pif_heads['o', 'q'].pt


In [27]:
iou_auc_range_("FT-F$FS-I$CIF", "o")

last_ft_i0_cwf-float_n100_cif_heads['o', 'q'].pt last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt


In [28]:
iou_auc_range_("ZS-I$ZS-F$CIF", "q")

last_fs0_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_fs0_i0_cwf-float_n100_cif_heads['o', 'q'].pt


In [32]:
plot_metric("FS-I$FT-I$PIF", "o")

last_fs10_i0_cwf-instruction_n23_pif_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n24_pif_heads['o', 'q'].pt
IOUAUC 26.64739341902037
last_fs10_i0_cwf-instruction_n23_pif_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n24_pif_heads['o', 'q'].pt
COSINE 0.7660793662071228
last_fs10_i0_cwf-instruction_n23_pif_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n24_pif_heads['o', 'q'].pt


In [31]:
plot_metric("FS-I$ZS-I$CIF", "q")

last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_fs0_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt
IOUAUC 32.90408398582089
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_fs0_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt
COSINE 0.6924096345901489
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_fs0_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt


In [116]:
plot_metric("FS-I$FT-I$PIF", "o")

last_fs10_i0_cwf-instruction_n23_pif_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n24_pif_heads['o', 'q'].pt
IOUAUC 26.64739341902037
last_fs10_i0_cwf-instruction_n23_pif_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n24_pif_heads['o', 'q'].pt
COSINE 0.7066217064857483
last_fs10_i0_cwf-instruction_n23_pif_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n24_pif_heads['o', 'q'].pt


In [117]:
plot_metric("FT-I$ZS-F", "o")

last_ft_i0_cwf-instruction_n100_heads['o', 'q'].pt last_fs0_i0_cwf-float_n100_heads['o', 'q'].pt
IOUAUC 10.079267159622402
last_ft_i0_cwf-instruction_n100_heads['o', 'q'].pt last_fs0_i0_cwf-float_n100_heads['o', 'q'].pt
COSINE 0.27627405524253845
last_ft_i0_cwf-instruction_n100_heads['o', 'q'].pt last_fs0_i0_cwf-float_n100_heads['o', 'q'].pt


In [118]:
load_plot_iou(*INPS["F-I-PIF"])

TypeError: 'function' object is not subscriptable

In [None]:
print("I-F:", load_iou_auc(*INPS["F-I"]))
print("I-F-CIF:", load_iou_auc(*INPS["F-I-CIF"]))
print("I-F-PIF:", load_iou_auc(*INPS["F-I-PIF"]))
print("I-F-CTP:", load_iou_auc(*INPS["F-I-CTP"]))
print("I-F-PTC:", load_iou_auc(*INPS["F-I-CTP"]))
print("I-F-PTC:", load_iou_auc(*INPS["F-I-CTP"]))

I-F: 30.402542048050307
I-F-CIF: 19.97782837881256
I-F-PIF: 15.594129093073795
I-F-CTP: 29.475131230323413
I-F-PTC: 29.475131230323413


In [None]:
a["attribution_patching"]

tensor([[ 1.3658e-06, -6.8909e-06, -4.6764e-06,  ...,  7.2364e-06,
         -8.1091e-06,  2.4591e-05],
        [-1.5227e-04, -2.3842e-05, -1.2488e-04,  ..., -6.1231e-05,
         -1.7952e-05, -5.7369e-07],
        [ 1.1112e-04,  3.7718e-04,  2.7432e-04,  ...,  2.5384e-05,
         -1.4487e-04, -9.9612e-05],
        ...,
        [-1.8943e-03, -1.9520e-03,  4.9043e-04,  ..., -5.0133e-04,
         -2.3839e-03, -1.4066e-03],
        [-4.1595e-03,  7.1400e-03, -1.3857e-03,  ...,  5.0807e-03,
         -4.3616e-03, -4.6444e-04],
        [ 1.6687e-03, -4.6904e-03,  7.6556e-03,  ...,  4.7433e-03,
         -1.3988e-03,  9.1553e-05]])

In [None]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

In [None]:
act, attr = load("all_act.pt", "o")

In [None]:
act_order = torch.argsort(act.flatten(), descending=True)
act_rank = torch.zeros(act_order.shape[0]).long()
act_rank[act_order] = torch.arange(act_order.shape[0])

attr_order = torch.argsort(attr.flatten(), descending=True)
attr_rank = torch.zeros(attr_order.shape[0]).long()
attr_rank[attr_order] = torch.arange(attr_order.shape[0])
# plot ranks as scatter plot
fig = go.Figure()
fig.add_trace(go.Scatter(x=act_rank.cpu().numpy(), y=attr_rank.cpu().numpy(), mode='markers'))
# log scale
fig.update_xaxes(type="log")
fig.update_yaxes(type="log")
# title
fig.update_layout(title_text="Rank Correlation")
# x and y labels
fig.update_xaxes(title_text="Activation Rank")
fig.update_yaxes(title_text="Attribution Rank")

In [12]:
act_order = torch.argsort(act.flatten(), descending=True)
act_rank = torch.zeros(act_order.shape[0]).long()
act_rank[act_order] = torch.arange(act_order.shape[0])

attr[act != 0] = act[act != 0]  
attr_order = torch.argsort(attr.flatten(), descending=True)
attr_rank = torch.zeros(attr_order.shape[0]).long()
attr_rank[attr_order] = torch.arange(attr_order.shape[0])
# plot ranks as scatter plot
fig = go.Figure()
fig.add_trace(go.Scatter(x=act_rank.cpu().numpy(), y=attr_rank.cpu().numpy(), mode='markers'))
# log scale
fig.update_xaxes(type="log")
fig.update_yaxes(type="log")
# title
fig.update_layout(title_text="Rank Correlation")
# x and y labels
fig.update_xaxes(title_text="Activation Rank")
fig.update_yaxes(title_text="Refined Attribution Rank")

# Plot Results

In [46]:
def get_metrics(config, head):
    return cosine_(config, head), iou_auc_(config, head)

def plot_similarity(similarity_matrix_cos, similarity_matrix_auc, order, TYPE):
    # Model names for axes
    model_names = order

    similarity_matrix1 = similarity_matrix_cos
    similarity_matrix2 = similarity_matrix_auc

    # Colors for baseline comparisons
    ioi_color = 'red'
    zeroshot_color = 'green'
    other_color2 = 'gray'

    # Define colors for categories
    finetune_color = 'purple'
    fewshot_color = 'blue'
    float_color = '#EC6E94'
    instruction_color = 'teal'


    # Function to determine the color of x-axis labels based on categories
    def get_label_color(label):
        if 'FT' in label:
            return finetune_color
        elif 'FS' in label:
            return fewshot_color
        elif 'F' in label:
            return float_color
        elif "ZS" in label:
            return zeroshot_color
        elif "IOI" in label:
            return ioi_color
        elif 'I' in label:
            return instruction_color
        else:
            return 'black'  # Default color




    # Reshape the similarity matrices into long format for plotting
    data1 = []
    data2 = []
    for i in range(len(model_names)):
        for j in range(i + 1, len(model_names)):
            mn = [model_names[i], model_names[j]]
            if "IOI" in mn and ("ZS-I" in mn or "ZS-F" in mn) or "ZS-I" in mn and "ZS-F" in mn:
                continue
            
            pair = f'{model_names[i]} – {model_names[j]}'
            color1 = ioi_color if 'IOI' in mn else zeroshot_color if 'ZS-I' in mn or 'ZS-F' in mn else other_color2
            color2 = ioi_color if 'IOI' in mn else zeroshot_color if 'ZS-I' in mn or 'ZS-F' in mn else other_color2
            data1.append({
                'model_pair': pair,
                'similarity_value': similarity_matrix1[i, j],
                'color': color1
            })
            data2.append({
                'model_pair': pair,
                'similarity_value': similarity_matrix2[i, j],
                'color': color2
            })

    # Sort the data in descending order
    data1 = sorted(data1, key=lambda x: x['similarity_value'], reverse=True)
    data2 = sorted(data2, key=lambda x: x['similarity_value'], reverse=True)

    # Find the smallest non-zero elements
    min_value1 = min([d['similarity_value'] for d in data1])
    min_value2 = min([d['similarity_value'] for d in data2])
    ymin1 = 0.9 * min_value1
    ymin2 = 0.9 * min_value2
    ymax1 = 1.0
    ymax2 = K_MAX

    # Create subplots with adjusted margins and spacing
    fig = make_subplots(rows=1, cols=2, subplot_titles=("Cosine Similarity", "IOU AUC"))

    # Add bars for similarity measure 1
    fig.add_trace(go.Bar(
        x=[d['model_pair'] for d in data1],
        y=[d['similarity_value'] for d in data1],
        marker_color=[d['color'] for d in data1],
        name='Cosine Similarity',
        showlegend=False
    ), row=1, col=1)

    # Add bars for similarity measure 2
    fig.add_trace(go.Bar(
        x=[d['model_pair'] for d in data2],
        y=[d['similarity_value'] for d in data2],
        marker_color=[d['color'] for d in data2],
        name='IOU AUC',
        showlegend=False
    ), row=1, col=2)

    # Update layout for better readability

    title = "LAYERWISE - " if LAYER_WISE else ""
    if "PIF" in TYPE:
        title += "Prior-Info-Flow: "
    elif "CIF" in TYPE:
        title += "Context-Info-Flow: "

    fig.update_layout(
        title=title +'Comparison of Similarity Measures Between Model Pairs',
        showlegend=True,
        legend=dict(
            x=1.05,
            y=1,
            traceorder='normal',
            itemsizing='constant',
            font=dict(size=12),
            bgcolor='rgba(255, 255, 255, 0.8)'
        )
    )

    # Function to create HTML formatted label with colored words
    def format_label(label):
        label = label.replace("-F", ' 1️⃣')
        label = label.replace("-I", ' 🫵')
        words = label.replace("-", " ").split()
        colored_label = ''.join([f'<span style="color:{get_label_color(word)};">{"<b>" if word in ["FS", "FT"] else ""} {word} </span>' for word in words])
        return colored_label

    # Hide the default x-axis labels
    fig.update_xaxes(showticklabels=False, row=1, col=1)
    fig.update_xaxes(showticklabels=False, row=1, col=2)

    # Set y-axis minimum
    fig.update_yaxes(range=[ymin1, ymax1], row=1, col=1)
    fig.update_yaxes(range=[ymin2, ymax2], row=1, col=2)

    # Apply color coding to x-axis labels using annotations
    annotations = []
    for idx, d in enumerate(data1):
        annotations.append(dict(
            x=idx,  # Adjust x position based on index
            y=-0.01,  # Adjust y position as needed
            xref='x1',
            yref='paper',
            text=format_label(d['model_pair']),
            showarrow=False,
            xanchor='center',
            yanchor='top',
            align='center',
            textangle=90
        ))

    for idx, d in enumerate(data2):
        annotations.append(dict(
            x=idx,  # Adjust x position based on index
            y=-0.005,  # Adjust y position as needed
            xref='x2',
            yref='paper',
            text=format_label(d['model_pair']),
            showarrow=False,
            xanchor='left',
            yanchor='top',
            align='center',
            textangle=90,
            xshift=-10
        ))

    for idx, d in enumerate(data1[:2]):
        annotations.append(dict(
            x=idx,  # Adjust x position based on index
            y=-0.01,  # Adjust y position as needed
            xref='x1',
            yref='paper',
            text=format_label(d['model_pair']),
            showarrow=False,
            xanchor='center',
            yanchor='top',
            align='center',
            textangle=90,
        ))

    #Add custom legend for color codes
    custom_annotations = [
        go.Scatter(
            x=[None],
            y=[None],
            mode='markers',
            marker=dict(
                size=10,
                color=finetune_color
            ),
            legendgroup='Finetuning',
            showlegend=True,
            name='Finetuning (FT)'
        ),
        go.Scatter(
            x=[None],
            y=[None],
            mode='markers',
            marker=dict(
                size=10,
                color=fewshot_color
            ),
            legendgroup='Fewshotting',
            showlegend=True,
            name='FewShot (FS)'
        ),
        go.Scatter(
            x=[None],
            y=[None],
            mode='markers',
            marker=dict(
                size=10,
                color=ioi_color
            ),
            legendgroup='IOI',
            showlegend=True,
            name='IOI'
        ),
        go.Scatter(
            x=[None],
            y=[None],
            mode='markers',
            marker=dict(
                size=10,
                color=zeroshot_color
            ),
            legendgroup='Zero Shot',
            showlegend=True,
            name='Zero Shot (ZS)'
        )
    ]

    # Add the custom legend to the figure
    for annotation in custom_annotations:
        fig.add_trace(annotation)

    # Add custom legend using annotations
    legend_annotations = [
        dict(
            x=1.056,
            y=0.8,
            xref='paper',
            yref='paper',
            text='1️⃣   CWF=float',
            showarrow=False,
            xanchor='left',
            yanchor='top',
            font=dict(size=12)
        ),
        dict(
            x=1.056,
            y=0.77,
            xref='paper',
            yref='paper',
            text='🫵   CWF=instruction',
            showarrow=False,
            xanchor='left',
            yanchor='top',
            font=dict(size=12)
        )
    ]
        
    fig.update_layout(annotations=annotations + legend_annotations)

    # adapt margins
    fig.update_layout(
        margin=dict(l=20, r=20, t=100, b=150),
    )

    # set the size of the figure
    fig.update_layout(
        width=1800,
        height=800
    )
    # Show the figure
    return fig

In [47]:
order = ["FS-I", "FS-F", "FT-I", "FT-F", "ZS-I", "ZS-F", "IOI"]
combs = [(a, b) for a in order for b in order if a != b and a <= b]
TYPE = "$CIF"
similarity_matrix_cos = np.zeros((len(order), len(order)))
similarity_matrix_auc = np.zeros((len(order), len(order)))
for comb in combs:
    cosine_v, iou_auc_v = get_metrics(f"{comb[0]}${comb[1]}{TYPE}", "o")
    
    similarity_matrix_cos[order.index(comb[0]), order.index(comb[1])] = cosine_v
    similarity_matrix_cos[order.index(comb[1]), order.index(comb[0])] = cosine_v
    similarity_matrix_auc[order.index(comb[0]), order.index(comb[1])] = iou_auc_v
    similarity_matrix_auc[order.index(comb[1]), order.index(comb[0])] = iou_auc_v
    baseline_args = []
    
plot_similarity(similarity_matrix_cos, similarity_matrix_auc, order, TYPE)


last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_ft_i0_cwf-float_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_ft_i0_cwf-float_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_fs0_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_fs0_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_fs0_i0_cwf-float_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_fs0_i0_cwf-float_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt ioi.pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt ioi

In [48]:
order = ["FS-I", "FS-F", "FT-I", "FT-F", "ZS-I", "ZS-F", "IOI"]
combs = [(a, b) for a in order for b in order if a != b and a <= b]
TYPE = "$CIF"
LAYER_WISE = False
K_MAX = 100
similarity_matrix_cos = np.zeros((len(order), len(order)))
similarity_matrix_auc = np.zeros((len(order), len(order)))
for comb in combs:
    cosine_v, iou_auc_v = get_metrics(f"{comb[0]}${comb[1]}{TYPE}", "o")
    
    similarity_matrix_cos[order.index(comb[0]), order.index(comb[1])] = cosine_v
    similarity_matrix_cos[order.index(comb[1]), order.index(comb[0])] = cosine_v
    similarity_matrix_auc[order.index(comb[0]), order.index(comb[1])] = iou_auc_v
    similarity_matrix_auc[order.index(comb[1]), order.index(comb[0])] = iou_auc_v
    baseline_args = []
    
plot_similarity(similarity_matrix_cos, similarity_matrix_auc, order, TYPE)


last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_ft_i0_cwf-float_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_ft_i0_cwf-float_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_fs0_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_fs0_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_fs0_i0_cwf-float_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt last_fs0_i0_cwf-float_n100_cif_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt ioi.pt
last_fs10_i0_cwf-instruction_n100_cif_heads['o', 'q'].pt ioi

In [19]:
order = ["FS-I", "FS-F", "FT-I", "FT-F", "ZS-I", "ZS-F", "IOI"]
combs = [(a, b) for a in order for b in order if a != b and a <= b]
TYPE = ""
LAYER_WISE = True
K_MAX = 10
similarity_matrix_cos = np.zeros((len(order), len(order)))
similarity_matrix_auc = np.zeros((len(order), len(order)))
for comb in combs:
    cosine_v, iou_auc_v = get_metrics(f"{comb[0]}${comb[1]}{TYPE}", "o")
    
    similarity_matrix_cos[order.index(comb[0]), order.index(comb[1])] = cosine_v
    similarity_matrix_cos[order.index(comb[1]), order.index(comb[0])] = cosine_v
    similarity_matrix_auc[order.index(comb[0]), order.index(comb[1])] = iou_auc_v
    similarity_matrix_auc[order.index(comb[1]), order.index(comb[0])] = iou_auc_v
    baseline_args = []
    
plot_similarity(similarity_matrix_cos, similarity_matrix_auc, order, TYPE)



last_fs10_i0_cwf-instruction_n100_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n100_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_heads['o', 'q'].pt last_ft_i0_cwf-instruction_n100_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_heads['o', 'q'].pt last_ft_i0_cwf-float_n100_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_heads['o', 'q'].pt last_ft_i0_cwf-float_n100_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_heads['o', 'q'].pt last_fs0_i0_cwf-instruction_n100_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_heads['o', 'q'].pt last_fs0_i0_cwf-instruction_n100_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_heads['o', 'q'].pt last_fs0_i0_cwf-float_n100_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_heads['o', 'q'].pt last_fs0_i0_cwf-float_n100_heads['o', 'q'].pt
last_fs10_i0_cwf-instruction_n100_heads['o', 'q'].pt ioi.pt
last_fs10_i0_cwf-instruction_n100_heads['o', 'q'].pt ioi.pt
last_fs10_i0_cwf-float_n100_heads['o', 'q'].pt last_fs10_i0_cwf-inst

In [12]:
l = [972, 974, 984, 970, 961, 983, 973, 967, 964, 969, 960, 977, 980, 976, 987, 979, 966, 991, 982, 968, 988, 981, 985, 971, 963, 990, 965, 989, 975, 986, 978, 962, 1014, 923, 1020, 884, 999, 1006, 878, 953, 775, 1009, 1004, 995, 768, 996, 872, 869, 1015, 871, 889, 893, 785, 1021, 1022, 944, 992, 791, 1018, 1016, 1017, 888, 873, 771, 433, 935, 772, 789, 929, 947, 798, 1000, 880, 774, 918, 895, 997, 1012, 994, 788, 1008, 881, 777, 787, 892, 916, 795, 883, 993, 782, 786, 792, 959, 870, 926, 490, 885, 799, 876, 816, 436, 950, 890, 1023, 770, 847, 907, 955, 790, 865, 778, 796, 793, 879, 416, 866, 1001, 794, 936, 784, 867, 891, 776, 958, 781, 1007, 783, 809, 821, 887, 1005, 427, 951, 894, 946, 535, 769, 930, 851, 532, 773, 941, 446, 853, 938, 868, 842, 440, 875, 797, 848, 956, 832, 899, 937, 928, 939, 498, 908, 948, 826, 1003, 827, 568, 779, 780, 942, 822, 805, 952, 444, 387, 910, 914, 439, 921, 493, 825, 932, 906, 731, 901, 934, 846, 861, 898, 949, 954, 856, 940, 1019, 931, 570, 857, 1002, 529, 874, 531, 396, 385, 538, 886, 801, 841, 912, 727, 486, 828, 673, 904, 413, 814, 877, 462, 927, 649, 919, 571, 725, 957, 397, 475, 488, 526, 840, 843, 945, 581, 481, 452, 920, 808, 489, 933, 501, 850, 943, 626, 905, 924, 704, 537, 409, 915, 824, 802, 998, 584, 831, 325, 913, 665, 811, 863, 471, 429]

In [13]:
[divmod(i, 32) for i in l]

[(30, 12),
 (30, 14),
 (30, 24),
 (30, 10),
 (30, 1),
 (30, 23),
 (30, 13),
 (30, 7),
 (30, 4),
 (30, 9),
 (30, 0),
 (30, 17),
 (30, 20),
 (30, 16),
 (30, 27),
 (30, 19),
 (30, 6),
 (30, 31),
 (30, 22),
 (30, 8),
 (30, 28),
 (30, 21),
 (30, 25),
 (30, 11),
 (30, 3),
 (30, 30),
 (30, 5),
 (30, 29),
 (30, 15),
 (30, 26),
 (30, 18),
 (30, 2),
 (31, 22),
 (28, 27),
 (31, 28),
 (27, 20),
 (31, 7),
 (31, 14),
 (27, 14),
 (29, 25),
 (24, 7),
 (31, 17),
 (31, 12),
 (31, 3),
 (24, 0),
 (31, 4),
 (27, 8),
 (27, 5),
 (31, 23),
 (27, 7),
 (27, 25),
 (27, 29),
 (24, 17),
 (31, 29),
 (31, 30),
 (29, 16),
 (31, 0),
 (24, 23),
 (31, 26),
 (31, 24),
 (31, 25),
 (27, 24),
 (27, 9),
 (24, 3),
 (13, 17),
 (29, 7),
 (24, 4),
 (24, 21),
 (29, 1),
 (29, 19),
 (24, 30),
 (31, 8),
 (27, 16),
 (24, 6),
 (28, 22),
 (27, 31),
 (31, 5),
 (31, 20),
 (31, 2),
 (24, 20),
 (31, 16),
 (27, 17),
 (24, 9),
 (24, 19),
 (27, 28),
 (28, 20),
 (24, 27),
 (27, 19),
 (31, 1),
 (24, 14),
 (24, 18),
 (24, 24),
 (29, 31),
 (27, 6

In [23]:
act, attr = load("all_act.pt", "o")

In [28]:
# plot
fig = go.Figure()
fig.add_trace(go.Heatmap(z=act.mean(1).unsqueeze(1).numpy()))