## Setup

In [1]:
import os
import json
import glob
import torch
import re
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
#import seaborn as sns
import matplotlib.pyplot as plt
from utils.data_processing import (
    load_edge_scores_into_dictionary,
    get_ckpts,
    load_metrics,
    compute_ged,
    compute_weighted_ged,
    compute_gtd,
    compute_jaccard_similarity_to_reference,
    compute_jaccard_similarity,
    aggregate_metrics_to_tensors_step_number,
    get_ckpts
)

## Retrieve & Process Data

### Circuit Data

In [None]:
folder_path = 'results/graphs/pythia-160m-seed1/ioi'
df = load_edge_scores_into_dictionary(folder_path)

Processing file 1/133: results/graphs/pythia-160m-seed1/ioi/57000.json
Processing file 2/133: results/graphs/pythia-160m-seed1/ioi/95000.json
Processing file 3/133: results/graphs/pythia-160m-seed1/ioi/107000.json
Processing file 4/133: results/graphs/pythia-160m-seed1/ioi/34000.json
Processing file 5/133: results/graphs/pythia-160m-seed1/ioi/6000.json
Processing file 6/133: results/graphs/pythia-160m-seed1/ioi/37000.json
Processing file 7/133: results/graphs/pythia-160m-seed1/ioi/39000.json
Processing file 8/133: results/graphs/pythia-160m-seed1/ioi/104000.json
Processing file 9/133: results/graphs/pythia-160m-seed1/ioi/59000.json
Processing file 10/133: results/graphs/pythia-160m-seed1/ioi/67000.json
Processing file 11/133: results/graphs/pythia-160m-seed1/ioi/111000.json
Processing file 12/133: results/graphs/pythia-160m-seed1/ioi/76000.json
Processing file 13/133: results/graphs/pythia-160m-seed1/ioi/5000.json
Processing file 14/133: results/graphs/pythia-160m-seed1/ioi/42000.json


In [16]:
from utils.data_processing import read_json_file

def load_faithfulness_scores_into_df(folder_path, seed_name='seed1234'):
    file_paths = glob.glob(f'{folder_path}/*.json')

    # Create an empty DataFrame to store all edge scores
    all_sizes = pd.DataFrame()

    for i, file_path in enumerate(file_paths):
        print(f'Processing file {i+1}/{len(file_paths)}: {file_path}')
        data = read_json_file(file_path)
        sizes = data.keys()
        scores = [data[size] for size in sizes]

        # Extract checkpoint name from the filename
        checkpoint_name = int(os.path.basename(file_path).replace('.json', ''))
        #checkpoint_name = f'step {checkpoint_name}'

        checkpoint_df = pd.DataFrame({'size': sizes, 'faithfulness_score': scores, 'checkpoint': checkpoint_name, 'seed': seed_name})
        all_sizes = pd.concat([all_sizes, checkpoint_df])

        #ensure size and checkpoint are integer columns
        all_sizes['size'] = all_sizes['size'].astype(int)
        all_sizes['checkpoint'] = all_sizes['checkpoint'].astype(int)


    # sort by checkpoint and then by size
    all_sizes = all_sizes.sort_values(by=['seed', 'checkpoint', 'size'])
    return all_sizes

In [22]:
df_ff_seed_1234 = load_faithfulness_scores_into_df('results/faithfulness/pythia-160m/ioi')
df_ff_seed_1 = load_faithfulness_scores_into_df('results/faithfulness/pythia-160m-seed1/ioi', seed_name='seed1')
df_ff_seed_2 = load_faithfulness_scores_into_df('results/faithfulness/pythia-160m-seed2/ioi', seed_name='seed2')
df_ff_seed_3 = load_faithfulness_scores_into_df('results/faithfulness/pythia-160m-seed3/ioi', seed_name='seed3')

