# Paper Figures

In [None]:
import os
from collections import defaultdict
import json
import numpy as np
import pandas as pd
from scipy import stats
import torch

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.lines import Line2D

from vis import (
    plot_top_bottom_evolution, 
    table_generator_three_way, 
    table_generator_two_way, 
    table_generator_three_way_bloom,
    table_generator_two_way_bloom
)

## IE Evolution of Top-5 and Bottom-5 Features

In [None]:
plot_top_bottom_evolution(
    version_num   = 446,
    max_examples  = 3469,
    node_threshold= 0.1,
    task          = "subjectverb",
    ckpt_num      = 20,
    top_k         = 5,
    ckpt_tokens   = ['1B', '4B', '286B'],
    figsize       = (6,16),
    save_path     = "figs/ie_evolution/top5_bottom5_pythia.pdf",
    fontsize      = 28,
)

In [None]:
plot_top_bottom_evolution(
    version_num   = 447,
    max_examples  = 3469,
    node_threshold= 0.1,
    task          = "subjectverb",
    ckpt_num      = 20,
    top_k         = 5,
    ckpt_tokens   = ['4B', '33B', '3T'],
    figsize       = (8,14),
    save_path     = "figs/ie_evolution/top5_bottom5_olmo.pdf"
)

## Pythia Top-10 & Top-100 3D plot of RelIE in 3-way comparison

In [None]:
version_num = 446
base_dir = "./workspace/logs/ie_dicts_zeroshot"
save_dir = f"{base_dir}/version_{version_num}"
node_threshold = 0.1
ckpt_num = 20
max_examples = 3469
task = "subjectverb"

for top_k in [10, 100]:
    csv_path = f"{save_dir}/latents_{task}_ckpt{ckpt_num}_thresh{node_threshold}_n{max_examples}_topk{top_k}.csv"
    df = pd.read_csv(csv_path)
    df['rel_ie_array'] = df['rel_ie_value'].str.strip('[]') \
                                        .apply(lambda s: np.fromstring(s, sep=' '))
    
    coords = np.vstack(df['rel_ie_array'].values)
    x, y, z = coords[:,0], coords[:,1], coords[:,2]

    # (elev, azim) pairs
    views = [
        (20,  30),   # front-angle
        (20, 120),   # side-ish
        (90,  -90),  # top-view
    ]

    for elev, azim in views:
        fig = plt.figure(constrained_layout=True)
        ax  = fig.add_subplot(111, projection='3d')
        
        # split at .3 and .7 => bins 0:[0 – .3], 1:(.3 – .7], 2:(.7 – 1]
        xb = np.digitize(x, [0.3, 0.7])
        yb = np.digitize(y, [0.3, 0.7])
        zb = np.digitize(z, [0.3, 0.7])

        # 2) map low=>0, mid=>.5, high=>1 into RGB channels
        r = xb / 2.0      # 0=>0, 1=>.5,   2=>1.0
        g = yb / 2.0
        b = zb / 2.0
        colors = np.stack([r, g, b], axis=1)

        # 3) scatter with those colors
        ax.scatter(x, y, z, c=colors, marker='o', s=20, edgecolor='k', linewidth=0.2)
        
        ax.set_xlabel('RelIE 1B')
        ax.set_ylabel('RelIE 4B')
        ax.set_zlabel('RelIE 286B')
        ax.set_xlim(0,1)
        ax.set_ylim(0,1)
        ax.set_zlim(0,1)        

        # smaller pad for the title itself
        ax.set_title(
            f'Pythia 3-Way Comparison RelIE for TopK={top_k}',
            pad=0
        )

        # tighten margins
        # fig.tight_layout(pad=0.25)
        
        legend_elements = [
            Line2D([0], [0],
                marker='o', color='w',
                markerfacecolor=(1,0,0), markersize=8,
                label='RelIE 1B high (x > 0.7)'),
            Line2D([0], [0],
                marker='o', color='w',
                markerfacecolor=(0,1,0), markersize=8,
                label='RelIE 4B high (y > 0.7)'),
            Line2D([0], [0],
                marker='o', color='w',
                markerfacecolor=(0,0,1), markersize=8,
                label='RelIE 286B high (z > 0.7)'),
        ]
        
        ax.legend(
            handles=legend_elements,
            title='Channel = high bin',
            loc='upper left',             # legend “corner” anchored to:
            bbox_to_anchor=(1.05, 1.0),   # (x, y) in axis‐fraction coords
            borderaxespad=0.,             # pad between axes and legend
        )

        ax.view_init(elev=elev, azim=azim)
        full_path = f"figs/pyhtia1b_3compar_RelIE_topk{top_k}_elev{elev}_azim{azim}_full.png",
        crop_path = f"figs/pyhtia1b_3compar_RelIE_topk{top_k}_elev{elev}_azim{azim}_cropped.png",
        fig.savefig(
            f"figs/pyhtia1b_3compar_RelIE_topk{top_k}_elev{elev}_azim{azim}.png",
            pad_inches=0.0
        )

