In [None]:
import os
import math
import json
import time
import torch
import torch.nn.functional as F
import pandas as pd
import random
from random import shuffle
from collections import defaultdict
from rich import print as rprint
from functools import partial

import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
from transformers import LlamaForCausalLM

from eap.metrics import logit_diff, direct_logit
from eap.graph import Graph
from eap.dataset import EAPDataset
from eap.attribute import attribute
from eap.evaluate import evaluate_graph, evaluate_baseline,get_circuit_logits

In [None]:
MODEL_PATH = "meta-llama/Llama-2-7b-chat-hf"
model_name = "Llama-2-7b-chat-hf"

model = HookedTransformer.from_pretrained(MODEL_PATH, device="cuda:0", fold_ln=False, center_writing_weights=False, center_unembed=False)
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True

# If model has Grouped-Query Attention (GQA), make the params below false.
# model.cfg.use_split_qkv_input = False

In [None]:
### Temporal Dataset ###

# Set target year for clean run
target_time = "1999"
target_category = "sports"
folder_path = './data/Temporal'

matched_file = None
for filename in os.listdir(folder_path):
    if filename.endswith('.json') and f'time_{target_category}' in filename:
        matched_file = os.path.join(folder_path, filename)
        break

# Load the matched JSON file
if matched_file:
    with open(matched_file, 'r') as f:
        data_json = json.load(f)
    print(f"Loaded file: {matched_file}")
else:
    raise FileNotFoundError(f"No file matching category '{target_category}' found.")

# Extract prompt templates and samples
prompt_template = data_json["prompt_templates"][0]
samples = data_json["samples"]

# Generate dataset rows
dataset_rows = []

# Filter samples for the target_time
for sample in samples:
    if sample["time"] == target_time:
        subject = sample["subject"]
        time_clean = sample["time"]
        object_clean = sample["object"]

        # Find corrupted samples (different time and different object for the same subject)
        for corrupted_sample in samples:
            if corrupted_sample["subject"] == subject and corrupted_sample["time"] != time_clean:
                time_corrupted = corrupted_sample["time"]
                object_corrupted = corrupted_sample["object"]

                # Skip if objects are the same
                if object_clean == object_corrupted:
                    continue

                # Tokenize object labels (assuming model.tokenizer is predefined)
                clean_token_ids = model.tokenizer(object_clean, add_special_tokens=False).input_ids
                corrupted_token_ids = model.tokenizer(object_corrupted, add_special_tokens=False).input_ids

                # Append row with all required columns
                dataset_rows.append({
                    "clean": prompt_template.format(time=time_clean, subject=subject),
                    "corrupted": prompt_template.format(time=time_corrupted, subject=subject),
                    "country_idx": clean_token_ids[0],  # First token of object in clean run
                    "corrupted_country_idx": corrupted_token_ids[0]  # First token of object in corrupted run
                })
                
                # # If the model is Phi, use the code below
                # dataset_rows.append({
                #     "clean": prompt_template.format(time_clean, subject),
                #     "corrupted": prompt_template.format(time_corrupted, subject),
                #     "country_idx": clean_token_ids[1],
                #     "corrupted_country_idx": corrupted_token_ids[1]
                # })

# Create DataFrame and save as CSV
df = pd.DataFrame(dataset_rows)
df.to_csv(f'./data/{target_time}_temporal_knowledge_{target_category}.csv', index=False)

print(f"Filtered dataset created and saved to '{target_time}_temporal_knowledge_{target_category}.csv'")

In [None]:
############################# Temporal Dataset ###########################
### For More Detailed Circuit, Use this version for Dataset Generation ###

# Set target year and category
target_time = "1999"
target_category = "sports"

matched_file = None
for filename in os.listdir(folder_path):
    if filename.endswith('.json') and f'time_{target_category}' in filename:
        matched_file = os.path.join(folder_path, filename)
        break

if not matched_file:
    raise FileNotFoundError(f"No file matching category '{target_category}' found.")

with open(matched_file, 'r') as f:
    data_json = json.load(f)
print(f"Loaded file: {matched_file}")

prompt_template = data_json["prompt_templates"][0]
samples = data_json["samples"]

subject_rows = defaultdict(list)

for sample in samples:
    if sample["time"] != target_time:
        continue
    subject = sample["subject"]
    time_clean = sample["time"]
    object_clean = sample["object"]

    for corrupted in samples:
        if corrupted["subject"] != subject or corrupted["time"] == time_clean:
            continue
        object_corr = corrupted["object"]
        if object_clean == object_corr:
            continue

        clean_ids = model.tokenizer(object_clean, add_special_tokens=False).input_ids
        corr_ids  = model.tokenizer(object_corr,   add_special_tokens=False).input_ids

        row = {
            "clean": prompt_template.format(time=time_clean, subject=subject),
            "corrupted": prompt_template.format(time=corrupted["time"], subject=subject),
            "country_idx": clean_ids[0],
            "corrupted_country_idx": corr_ids[0]
        }
        subject_rows[subject].append(row)

