In [73]:
import os
import json
import glob
import re
import sys
from typing import Dict, List

import torch
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import chart_studio
from chart_studio import plotly as py

from IPython.display import display, clear_output, HTML

sys.path.append("..")
from utils.data_processing import (
    load_edge_scores_into_dictionary,
    compute_weighted_jaccard_similarity,
    compute_weighted_jaccard_similarity_to_reference,
    compute_ewma_weighted_jaccard_similarity,
    generate_in_circuit_df_files
)
from utils.result_plotting import plot_head_circuit_scores, plot_graph_metric

In [74]:
# ignore the following if not using chart_studio. if you do want to publish graphs, simple include upload=True in the plotly function

# load API key from local file chart_studio_api_key.txt - should be username and api key separated by a comma
with open("../auth/chart_studio_api_key.txt") as f:
    username, api_key = f.read().strip().split(",")
    # strip leading whitespace
    username = username.strip()
    api_key = api_key.strip()

chart_studio.tools.set_credentials_file(username=username, api_key=api_key)

In [75]:
TASK = 'ioi'
PERFORMANCE_METRIC = 'logit_diff'
MODEL_NAME = 'pythia-410m'

## Graph Metrics

In [76]:
# ONLY NEEDS TO BE RUN IF EAP IS REPEATED FOR MODEL/TASK OR NEW CHECKPOINTS ARE ADDED

# generate_in_circuit_df_files('/mnt/hdd-0/circuits-over-time/results/graphs', start_checkpoint=3000, limit_to_model=MODEL_NAME, limit_to_task=TASK)
# clear_output()

In [77]:
# load circuit graph dataframe from file
in_circuit_df = pd.read_feather(f'/mnt/hdd-0/circuits-over-time/results/graphs/{MODEL_NAME}/{TASK}/in_circuit_edges.feather')
edge_count_df = in_circuit_df.groupby('checkpoint').size().reset_index(name='num_edges')

# load performance metrics, e.g. logit diff, from file
perf_metrics_by_model = torch.load('/mnt/hdd-0/circuits-over-time/results/task_performance_metrics/all_models_task_performance.pt')

# The following can be replaced with any dictionary with (checkpoint: metric) structure,
# e.g. from baselines or other model task runs
perf_metric_dict = perf_metrics_by_model[MODEL_NAME][TASK][PERFORMANCE_METRIC]

In [78]:
in_circuit_df.head()

Unnamed: 0,edge,score,in_circuit,checkpoint
0,a12.h8->a13.h5<q>,-0.004456,True,4000
1,a11.h4->m13,0.005157,True,4000
2,a11.h4->a13.h5<q>,0.023438,True,4000
3,a11.h4->m14,0.004608,True,4000
4,a11.h4->a12.h0<q>,-0.009949,True,4000


### Graph Size

In [79]:
plot_graph_metric(
    edge_count_df, 
    'num_edges', 
    perf_metric_dict, 
    f'Graph Size for {MODEL_NAME}', 
    right_y_title="Logit Diff",
    y_ranges=((0, 1500), (0, 6)), 
    left_y_title="Edge Count", 
    x_axis_col='checkpoint', 
    log_x=True,
    output_path = "/mnt/hdd-0/circuits-over-time/results/plots/graph_metrics/"
)

### Graph Similarity

In [80]:
weighted_jaccard_results = compute_weighted_jaccard_similarity(in_circuit_df)
plot_graph_metric(
    weighted_jaccard_results, 
    'jaccard_similarity', 
    perf_metric_dict, 
    f'Jaccard Similarity for {MODEL_NAME}', 
    y_ranges=((0, 1), (0, 6)), 
    left_y_title="Jaccard Similarity", 
    x_axis_col='checkpoint_2', 
    log_x=True,
    metric_legend_name="Jaccard Sim",
    output_path = "/mnt/hdd-0/circuits-over-time/results/plots/graph_metrics/"
)

In [81]:
comparison_checkpoint = 5000 
jaccard_reference_results = compute_weighted_jaccard_similarity_to_reference(in_circuit_df, comparison_checkpoint)
plot_graph_metric(
    jaccard_reference_results, 
    'jaccard_similarity', 
    perf_metric_dict, 
    f'Weighted Jaccard Similarity to Checkpoint {comparison_checkpoint} for {MODEL_NAME}', 
    y_ranges=((0, 1), (0, 6)), 
    left_y_title="Jaccard Similarity", 
    x_axis_col='checkpoint', 
    log_x=True,
    metric_legend_name="Jaccard Sim",
    output_path = "/mnt/hdd-0/circuits-over-time/results/plots/graph_metrics/"
)

In [82]:
comparison_checkpoint = 143000

jaccard_reference_results = compute_weighted_jaccard_similarity_to_reference(in_circuit_df, comparison_checkpoint)
plot_graph_metric(
    jaccard_reference_results, 
    'jaccard_similarity', 
    perf_metric_dict, 
    f'Weighted Jaccard Similarity to Checkpoint {comparison_checkpoint} for {MODEL_NAME}', 
    y_ranges=((0, 1), (0, 6)), 
    left_y_title="Jaccard Similarity",
    x_axis_col='checkpoint', 
    log_x=True,
    metric_legend_name="Jaccard Sim",
    output_path = "/mnt/hdd-0/circuits-over-time/results/plots/graph_metrics/"
)

In [83]:
jaccard_reference_results = compute_ewma_weighted_jaccard_similarity(in_circuit_df, alpha=0.1)
plot_graph_metric(
    jaccard_reference_results, 
    'ewma_change_rate', 
    perf_metric_dict, 
    f'Exponential Weighted Average Graph Change Rate for {MODEL_NAME}', 
    y_ranges=((0, 1), (0, 6)), 
    left_y_title="Jaccard Similarity",
    x_axis_col='checkpoint_2', 
    log_x=True,
    metric_legend_name="Jaccard Sim",
    output_path = "/mnt/hdd-0/circuits-over-time/results/plots/graph_metrics/"
)