In [1]:
import os
import torch
import pandas as pd
import plotly.express as px
from utils.cspa_main import (
    display_cspa_grids
)
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

## Plot Results for CSPA Experiments

In [2]:
def format_cspa_data_for_plots(checkpoint_dict):
    # Transform the dictionary
    head_values_dict = {}
    for checkpoint, tensor in checkpoint_dict.items():
        for layer_idx, layer in enumerate(tensor):
            for head_idx, value in enumerate(layer):
                head_key = f"L{layer_idx}H{head_idx}"
                if head_key not in head_values_dict:
                    head_values_dict[head_key] = {}
                head_values_dict[head_key][checkpoint] = value

    # Calculate the sum of values for each head across all checkpoints
    head_sums = {head_key: sum(values.values()) for head_key, values in head_values_dict.items()}

    # Sort heads by their sums and select the top 10
    top_heads = sorted(head_sums, key=head_sums.get, reverse=True)[:5]

    # Prepare data for plotting (only for top 10 heads)
    plot_data = []
    for head_key in top_heads:
        sorted_checkpoints = sorted(head_values_dict[head_key].keys())
        for checkpoint in sorted_checkpoints:
            plot_data.append({'Head': head_key, 'Checkpoint': checkpoint, 'CSPA Score': head_values_dict[head_key][checkpoint]})

    # Plot using Plotly Express
    df = pd.DataFrame(plot_data)

    return df


def load_checkpoints(target_directory):
    checkpoints_list = []
    for root, dirs, files in os.walk(target_directory):
        for dir in dirs:
            checkpoint_path = os.path.join(root, dir, 'all_checkpoints.pt')
            if os.path.exists(checkpoint_path):
                checkpoints = torch.load(checkpoint_path)
                checkpoints_list.append((dir, checkpoints))
    
    return checkpoints_list


In [6]:
target_directory = 'results/cspa'
loaded_checkpoints = load_checkpoints(target_directory)
loaded_checkpoints.sort(key=lambda x: x[0])
for model_shortname, checkpoints in loaded_checkpoints:
    print(f"Subfolder: {model_shortname}")

    model_df = format_cspa_data_for_plots(checkpoints)

    fig = px.line(model_df, x='Checkpoint', y='CSPA Score', color='Head', markers=True,
                title=f'Top 5 Head CSPA Score Over Checkpoints ({model_shortname})')
    
    fig.update_layout(
        yaxis=dict(range=[-0.1, 0.8], tickformat=".1%")
    )
    fig.show()

Subfolder: pythia-160m


Subfolder: pythia-160m-alldropout


Subfolder: pythia-160m-attndropout


Subfolder: pythia-160m-data-seed1


Subfolder: pythia-160m-data-seed2


Subfolder: pythia-160m-data-seed3


Subfolder: pythia-160m-hiddendropout


Subfolder: pythia-160m-weight-seed1


Subfolder: pythia-160m-weight-seed2


Subfolder: pythia-160m-weight-seed3


In [4]:
for model_shortname, checkpoints in loaded_checkpoints:
    print(f"Subfolder: {model_shortname}")

    cspa_sums = dict()
    checkpoint_keys = list(checkpoints.keys())
    checkpoint_keys.sort()

    for ckpt_key in checkpoint_keys:
        cspa_sums[ckpt_key] = checkpoints[ckpt_key].sum() / 144

    # Convert cspa_sums to DataFrame
    df = pd.DataFrame(list(cspa_sums.items()), columns=['Checkpoint', 'CSPA Score'])

    # Plot using DataFrame
    fig = px.line(df, x='Checkpoint', y='CSPA Score', title=f'CSPA Score Over Checkpoints ({model_shortname})')

    # Set y-axis range and format as percentage
    fig.update_layout(
        yaxis=dict(range=[0, 0.012], tickformat=".1%"),
        title=f'CSPA Score Over Checkpoints ({model_shortname})'
    )

    fig.show()

Subfolder: pythia-160m


Subfolder: pythia-160m-alldropout


Subfolder: pythia-160m-attndropout


Subfolder: pythia-160m-data-seed1


Subfolder: pythia-160m-data-seed2


Subfolder: pythia-160m-data-seed3


Subfolder: pythia-160m-hiddendropout


Subfolder: pythia-160m-weight-seed1


Subfolder: pythia-160m-weight-seed2


Subfolder: pythia-160m-weight-seed3