output_dir = './data'

for subject, rows in subject_rows.items():
    if not rows:
        continue
    df_subj = pd.DataFrame(rows)
    safe_subj = subject.replace(" ", "_").lower()
    out_path = os.path.join(
        output_dir,
        f"{target_time}_{target_category}_{safe_subj}.csv"
    )
    df_subj.to_csv(out_path, index=False)
    print(f"Saved {len(rows)} rows for subject '{subject}' → {out_path}")

In [None]:
### Time Invariant Dataset ###

target_category = "roman_numerals"
folder_path = './data/Invariant'

matched_file = None
for filename in os.listdir(folder_path):
    if filename.endswith('.json') and f'{target_category}' in filename:
        matched_file = os.path.join(folder_path, filename)
        break

# Load the matched JSON file
if matched_file:
    with open(matched_file, 'r') as f:
        data_json = json.load(f)
    print(f"Loaded file: {matched_file}")
else:
    raise FileNotFoundError(f"No file matching category '{target_category}' found.")

# Extract the prompt template and samples
prompt_template = data_json["prompt_templates"][0]
samples = data_json["samples"]

dataset_rows = []

for sample in samples:
    # Clean sample
    subject_clean = sample["subject"]
    object_clean = sample["object"]
    clean_str = prompt_template.format(subject_clean)

    # All corrupted candidates: samples with a different object
    corrupted_candidates = [
        s for s in samples if s["object"] != object_clean
    ]
    if not corrupted_candidates:
        continue

    # Shuffle candidates for randomness
    shuffle(corrupted_candidates)

    # Iterate over all corrupted samples for the same subject
    for corr_samp in corrupted_candidates:
        subject_corr = corr_samp["subject"]
        object_corr = corr_samp["object"]
        corrupted_str = prompt_template.format(subject_corr)

        # Tokenize the clean and corrupted strings
        model.cfg.default_prepend_bos = False  # Prevent BOS token auto-insertion
        clean_tokens = model.to_str_tokens(clean_str)
        corrupted_tokens = model.to_str_tokens(corrupted_str)

        # Check token length consistency
        if len(clean_tokens) == len(corrupted_tokens):
            # Token IDs for clean and corrupted objects
            clean_obj_token_ids = model.tokenizer(object_clean, add_special_tokens=False).input_ids
            corr_obj_token_ids = model.tokenizer(object_corr, add_special_tokens=False).input_ids
            clean_obj_idx = clean_obj_token_ids[0] if clean_obj_token_ids else None
            corr_obj_idx = corr_obj_token_ids[0] if corr_obj_token_ids else None
            
            # If the model is Phi, use the code below
            # clean_obj_idx = clean_obj_token_ids[1] if clean_obj_token_ids else None
            # corr_obj_idx = corr_obj_token_ids[1] if corr_obj_token_ids else None

            # Add a row for each corrupted example
            dataset_rows.append({
                "clean": clean_str,
                "corrupted": corrupted_str,
                "object_clean": object_clean,
                "object_corrupted": object_corr,
                "country_idx": clean_obj_idx,
                "corrupted_country_idx": corr_obj_idx
            })

print(f"Total rows with same token length: {len(dataset_rows)}")

# Save the dataset to CSV
df = pd.DataFrame(dataset_rows)
df.to_csv(f'./data/{target_category}.csv', index=False)
print(f"Filtered dataset saved to {target_category}.csv")

In [None]:
### Temporal Dataset with Alias ###

target_time = "2009"
target_category = "sports"
folder_path = './data/Temporal'

matched_file = None
for filename in os.listdir(folder_path):
    if filename.endswith('.json') and f'time_{target_category}' in filename:
        matched_file = os.path.join(folder_path, filename)
        break

# Load the matched JSON file
if matched_file:
    with open(matched_file, 'r') as f:
        data_json = json.load(f)
    print(f"Loaded file: {matched_file}")
else:
    raise FileNotFoundError(f"No file matching category '{target_category}' found.")

prompt_template = data_json["prompt_templates"][0]
samples = data_json["samples"]
alias_time_list = data_json.get("alias_time_templates", [])
alias_dict = { item["year"]: item["template"] for item in alias_time_list }

dataset_rows = []
model.cfg.default_prepend_bos = False