## BLOOM MultiBLiMP Overlap Multilinguality Scores

### 2-way

In [None]:
def get_feature_overlap_across_tasks_2way(
        task_list,
        version_num, 
        max_examples,
        submod_layer,
        node_threshold=0.1, 
        ckpt_num=20, 
        top_k=100
    ):
    lang2fidset = {}
    lang2relie = {}
    fid2relie = {}
    fid2reldec = {}
        
    # (1) for all tasks in task_list
    for task in task_list:
        # (1.1) Load IE & Rel IE for all tasks in task_list
        base_dir = "./workspace/logs/ie_dicts_zeroshot"
        save_dir = f"{base_dir}/{version_num}"
        final_path = f"{save_dir}/{task}_ckpt{ckpt_num}_thresh{node_threshold}_n{max_examples}.pt"
        effects = torch.load(final_path)
        
        # (1.2) Make a dict of task -> feat_id set
        csv_path = f"{save_dir}/latents_{task}_ckpt{ckpt_num}_thresh{node_threshold}_n{max_examples}_topk10.csv"
        featdf = pd.read_csv(csv_path)
        lang = task.split("_")[1]
        fid2relie.update(dict(zip(f"{lang}-" + str(featdf['feat_id']), featdf['rel_ie_class'])))
        fid2reldec.update(dict(zip(f"{lang}-" + str(featdf['feat_id']), featdf['rel_dec_norm_class'])))
          
        all_top100_fids = set(
            torch.topk(effects[f"m1_layer{submod_layer}_out"].abs(), k=top_k).indices.tolist()
            + torch.topk(effects[f"m2_layer{submod_layer}_out"].abs(), k=top_k).indices.tolist()
        )
        print(f"# of feats for both models for {task}: ", len(all_top100_fids))
        
        lang2fidset[lang] = all_top100_fids
        
        
    fid2langset = defaultdict(set)
    for lang, fidset in lang2fidset.items():
        for fid in fidset:
            fid2langset[fid].add(lang)
    
    fid2multilangscore = {}
    for fid, langset in fid2langset.items():
        fid2multilangscore[fid] = len(langset)
            
    multilangscore_mean = np.array([len(langset) for fid, langset in fid2langset.items()]).mean()
    # print("Mean: ", multilangscore_mean)
    multilangscore_std = np.array([len(langset) for fid, langset in fid2langset.items()]).std()
    # print("Std: ", multilangscore_std)

    comparison = featdf["comparison"].unique().tolist()[0]
    
    return lang2fidset, comparison

def make_overlap_df(lang2fid, normalize=True):
    langs = sorted(lang2fid)
    if "fra" in langs:
        langs = ["arb", "hin", "fra", "eng", "spa", "por"]
    else:
        langs = ["arb", "hin", "por"]
    mat = np.zeros((len(langs), len(langs)), dtype=float)
    for i, la in enumerate(langs):
        for j, lb in enumerate(langs):
            inter = lang2fid[la] & lang2fid[lb]
            if normalize:
                union = lang2fid[la] | lang2fid[lb]
                mat[i, j] = len(inter) / len(union) if union else 0.0
            else:
                mat[i, j] = len(inter)
    return pd.DataFrame(mat, index=langs, columns=langs)

def plot_heatmap(df, title=None, save_path=None, fontsize=16):
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(df.values, aspect='equal')
    ax.set_xticks(range(len(df.columns)))
    ax.set_xticklabels(df.columns, rotation=45, ha='right', fontsize=fontsize)
    ax.set_yticks(range(len(df.index)))
    ax.set_yticklabels(df.index, fontsize=fontsize)
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label('Feature count')
    thresh = df.values.max() / 2.0
    for i in range(len(df)):
        for j in range(len(df)):
            val = df.values[i, j]
            text_color = "white" if val < thresh else "black"
            ax.text(
                j, i, f"{int(val)}",
                ha='center', va='center',
                color=text_color,
                fontsize=fontsize
            )
    if title:
        ax.set_title(title, fontsize=fontsize + 2)
    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight')
    plt.show()


