In [7]:
import h5py
import pandas as pd

import matplotlib
import seaborn as sns
import matplotlib.pyplot as plt

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rc('font', size=8)

In [8]:
AHS = {
    "ioi": {
        "gpt2-small": [(8, 6), (4, 11), (3, 0), (9, 9), (9, 6), (10, 0)],
        "pythia-160m": [(2,6), (4,11), (7,9), (8,2), (8,10), (10,7)],
        "gemma-2-2b": [(12, 7), (10, 0), (2, 5), (3, 0), (10, 3)] # Top-5 most active heads 
    },
    "gp": {
        "gpt2-small": [(10, 9), (9, 7), (4, 3), (1, 10)],
        "pythia-160m": [(3, 2), (9, 3), (1, 2), (2, 11), (11, 9)], # Top-5 most active heads,
        "gemma-2-2b": [(12, 7), (3, 0), (21, 4), (2, 5), (2, 6)] # Top-5 most active heads,
    },
    "gt": {
        "gpt2-small": [(5,5), (6,1), (6,9), (7,10), (8,11), (9,1)],
        "pythia-160m": [(3, 2), (3, 0), (1, 7), (3, 1), (1, 2)], # Top-5 most active heads 
    }
}

ranks = {
    "gpt2-small": 64,
    "pythia-160m": 64,
    "gemma-2-2b": 256
}

In [9]:
for task in AHS:
    for model in AHS[task]:
        print(model, task)
        n_prompts = 256 if task in ["ioi", "gt"] else 100

        filename = f"tracing_results/{model}_{task}_{n_prompts}_0.hdf5"

        AHS_TASK = AHS[task][model]
        AHS_data = {x: [] for x in AHS_TASK}

        with h5py.File(filename, "r") as f:
            for key in f.keys():
                if key.startswith("svs_used_decomp"):
                    _, layer, ah_idx, _, _, = eval(key.split("_")[-1])
                    if len(f[key][:]) > 0: # Cases where we did trace.
                        if (layer, ah_idx) in AHS_TASK:
                            AHS_data[layer, ah_idx].append(len(f[key][:]))

        df = pd.DataFrame([(f"AH{key}", value) for key, values in AHS_data.items() for value in values],
                  columns=['AH', 'n_svs'])
        
        fig, ax = plt.subplots(1, 1, figsize=(3, 1.9), sharey=True, sharex=True)
        if task == "gp" and model == "pythia-160m":
            bw_method = 0.5
        else:
            bw_method = 0.3
        sns.kdeplot(df, x="n_svs", hue="AH", bw_method = bw_method, alpha=0.75 ,linewidth=1.5, ax=ax, common_norm=False);
        plt.xlabel("Number of singular vectors in " + r"$|S^{\ell ads}|$");
        plt.xlim(0)
        # Edit the legend
        legend = ax.get_legend()
        legend.set_title(None)
        for text in legend.get_texts():
            text.set_fontsize(8)
        plt.tight_layout();
        plt.savefig(f'figures/sparse_attn_decomp/n_svs_{model}_{task}.pdf', bbox_inches='tight', dpi=800);
        plt.close()

gpt2-small ioi
pythia-160m ioi
gemma-2-2b ioi
gpt2-small gp
pythia-160m gp
gemma-2-2b gp
gpt2-small gt
pythia-160m gt


In [10]:
# General results plots

In [11]:
for task in AHS.keys():
    data = {}
    for model in AHS[task]:
        n_prompts = 256 if task in ["ioi", "gt"] else 100

        rank = ranks[model]

        filename = f"tracing_results/{model}_{task}_{n_prompts}_0.hdf5"

        data[model] = []

        with h5py.File(filename, "r") as f:
            for key in f.keys():
                if key.startswith("svs_used_decomp"):
                    _, layer, ah_idx, _, _, = eval(key.split("_")[-1])
                    if len(f[key][:]) > 0: # Cases where we did trace.
                        data[model].append(len(f[key][:]) / rank) # fraction of the available SVs

    df_total = pd.DataFrame([(key, value) for key, values in data.items() for value in values],
                    columns=['model', 'n_svs'])

    fig, ax = plt.subplots(1, 1, figsize=(3,1.9), sharey=True, sharex=True)
    sns.kdeplot(df_total, x="n_svs", hue="model", bw_method = 0.25, alpha=0.75 ,linewidth=1.5, ax=ax, common_norm=False);
    plt.xlabel("Fraction of singular vectors in " + r"$S^{\ell ads}$");
    plt.xlim(0)
    # Edit the legend
    legend = ax.get_legend()
    legend.set_title(None)
    for text in legend.get_texts():
        text.set_fontsize(8)
    plt.tight_layout();
    plt.savefig(f'figures/sparse_attn_decomp/distribution_n-svs_{task}.pdf', bbox_inches='tight', dpi=800);
    plt.close()