for sample in samples:
    if sample["time"] == target_time:
        subject = sample["subject"]
        time_clean = sample["time"]
        object_clean = sample["object"]

        # Clean prompt
        alias_prompt_template = alias_dict.get(time_clean, None)
        if alias_prompt_template:
            clean_str = alias_prompt_template.format(subject)
        else:
            clean_str = prompt_template.format(time_clean, subject)

        # Corrupted samples
        corrupted_candidates = [
            s for s in samples
            if s["subject"] == subject and s["time"] != time_clean and s["object"] != object_clean
        ]
        shuffle(corrupted_candidates)

        # corrupted candidates
        for corrupted_sample in corrupted_candidates:
            time_corrupted = corrupted_sample["time"]
            object_corrupted = corrupted_sample["object"]

            # alias template
            alias_prompt_corrupted = alias_dict.get(time_corrupted, None)
            if alias_prompt_corrupted:
                corrupted_str = alias_prompt_corrupted.format(subject)
            else:
                corrupted_str = prompt_template.format(time_corrupted, subject)

            # to_str_tokens tokenize with length check
            clean_tokens = model.to_str_tokens(clean_str)
            corrupted_tokens = model.to_str_tokens(corrupted_str)

            # skip to pass attention mask shape mismatch
            if len(clean_tokens) != len(corrupted_tokens):
                continue

            # country_idx / corrupted_country_idx
            clean_obj_token_ids = model.tokenizer(object_clean, add_special_tokens=False).input_ids
            corrupted_obj_token_ids = model.tokenizer(object_corrupted, add_special_tokens=False).input_ids

            if not clean_obj_token_ids or not corrupted_obj_token_ids:
                continue

            clean_obj_idx = clean_obj_token_ids[0]
            corr_obj_idx = corrupted_obj_token_ids[0]
            
            # If the model is Phi, use the code below
            # clean_obj_idx = clean_obj_token_ids[1]
            # corr_obj_idx = corrupted_obj_token_ids[1]

            dataset_rows.append({
                "clean": clean_str,
                "corrupted": corrupted_str,
                "object_clean": object_clean,
                "object_corrupted": object_corrupted,
                "country_idx": clean_obj_idx,
                "corrupted_country_idx": corr_obj_idx
            })

df = pd.DataFrame(dataset_rows)
out_csv_path = f'./data/{target_time}_temporal_knowledge_{target_category}_alias.csv'
df.to_csv(out_csv_path, index=False)
print(f"Total rows with same token length: {len(dataset_rows)}")
print(f"Filtered dataset created and saved to '{out_csv_path}'")

In [None]:
### Load dataset into EAPDataset ###
dataset = EAPDataset(filename=f'./data/{target_time}_temporal_knowledge_{target_category}.csv', task='fact-retrieval')

# subject = "Nicolas Anelka"
# dataset = EAPDataset(filename=f"{target_time}_{target_category}_{subject}.csv")

# dataset = EAPDataset(filename="./data/{target_category}.csv", task='fact-retrieval')

# dataset = EAPDataset(filename=f'./data/{target_time}_temporal_knowledge_{target_category}_alias.csv', task='fact-retrieval')
dataloader = dataset.to_dataloader(batch_size=1)

In [None]:
g = Graph.from_model(model)
start_time = time.time()

attribute(
    model, 
    g, 
    dataloader,  # Updated to use dataloader
    partial(logit_diff, loss=True, mean=True), 
    method='EAP-IG',  # For multiple samples
    ig_steps=100
)

g.apply_topn(5000, absolute=True)
g.prune_dead_nodes()
g.to_json(f'./graphs/{model_name}/graph_{target_time}_{target_category}.json')
# g.to_json('./graphs/{model_name}/{target_category}.json')

gz = g.to_graphviz()
gz.draw(f'./graphs/{model_name}/graph_{target_time}_{target_category}.png', prog='dot')
# gz.draw(f'./graphs/{model_name}/{target_category}.png', prog='dot')

end_time = time.time()
execution_time = end_time - start_time
print(f"Execution_Time: {execution_time} seconds")

In [None]:
def get_component_logits(logits, model, answer_token, top_k=10):
    logits = utils.remove_batch_dim(logits)
    # print(heads_out[head_name].shape)
    probs = logits.softmax(dim=-1)
    token_probs = probs[-1]
    answer_str_token = model.to_string(answer_token)
    sorted_token_probs, sorted_token_values = token_probs.sort(descending=True)
    # Janky way to get the index of the token in the sorted list - I couldn't find a better way?
    correct_rank = torch.arange(len(sorted_token_values))[
        (sorted_token_values == answer_token).cpu()
    ].item()
    # answer_ranks = []
    # answer_ranks.append((answer_str_token, correct_rank))
    # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.
    # rprint gives rich text printing
    rprint(
        f"Performance on answer token:\n[b]Rank: {correct_rank: <8} Logit: {logits[-1, answer_token].item():5.2f} Prob: {token_probs[answer_token].item():6.2%} Token: |{answer_str_token}|[/b]"
    )
    for i in range(top_k):
        print(
            f"Top {i}th token. Logit: {logits[-1, sorted_token_values[i]].item():5.2f} Prob: {sorted_token_probs[i].item():6.2%} Token: |{model.to_string(sorted_token_values[i])}|"
        )
    # rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}")
    