langs=["eng", "fra", "spa", "por", "arb", "hin"]
svg_langs=["por", "arb", "hin"]

tasks=["SV-#", "SV-G", "SV-P"]
task2len = {
    "SV-#": 100,
    "SV-G": 100, 
    "SV-P": 290
}

for top_k in [10, 100]:
    for task in tasks:
        print("#" * 40)
        print("task: ", task)
        for version_num in ["387",  "400", "409"]:
            print("-" * 20)
            print("version_num: ", version_num)
            task_list = []
            if task == "SV-G":
                task_list = [f"multiblimp_{lang}_{task}" for lang in svg_langs]
            else:
                task_list = [f"multiblimp_{lang}_{task}" for lang in langs]
            lang2fid, comparison = get_feature_overlap_across_tasks_2way(
                task_list=task_list,
                version_num="version_" + version_num,
                max_examples=task2len[task],
                submod_layer=12,
                top_k=top_k
            )
            
            overlap_df = make_overlap_df(lang2fid, normalize=False)
            final_path = f"figs/multilingual_overlap_2way/multilang_overlap_topk{top_k}_{task}_{comparison}"
            final_path = final_path.replace("#", "num")
            final_path = final_path.replace(" ", "_")
            final_path = final_path.replace(".", "")
            final_path += ".png"
            plot_heatmap(overlap_df, title=f"Feature set overlap - {comparison}", save_path=final_path)


### 3-way

In [None]:
import torch
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def get_feature_overlap_across_tasks_3way(
        task_list,
        version_num, 
        max_examples,
        submod_layer,
        node_threshold=0.1, 
        ckpt_num=20, 
        top_k=100,
        effects_k='m0_layer12_out'
    ):
    lang2fidset = {}
    lang2relie = {}
    fid2relie = {}
    fid2reldec = {}
        
    # (1) for all tasks in task_list
    for task in task_list:
        # (1.1) Load IE & Rel IE for all tasks in task_list
        base_dir = "./workspace/logs/ie_dicts_zeroshot"
        save_dir = f"{base_dir}/{version_num}"
        final_path = f"{save_dir}/{task}_ckpt{ckpt_num}_thresh{node_threshold}_n{max_examples}.pt"
        effects = torch.load(final_path)
        
        # (1.2) Make a dict of task -> feat_id set
        csv_path = f"{save_dir}/latents_{task}_ckpt{ckpt_num}_thresh{node_threshold}_n{max_examples}_topk10.csv"
        featdf = pd.read_csv(csv_path)
        lang = task.split("_")[1]
        fid2relie.update(dict(zip(f"{lang}-" + str(featdf['feat_id']), featdf['rel_ie_class'])))
        fid2reldec.update(dict(zip(f"{lang}-" + str(featdf['feat_id']), featdf['rel_dec_norm_class'])))
        
        all_top100_fids = set(
            torch.topk(effects[effects_k].abs(), k=top_k).indices.tolist()
        )
        print(f"# of feats for both models for {task}: ", len(all_top100_fids))
        
        lang2fidset[lang] = all_top100_fids
    
    fid2langset = defaultdict(set)
    for lang, fidset in lang2fidset.items():
        for fid in fidset:
            fid2langset[fid].add(lang)
    
    fid2multilangscore = {}
    for fid, langset in fid2langset.items():
        fid2multilangscore[fid] = len(langset)
    
    multilangscore_mean = np.array([len(langset) for fid, langset in fid2langset.items()]).mean()
    # print("Mean: ", multilangscore_mean)
    multilangscore_std = np.array([len(langset) for fid, langset in fid2langset.items()]).std()
    # print("Std: ", multilangscore_std)
    
    comparison = featdf["comparison"].unique().tolist()[0]
    
    return lang2fidset, comparison

def make_overlap_df(lang2fid, normalize=True):
    langs = sorted(lang2fid)
    if "fra" in langs:
        langs = ["arb", "hin", "fra", "eng", "spa", "por"]
    else:
        langs = ["arb", "hin", "por"]
    mat = np.zeros((len(langs), len(langs)), dtype=float)
    for i, la in enumerate(langs):
        for j, lb in enumerate(langs):
            inter = lang2fid[la] & lang2fid[lb]
            if normalize:
                union = lang2fid[la] | lang2fid[lb]
                mat[i, j] = len(inter) / len(union) if union else 0.0
            else:
                mat[i, j] = len(inter)
    return pd.DataFrame(mat, index=langs, columns=langs)

