In [1]:
import os
import json
import glob
import torch
import re
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
#import seaborn as sns
import matplotlib.pyplot as plt
from utils.data_processing import (
    load_edge_scores_into_dictionary,
    get_ckpts,
    load_metrics,
    compute_ged,
    compute_weighted_ged,
    compute_gtd,
    compute_jaccard_similarity_to_reference,
    compute_jaccard_similarity,
    aggregate_metrics_to_tensors_step_number,
    get_ckpts
)

In [13]:
def read_json_file(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)



def load_nodes_into_dictionary(folder_path):
    file_paths = glob.glob(f'{folder_path}/*.json')

    # Create an empty DataFrame to store all edge scores
    all_nodes = pd.DataFrame()

    for i, file_path in enumerate(file_paths):
        print(f'Processing file {i+1}/{len(file_paths)}: {file_path}')
        data = read_json_file(file_path)
        nodes = data['nodes']
        node_names = list(nodes.keys())
        circuit_inclusion = [status for status in nodes.values()]

        # Extract checkpoint name from the filename
        checkpoint_name = int(os.path.basename(file_path).replace('.json', ''))
        #checkpoint_name = f'step {checkpoint_name}'

        checkpoint_df = pd.DataFrame({'node': node_names, 'in_circuit': circuit_inclusion, 'checkpoint': checkpoint_name})
        all_nodes = pd.concat([all_nodes, checkpoint_df])

    all_nodes = all_nodes.sort_values('checkpoint')
    return all_nodes

In [68]:
folder_path = 'results/graphs/pythia-160m/ioi'

node_df = load_nodes_into_dictionary(folder_path)
edge_df = load_edge_scores_into_dictionary(folder_path)

node_df = node_df[node_df['checkpoint'] >= 4000]
edge_df = edge_df[edge_df['checkpoint'] >= 4000]
edge_df = edge_df[edge_df['checkpoint'] <= 78000]

Processing file 1/97: results/graphs/pythia-160m/ioi/57000.json
Processing file 2/97: results/graphs/pythia-160m/ioi/34000.json
Processing file 3/97: results/graphs/pythia-160m/ioi/6000.json
Processing file 4/97: results/graphs/pythia-160m/ioi/37000.json
Processing file 5/97: results/graphs/pythia-160m/ioi/39000.json
Processing file 6/97: results/graphs/pythia-160m/ioi/59000.json
Processing file 7/97: results/graphs/pythia-160m/ioi/67000.json
Processing file 8/97: results/graphs/pythia-160m/ioi/16.json
Processing file 9/97: results/graphs/pythia-160m/ioi/76000.json
Processing file 10/97: results/graphs/pythia-160m/ioi/1.json
Processing file 11/97: results/graphs/pythia-160m/ioi/5000.json
Processing file 12/97: results/graphs/pythia-160m/ioi/42000.json
Processing file 13/97: results/graphs/pythia-160m/ioi/77000.json
Processing file 14/97: results/graphs/pythia-160m/ioi/80000.json
Processing file 15/97: results/graphs/pythia-160m/ioi/63000.json
Processing file 16/97: results/graphs/pythi

In [69]:
# get unique edges
edge_df[['source', 'target']] = edge_df['edge'].str.split('->', expand=True)
len(edge_df['target'].unique())

445

In [70]:
# normalize edge scores for each checkpoint. edge_df has the following columns: edge, score, in_circuit, checkpoint
# edge scores should be normalized by the sum of edge scores for each checkpoint
# only sum scores for edges where in_circuit is True
# Normalize scores by the sum of scores for each checkpoint
edge_df['normalized_score'] = edge_df.groupby('checkpoint')['score'].transform(lambda x: x / x.sum())

edge_df.head(10)

Unnamed: 0,edge,score,in_circuit,checkpoint,source,target,normalized_score
10784,a0.h10->a7.h5<v>,2.473593e-06,False,4000,a0.h10,a7.h5<v>,1.567943e-05
10783,a0.h10->a7.h5<k>,3.635883e-06,False,4000,a0.h10,a7.h5<k>,2.304687e-05
10782,a0.h10->a7.h5<q>,-2.369285e-06,False,4000,a0.h10,a7.h5<q>,-1.501825e-05
10774,a0.h10->a7.h2<k>,-5.550683e-07,False,4000,a0.h10,a7.h2<k>,-3.518426e-06
10781,a0.h10->a7.h4<v>,-4.544854e-07,False,4000,a0.h10,a7.h4<v>,-2.880859e-06
10778,a0.h10->a7.h3<v>,-4.544854e-07,False,4000,a0.h10,a7.h3<v>,-2.880859e-06
10779,a0.h10->a7.h4<q>,5.252659e-07,False,4000,a0.h10,a7.h4<q>,3.329517e-06
10777,a0.h10->a7.h3<k>,3.306195e-08,False,4000,a0.h10,a7.h3<k>,2.095707e-07
10776,a0.h10->a7.h3<q>,4.991889e-07,False,4000,a0.h10,a7.h3<q>,3.164222e-06
10775,a0.h10->a7.h2<v>,1.117587e-06,False,4000,a0.h10,a7.h2<v>,7.084079e-06


In [71]:
# Sum the normalized scores for each source node
source_scores = edge_df.groupby(['source'])['normalized_score'].sum()
source_scores = source_scores.reset_index()
source_scores = source_scores.rename(columns={'source': 'node', 'normalized_score': 'source_score'})

# sort in descending order
source_scores = source_scores.sort_values('source_score', ascending=False)
source_scores.head(30)

Unnamed: 0,node,source_score
131,a8.h9,22.564401
144,input,21.913746
145,m0,19.913437
25,a10.h1,2.314663
75,a4.h11,2.256133
104,a6.h6,1.937911
149,m2,1.823485
112,a7.h2,1.300857
56,a2.h6,0.812817
103,a6.h5,0.798604


In [72]:
# for each checkpoint, how many unique sources are there? (filter by in_circuit==True)
unique_sources = edge_df[edge_df['in_circuit']==True].groupby('checkpoint')['source'].nunique()

In [74]:
unique_sources.head(40)

checkpoint
5000     16
6000     20
7000     16
8000     22
9000     17
10000    21
11000    34
12000    42
13000    48
14000    46
15000    38
16000    40
17000    39
18000    42
19000    43
20000    34
21000    54
22000    38
23000    38
24000    52
25000    42
26000    58
27000    44
28000    52
29000    43
30000    50
31000    42
32000    44
33000    49
34000    59
35000    48
36000    63
37000    37
38000    44
39000    53
40000    61
41000    60
42000    46
43000    69
44000    48
Name: source, dtype: int64