def CRS(
    baseline_perf: float, 
    circuit_perf: float, 
    alpha=1.0, 
    sf_bothpos=1.0,   # both are positive
    sf_bothneg=0.5,   # both are negative
    sf_bneg_cpos=0.8, # baseline < 0, circuit > 0
    sf_bpos_cneg=0.6, # baseline > 0, circuit < 0
    eps=1e-9
) -> float:
    """
    Computes a single score between 0 and 100:
      - baseline_perf: performance of the original system (float)
      - circuit_perf: performance of the circuit-modified system (float)
      - alpha: sensitivity to error
      - sf_*: weighting factors for different sign scenarios
      - eps: small constant to avoid division by zero or near-zero

    return: float in [0, 100], indicating how much better/worse the circuit is compared to the original
    """
    B = baseline_perf
    C = circuit_perf

    # ---- Determine the sign scenario ----
    if B > 0:
        if C >= 0:
            sign_factor = sf_bothpos
        else:
            sign_factor = sf_bpos_cneg
    else:
        # B <= 0
        if C >= 0:
            sign_factor = sf_bneg_cpos
        else:
            sign_factor = sf_bothneg

    # ---- Ideal case: both positive and C >= B => full score ----
    if B > 0 and C >= B:
        return 100.0

    # ---- Compute distance ----
    absB = abs(B) if abs(B) > eps else eps  # replace near-zero |B| with eps

    if B > 0:
        # When B is positive: higher is better => dist = max(0, B - C)
        dist_val = max(0.0, B - C)
    else:
        # When B <= 0: zero or positive is better => dist = max(0, |C|)
        dist_val = max(0.0, abs(C))

    # ---- Normalize distance ----
    dist_ratio = dist_val / absB

    # ---- Convert distance to similarity (0 ~ 1) using exponential decay ----
    similarity = math.exp(- alpha * dist_ratio)

    # ---- Final score (scaled by sign factor) ----
    final_score = 100.0 * sign_factor * similarity
    return final_score

In [None]:
baseline = evaluate_baseline(model, dataloader, partial(logit_diff, loss=False, mean=False)).mean().item()
results = evaluate_graph(model, g, dataloader, partial(logit_diff, loss=False, mean=False)).mean().item()

print(f"Original performance was {baseline}; the circuit's performance is {results}")

# CRS
final_metric = CRS(
    baseline_perf=baseline, 
    circuit_perf=results, 
    alpha=1.0,          
    sf_bothpos=1.0,     
    sf_bothneg=0.5,     
    sf_bneg_cpos=0.8,   # baseline<0, circuit>0
    sf_bpos_cneg=0.6    # baseline>0, circuit<0
)
print(f"Circuit Single Score (0~100) = {final_metric:.2f}")

In [None]:
### Simplifying graph with threshold ###

tau = 0.1  # example threshold

g.apply_threshold(threshold=tau, absolute=False)

important_node_ids = set()
for edge_name, edge in g.edges.items():
    if edge.in_graph:
        important_node_ids.add(edge.parent.name)
        important_node_ids.add(edge.child.name)

for node_name, node in g.nodes.items():
    if node_name not in important_node_ids:
        node.in_graph = False
        
for node_name in list(g.nodes.keys()):
    if not g.nodes[node_name].in_graph:
        del g.nodes[node_name]

for edge_name in list(g.edges.keys()):
    if not g.edges[edge_name].in_graph:
        del g.edges[edge_name]

gz = g.to_graphviz()
gz.draw(f'./graphs/{model_name}/simplified_graph_{target_time}_{target_category}.png', prog='dot')
# gz = g.to_graphviz_enhanced(score_threshold=0.3, threshold_type="below", highlight_nodes=["a15.h0", "a18.h3"],)
# gz.draw(f'./graphs/{model_name}/simplified_graph_{target_time}_{target_category}.png', prog='dot')

remain_node_count = len(g.nodes)
remain_edge_count = len(g.edges)

print("Simplified circuit creation complete!")
print(f"Simplfied Graph: {remain_node_count} nodes, {remain_edge_count} edges")