def plot_heatmap(df, top_k, title=None, save_path=None, fontsize=16):
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(df.values, aspect='equal', vmin=0, vmax=top_k)
    ax.set_xticks(range(len(df.columns)))
    ax.set_xticklabels(df.columns, rotation=45, ha='right', fontsize=fontsize)
    ax.set_yticks(range(len(df.index)))
    ax.set_yticklabels(df.index, fontsize=fontsize)
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label('Feature count')
    thresh = df.values.max() / 2.0
    for i in range(len(df)):
        for j in range(len(df)):
            val = df.values[i, j]
            text_color = "white" if val < thresh else "black"
            ax.text(
                j, i, f"{int(val)}",
                ha='center', va='center',
                color=text_color,
                fontsize=fontsize
            )
    if title:
        ax.set_title(title, fontsize=fontsize + 2)
    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight')
    plt.show()


langs=["eng", "fra", "spa", "por", "arb", "hin"]
svg_langs=["por", "arb", "hin"]

tasks=["SV-#", "SV-G", "SV-P"]
task2len = {
    "SV-#": 100,
    "SV-G": 100, 
    "SV-P": 290
}

for top_k in [10, 100]:
    for task in tasks:
        print("#" * 40)
        print("task: ", task)
        for version_num in ["454"]:
            print("-" * 20)
            print("version_num: ", version_num)
            effects_k_list = ['m0_layer12_out', 'm1_layer12_out', 'm2_layer12_out']
            for i, effects_k in enumerate(effects_k_list):
                print("-" * 20)
                print("effects_k: ", effects_k)
                if task == "SV-G":
                    task_list = [f"multiblimp_{lang}_{task}" for lang in svg_langs]
                else:
                    task_list = [f"multiblimp_{lang}_{task}" for lang in langs]
                lang2fid, comparison = get_feature_overlap_across_tasks_3way(
                    task_list=task_list,
                    version_num="version_" + version_num,
                    max_examples=task2len[task],
                    submod_layer=12,
                    top_k=top_k,
                    effects_k=effects_k
                )
                
                overlap_df = make_overlap_df(lang2fid, normalize=False)
                checkpoint_name = comparison.split(" vs. ")[i]
                
                final_path = f"figs/multilingual_overlap_3way/multilang_overlap_topk{top_k}_{task}_{i}"
                final_path = final_path.replace("#", "num")
                final_path = final_path.replace(" ", "_")
                final_path = final_path.replace(".", "")
                final_path += ".png"            
                plot_heatmap(overlap_df, top_k=top_k, title=f"Feature set overlap - {checkpoint_name}", save_path=final_path)
        
        

## Annotation Table Pythia 3-way

In [None]:
out = table_generator_three_way(
    version_num    = 446,
    model_name     = "Pythia-1B"
)
print(out)

## Annotation Table OLMo 3-way

In [None]:
out = table_generator_three_way(
    version_num    = 447,
    model_name     = "OLMo-1B"
)
print(out)

## Annotation Table BLOOM 3-way (CLAMS & MultiBLiMP)

In [None]:
for task_name in ["clams_fraeng", "multiblimp_eng", "multiblimp_fra", "multiblimp_hin"]:
    out = table_generator_three_way_bloom(
        annotation_filename = f"annotation_{task_name}.csv",
        version_num         = 454,
        model_name          = "BLOOM-1B"
    )
    print(out)
    print("\n\n\n\n\n\n")

## Annotation Table Pythia 2-way

In [None]:
out = table_generator_two_way(
    annotation_filename = f"workspace/annotation_pythia_2way_blimp.csv",
    model_name          = "Pythia-1B"
)
print(out)

## Annotation Table BLOOM 2-way

In [None]:
out = table_generator_two_way_bloom(
    annotation_filename = f"workspace/annotation_bloom_2way_clamsfraeng.csv",
    model_name          = "BLOOM-1B"
)
print(out)

## RelIE & RelDec Corr with Delta M1 Patch / Delta M2 Patch

