## Setup

In [25]:
import os
import json
import glob
import torch
import re
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
#import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import display, clear_output, HTML
from utils.data_processing import (
    load_graphs_for_models,
    get_ckpts,
    load_metrics,
    load_edge_scores_into_dictionary,
    compute_ged,
    compute_weighted_ged,
    compute_gtd,
    compute_jaccard_similarity_to_reference,
    compute_jaccard_similarity,
    compute_weighted_jaccard_similarity,
    compute_weighted_jaccard_similarity_to_reference,

)
from utils.result_plotting import plot_graph_metric

In [2]:
# =============================================================================
#import kaleido
#pio.renderers.default = 'png' # USE IF MAKING GRAPHS FOR NOTEBOOK EXPORT
# =============================================================================

In [16]:
TASK = 'greater_than'
PERFORMANCE_METRIC = 'prob_diff'
MODEL_NAME = 'pythia-160m'

## Retrieve & Process Data

### Circuit Data

In [19]:
# only needs to be run if another model is added or graph files change in some way
# all_graphs = load_graphs_for_models('results/graphs', TASK)
# all_graphs = all_graphs[all_graphs['in_circuit'] == True]
# all_graphs.to_pickle(f'results/all_minimal_graphs_{TASK}.pkl')
# clear_output()

Processing file 1/143: results/graphs/pythia-160m/greater_than/57000.json
                  edge         score  in_circuit  checkpoint
0      input->a0.h0<q> -5.264282e-04       False       57000
1      input->a0.h0<k> -4.720688e-05       False       57000
2      input->a0.h0<v>  1.821518e-04       False       57000
3      input->a0.h1<q>  2.944469e-05       False       57000
4      input->a0.h1<k> -1.919270e-05       False       57000
...                ...           ...         ...         ...
32342   a11.h8->logits  1.606531e-08       False       57000
32343   a11.h9->logits  5.625188e-07       False       57000
32344  a11.h10->logits -3.667083e-09       False       57000
32345  a11.h11->logits -1.611188e-07       False       57000
32346      m11->logits  7.202148e-03        True       57000

[32347 rows x 4 columns]
Processing file 2/143: results/graphs/pythia-160m/greater_than/141000.json
                  edge         score  in_circuit  checkpoint
0      input->a0.h0<q> -5.264282

In [21]:
# load dataframe from pickle file
all_graphs = pd.read_pickle(f'results/all_minimal_graphs_{TASK}.pkl')
all_graphs = all_graphs[all_graphs['checkpoint'] >= 4000]
all_graphs.rename(columns={'subfolder': 'model'}, inplace=True)
all_graphs.sort_values(by=['model', 'checkpoint'], inplace=True)

# Group by checkpoint and subfolder and sum the number of edges
subgraph_df = all_graphs.groupby(['checkpoint', 'model']).sum().reset_index()

subgraph_df.head()

Unnamed: 0,edge,score,in_circuit,checkpoint,model
102570,a4.h6->m8,0.002136,True,4000,pythia-160m
103967,a4.h6->m7,0.001678,True,4000,pythia-160m
104749,a5.h9->m7,0.002609,True,4000,pythia-160m
106296,m0->a4.h6<v>,0.005341,True,4000,pythia-160m
106311,m0->a4.h11<v>,0.00351,True,4000,pythia-160m


Unnamed: 0,checkpoint,model,edge,score,in_circuit
0,4000,pythia-160m,a4.h6->m8a4.h6->m7a5.h9->m7m0->a4.h6<v>m0->a4....,0.117012,32
1,5000,pythia-160m,a5.h9->m6a4.h11->m8a4.h6->m8a4.h6->m7a5.h9->m7...,0.125679,34
2,6000,pythia-160m,m7->m9a7.h10->m9m8->m9a5.h9->m9m6->m9a4.h6->m1...,0.100861,40
3,7000,pythia-160m,a4.h11->m8a5.h0->m7a5.h9->a7.h10<v>a5.h9->m7m0...,0.125198,29
4,8000,pythia-160m,m7->m9a7.h10->m9m8->m9a5.h9->m9a4.h6->m10m7->m...,0.116714,57


In [23]:
# sort this
models = subgraph_df['model'].unique().tolist()
models.sort()

### Performance Data

In [39]:
# ONLY NEED TO RUN WHEN NEW CHECKPOINTS ARE ADDED

# perf_metrics_by_model = dict()

# directory_path = 'results'
# perf_metrics_by_model = load_metrics(directory_path)

# ckpts = get_ckpts(schedule="exp_plus_detail")

# for model in perf_metrics_by_model.keys():
#     for task in perf_metrics_by_model[model].keys():
#         for metric in perf_metrics_by_model[model][task].keys():
#             perf_metric = perf_metrics_by_model[model][task][metric]
#             perf_metric = [x.item() for x in perf_metric]
#             perf_metric_dict = dict(zip(ckpts, perf_metric))
#             perf_metrics_by_model[model][task][metric] = perf_metric_dict