Processing file 1/140: results/faithfulness/pythia-160m/ioi/57000.json
Processing file 2/140: results/faithfulness/pythia-160m/ioi/141000.json
Processing file 3/140: results/faithfulness/pythia-160m/ioi/95000.json
Processing file 4/140: results/faithfulness/pythia-160m/ioi/107000.json
Processing file 5/140: results/faithfulness/pythia-160m/ioi/34000.json
Processing file 6/140: results/faithfulness/pythia-160m/ioi/6000.json
Processing file 7/140: results/faithfulness/pythia-160m/ioi/37000.json
Processing file 8/140: results/faithfulness/pythia-160m/ioi/39000.json
Processing file 9/140: results/faithfulness/pythia-160m/ioi/104000.json
Processing file 10/140: results/faithfulness/pythia-160m/ioi/59000.json
Processing file 11/140: results/faithfulness/pythia-160m/ioi/67000.json
Processing file 12/140: results/faithfulness/pythia-160m/ioi/111000.json
Processing file 13/140: results/faithfulness/pythia-160m/ioi/76000.json
Processing file 14/140: results/faithfulness/pythia-160m/ioi/5000.json

In [23]:
# concatenate all seeds
df_ff = pd.concat([df_ff_seed_1234, df_ff_seed_1, df_ff_seed_2, df_ff_seed_3])
df_ff.head()

Unnamed: 0,size,faithfulness_score,checkpoint,seed
4,1,1.0,4000,seed1234
3,3,2.28125,4000,seed1234
2,6,1.414062,4000,seed1234
1,13,1.046875,4000,seed1234
0,25,1.578125,4000,seed1234


In [30]:
# graph faithfulness scores for each seed for checkpoint 143000
checkpoint = 52000
sub_df_ff = df_ff[df_ff['checkpoint'] == checkpoint]
fig = px.line(sub_df_ff, x='size', y='faithfulness_score', color='seed')
fig.update_layout(title=f'Faithfulness scores for each seed at checkpoint {checkpoint}', xaxis_title='Size', yaxis_title='Faithfulness score')
fig.show()

In [26]:
import pandas as pd
import plotly.express as px

# Assuming `df` is your DataFrame
# Pivot the DataFrame to get the seeds as rows, sizes as columns, and faithfulness scores as values
heatmap_data = df_ff.pivot(index='seed', columns='size', values='faithfulness_score')

# Create the heatmap
fig = px.imshow(heatmap_data,
                labels=dict(x="Size", y="Seed", color="Faithfulness Score"),
                x=heatmap_data.columns,  # Explicitly specify x-axis categories to ensure correct order
                y=heatmap_data.index,    # Explicitly specify y-axis categories to ensure correct order
                aspect="auto")           # Let plotly choose the aspect ratio for best fit or use 'equal' for square cells

# Optional: Update layout for a better presentation
fig.update_layout(
    title="Heatmap of Faithfulness Scores",
    xaxis_title="Size",
    yaxis_title="Seed",
    xaxis_nticks=36)  # Adjust this for the number of sizes you have for better tick distribution

fig.show()


ValueError: Index contains duplicate entries, cannot reshape

In [18]:
df_ff_seed_1['checkpoint'].unique()

array([  4000,   5000,   6000,   7000,   8000,   9000,  10000,  11000,
        12000,  13000,  14000,  15000,  16000,  17000,  18000,  19000,
        20000,  21000,  22000,  23000,  24000,  25000,  26000,  27000,
        28000,  29000,  30000,  31000,  32000,  33000,  34000,  35000,
        36000,  37000,  38000,  39000,  40000,  41000,  42000,  43000,
        44000,  45000,  46000,  47000,  48000,  49000,  50000,  51000,
        52000,  53000,  54000,  55000,  56000,  57000,  58000,  59000,
        60000,  61000,  62000,  63000,  64000,  65000,  66000,  67000,
        68000,  69000,  70000,  71000,  72000,  73000,  74000,  75000,
        76000,  77000,  78000,  79000,  80000,  81000,  82000,  83000,
        84000,  85000,  86000,  87000,  88000,  89000,  90000,  91000,
        92000,  93000,  94000,  95000,  96000,  97000,  98000,  99000,
       100000, 101000, 102000, 103000, 104000, 105000, 106000, 107000,
       108000, 109000, 110000, 111000, 112000, 113000, 114000, 115000,
      