In [None]:
def clean_scatter(df, x_col, y_col):
    x = df[x_col]
    y = df[y_col]
    
    # compute z-scores for y
    z_scores = stats.zscore(y)
    abs_z = abs(z_scores)

    # filter (e.g. keep only |z| < 3)
    df = df[abs_z < 3]
    
    x = df[x_col]
    y = df[y_col]

    plt.figure()
    plt.scatter(x, y)
    plt.xlabel(x_col)
    plt.ylabel(y_col)
    plt.ylim(0,10)
    plt.title(f'{y_col} vs {x_col}')
    plt.show()

def gather_corr_results(top_k, version_list, model_name, do_corr_plot=False):
    base_dir = "./workspace/logs/ie_dicts_zeroshot"
    subtasks = [
        "distractor_agreement_relational_noun",
        "distractor_agreement_relative_clause",
        "irregular_plural_subject_verb_agreement_1",
        "regular_plural_subject_verb_agreement_1"
    ]
    subtask2name = {
        "subjectverb-distractor_agreement_relational_noun": "\\texttt{Distractor Relational Noun}",
        "subjectverb-distractor_agreement_relative_clause": "\\texttt{Distractor Relative Clause}",
        "subjectverb-irregular_plural_subject_verb_agreement_1": "\\texttt{Irregular Plural Subject}",
        "subjectverb-regular_plural_subject_verb_agreement_1": "\\texttt{Regular Plural Subject}",
        
    }
    
    # Collect DataFrames in a list, then concatenate
    all_dfs = []
    plt_cnt = 0
    for version_num in version_list:
        save_dir = os.path.join(base_dir, "version_" + str(version_num))
        for subtask in subtasks:
            corr_path = os.path.join(
                save_dir,
                f"ablation-task_subjectverb-{subtask}-topk{top_k}-corr.csv"
            )
            if os.path.isfile(corr_path):
                df = pd.read_csv(corr_path)
                all_dfs.append(df)
            else:
                print(f"Warning: file not found: {corr_path}")
            
            if do_corr_plot and plt_cnt < 2:
                summary_path = os.path.join(
                    save_dir,
                    f"ablation-task_subjectverb-{subtask}-topk{top_k}-deltasummary.csv"
                )
                if os.path.isfile(corr_path):
                    summary_df = pd.read_csv(summary_path)
                    clean_scatter(df=summary_df, y_col="LogProbDiff Δ Ratio Abs(M2/M1)", x_col="RelDec")
                    clean_scatter(df=summary_df, y_col="LogProbDiff Δ Ratio Abs(M2/M1)", x_col="RelIE")
                    plt_cnt += 1
                else:
                    print(f"Warning: file not found: {corr_path}")

    # Concatenate all (will append rows)
    if all_dfs:
        combined = pd.concat(all_dfs, ignore_index=True)
        combined["task"] = combined["task"].replace(subtask2name)
        combined = combined[['comparison', 'task', 'rho(Δ Ratio Abs(M2/M1) - RelDec)',
       'rho(Δ Ratio Abs(M2/M1) - RelIE)']]
        
        # write out a single CSV
        output_path = os.path.join(".",  "workspace", "results", f"combined_ablation_corr_{model_name}_topk{top_k}.csv")
        combined.to_csv(output_path, index=False)
        print(f"Combined CSV written to: {output_path}")
        
        # write out a single LaTeX
        num_cols = [
            'rho(Δ Ratio Abs(M2/M1) - RelDec)',
            'rho(Δ Ratio Abs(M2/M1) - RelIE)'
        ]
        
        group_means = (
            combined
            .groupby('comparison')[num_cols]
            .mean()
            .reset_index()
        )
        # label them
        group_means['comparison'] = 'Avg ' + group_means['comparison']
        group_means['task'] = '-'
        # make sure columns are in the same order
        group_means = group_means[['comparison', 'task'] + num_cols]

        # compute overall average
        overall = combined[num_cols].mean().to_frame().T
        overall['comparison'] = 'Avg'
        overall['task'] = '-'
        overall = overall[['comparison','task'] + num_cols]
        
        final = pd.concat([combined, group_means, overall], ignore_index=True)
        print(final.to_latex(index=False, float_format="%.3f"))
    else:
        print("No CSV files were loaded. Check your paths.")
        
gather_corr_results(top_k=10, version_list=[218, 265, 219], model_name="pythia", do_corr_plot=False)
gather_corr_results(top_k=10, version_list="440 430 434 433 436".split(" "), model_name="olmo", do_corr_plot=False)