# # save dictionary to file
# torch.save(perf_metrics_by_model, 'results/task_performance_metrics/all_models_task_performance.pt')

In [40]:
perf_metrics_by_model = torch.load('results/task_performance_metrics/all_models_task_performance.pt')

# The following can be replaced with any dictionary with (checkpoint: metric) structure,
# e.g. from baselines or other model task runs
perf_metric_dict = perf_metrics_by_model[MODEL_NAME][TASK][PERFORMANCE_METRIC]

{1: -8.466382860206068e-05,
 2: -8.481797704007477e-05,
 4: -9.149761899607256e-05,
 8: -0.00011807034024968743,
 16: -0.00014619850844610482,
 32: -1.1029071174561977e-05,
 64: -0.000151446380186826,
 128: -0.0022255955263972282,
 256: -0.005887055769562721,
 512: -0.02106129564344883,
 1000: -0.17926901578903198,
 2000: 0.14408066868782043,
 3000: 0.30164390802383423,
 4000: 0.5214207172393799,
 5000: 0.6283944845199585,
 6000: 0.42528706789016724,
 7000: 0.6616678237915039,
 8000: 0.5040611624717712,
 9000: 0.5024839639663696,
 10000: 0.5974746346473694,
 11000: 0.6558032035827637,
 12000: 0.4376274645328522,
 13000: 0.5673839449882507,
 14000: 0.5760918855667114,
 15000: 0.673001766204834,
 20000: 0.585519552230835,
 25000: 0.6561374664306641,
 30000: 0.7069724202156067,
 35000: 0.8510973453521729,
 40000: 0.7858152389526367,
 45000: 0.8430061340332031,
 50000: 0.8594885468482971,
 55000: 0.8246071934700012,
 60000: 0.8277291059494019,
 65000: 0.886686384677887,
 70000: 0.863932073

In [41]:
perf_metric_dict.keys()

dict_keys([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000, 55000, 60000, 65000, 70000, 80000, 90000, 100000, 110000, 120000, 130000, 140000])

## Visualize Results

### Graph Size

In [43]:
for model in models:
    model_df = subgraph_df[subgraph_df['model'] == model].copy()
    plot_graph_metric(
        model_df, 
        'in_circuit', 
        perf_metric_dict, # note that this will interpolate missing values
        f'Graph Size for {model}',
        left_y_title="Edge Count", 
        y_range=200, 
        x_axis_col='checkpoint', 
        log_x=True, 
        disable_title=False # Optional: Set to True to disable the title for publishing
    )

### Graph Similarity

In [44]:
for model in models:
    model_df = all_graphs[all_graphs['model'] == model].copy()
    weighted_jaccard_results = compute_weighted_jaccard_similarity(model_df)
    plot_graph_metric(
        weighted_jaccard_results, 
        'jaccard_similarity', 
        perf_metric_dict, 
        f'Weighted Jaccard Similarity to Previous Checkpoint for {model}', 
        left_y_title="Weighted Jaccard Similarity",
        y_range=1, 
        x_axis_col='checkpoint_2', 
        log_x=False,
        disable_title=True
    )

In [13]:
comparison_checkpoint = 5000

for model in models:
    model_df = all_graphs[all_graphs['model'] == model].copy()
    jaccard_reference_results = compute_weighted_jaccard_similarity_to_reference(model_df, comparison_checkpoint)
    plot_graph_metric(
        jaccard_reference_results, 
        'jaccard_similarity', 
        perf_metrics_by_model[model], 
        f'Weighted Jaccard Similarity to Step {comparison_checkpoint} for {model}', 
        left_y_title="Weighted Jaccard Similarity",
        y_range=1, 
        x_axis_col='checkpoint', 
        log_x=True,
        disable_title=True
    )

In [14]:
comparison_checkpoint = 143000

for model in models:
    model_df = all_graphs[all_graphs['model'] == model].copy()
    jaccard_reference_results = compute_weighted_jaccard_similarity_to_reference(model_df, comparison_checkpoint)
    plot_graph_metric(
        jaccard_reference_results, 
        'jaccard_similarity', 
        perf_metrics_by_model[model], 
        f'Weighted Jaccard Similarity to Step {comparison_checkpoint} for {model}', 
        left_y_title="Weighted Jaccard Similarity",
        y_range=1, 
        x_axis_col='checkpoint', 
        log_x=True,
        disable_title=True
    )

In [15]:
from utils.data_processing import compute_ewma_weighted_jaccard_similarity
for model in models:
    model_df = all_graphs[all_graphs['model'] == model].copy()
    jaccard_reference_results = compute_ewma_weighted_jaccard_similarity(model_df, alpha=0.1)
    plot_graph_metric(
        jaccard_reference_results, 
        'ewma_change_rate',
        perf_metrics_by_model[model],
        f'EWMA Weighted Change Rate for {model}',
        left_y_title="Weighted Change Rate",
        y_range=1.0,
        x_axis_col='checkpoint_2',
        log_x=False,
        disable_title=True
    )