# Evaluate circuit faithfulness

In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import pandas as pd
import matplotlib.pyplot as plt
from transformer_lens import HookedTransformer
from src.graph import Graph

# def display_name(task):
#     if '-comma' in task:
#         task = task[:-6]
#     return display_name_dict[task]

def n_edges(model_name: str) -> int:
    model = HookedTransformer.from_pretrained(
        model_name,
        center_writing_weights=False,
        center_unembed=False,
        fold_ln=False,
    )
    model.cfg.use_split_qkv_input = True
    model.cfg.use_attn_result = True
    model.cfg.use_hook_mlp_in = True
    model.cfg.ungroup_grouped_query_attention = True
    g = Graph.from_model(model)
    return len(g.edges)

model_name = "meta-llama/Llama-3.2-1B"
model_name_noslash = 'Llama-3.2-1B-Instruct'
total_edges = n_edges(model_name)
print(f"Total number of edges: {total_edges}")

In [None]:
import plotly.subplots as sp
import plotly.graph_objects as go
import pandas as pd
import os

task_names = ['temp_toxicity_samp', 'adv-bias-3']
param_dict = {'temp_toxicity_samp': ['Toxicity', ['EAP', 'EAP-IG']],
                     'adv-bias-3': ['Name Bias', ['EAP', 'EAP-IG', 'EAP-IG-KL']]}

trace_colors = {
    'EAP': '#FFA15A', 
    'EAP-IG': '#19D3F3',
    'EAP-IG-KL': '#FF6692'
}

rows = 1
cols = len(task_names)
fig = sp.make_subplots(rows=rows, cols=cols, subplot_titles=[param_dict[task][0] for task in task_names])

# Iterate over tasks and add traces
for idx, task_name in enumerate(task_names):
    df = pd.read_csv(f'./inputs/{task_name}.csv')
    
    # Normalize faithfulness
    for losstype in param_dict[task_name][1]:
        df[f'normalized_faithfulness_{losstype}'] = (
            df[f'loss_{losstype}'] - df['corrupted_baseline']
        ) / (df['baseline'] - df['corrupted_baseline'])

    # Save processed data
    df.to_csv(f'./outputs/{task_name}_graph.csv', index=False)

    # Determine subplot position
    row = idx // cols + 1
    col = idx % cols + 1

    # Add traces with legend grouping
    show_legend = idx == 1  # Show legend only for the first subplot
    for losstype in param_dict[task_name][1]:
        fig.add_trace(go.Scatter(
            x=df[f'edges_{losstype}'], y=df[f'normalized_faithfulness_{losstype}'], mode='lines',
            name=losstype, line=dict(color=trace_colors[losstype]), legendgroup=losstype, showlegend=show_legend
        ), row=row, col=col)

    # Allow each subplot to have independent y-axis
    fig.update_yaxes(matches=None, row=row, col=col)

# Global layout settings
fig.update_layout(
    height=400, width=1000, title_text="",
    showlegend=True,
    legend=dict(
        orientation="h",  # Horizontal legend
        yanchor="bottom",  # Anchor to the bottom
        y=-0.4,  # Position below the plot
        xanchor="center",
        x=0.5
    )
)
fig.update_yaxes(range=[-0.1, 1.0], row=1, col=1)
fig.update_yaxes(range=[-0.1, 1.0], row=1, col=2)
fig.update_yaxes(nticks=8, row=1, col=1)
fig.update_xaxes(nticks=8, row=1, col=1)
fig.update_xaxes(nticks=8, row=1, col=2)

fig.update_layout(
    plot_bgcolor='white',  # White background
    paper_bgcolor='white'  # White figure background
)

# Ensure grid and zero lines are light gray
fig.update_xaxes(showgrid=True, gridcolor='lightgray', zeroline=True, zerolinecolor='lightgray')
fig.update_yaxes(showgrid=True, gridcolor='lightgray', zeroline=True, zerolinecolor='lightgray')

# X and Y labels (custom positioning)
fig.update_xaxes(title_text=f'Edges included (/{total_edges})')
fig.update_yaxes(title_text='Normalized faithfulness', row=1, col=1)

# Show figure
fig.show()


In [312]:
fig.write_image(f"./output/normalized_faithfulness.svg")

### Bias heatmaps

In [7]:
import torch

# Specify the path to the .pth file
pth_file_path = 'vul_heads_og_method_bias.pth'

# Load the tensor from the .pth file
act_patching_tensor = torch.load(pth_file_path, map_location=torch.device('cpu'))
act_patching_tensor.requires_grad = False
# Print the loaded tensor
act_patching_tensor.shape


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



torch.Size([16, 32])

In [8]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.graph import Graph


graph_path = "./full_adv_bias_EAP-IG_step3000_2357edges.json"

g = Graph.from_json(graph_path)

In [211]:
import pandas as pd

heatmap_dict = dict()

for node_str in g.nodes:
    node = g.nodes[node_str]
    if node.in_graph and "a" in node_str:
        print(node)
        heatmap_dict[node_str] = 0
        for p_edge in node.parent_edges:
            if p_edge.in_graph:
                heatmap_dict[node_str] += p_edge.score
                print(p_edge)
                print(p_edge.parent)
                print(p_edge.child)
        for c_edge in node.child_edges:
            if c_edge.in_graph:
                pass
                # print(c_edge)
                # print(c_edge.parent)
                # print(c_edge.child)

heatmap_dict

Node(a0.h2, in_graph: True)
Edge(input->a0.h2<q>, score: 0.0006232046871446073, in_graph: True)
Node(input, in_graph: True)
Node(a0.h2, in_graph: True)
Edge(input->a0.h2<k>, score: 0.0015502066817134619, in_graph: True)
Node(input, in_graph: True)
Node(a0.h2, in_graph: True)
Node(a0.h3, in_graph: True)
Edge(input->a0.h3<q>, score: 0.0010811270913109183, in_graph: True)
Node(input, in_graph: True)
Node(a0.h3, in_graph: True)
Node(a0.h6, in_graph: True)
Edge(input->a0.h6<k>, score: 0.0011417785426601768, in_graph: True)
Node(input, in_graph: True)
Node(a0.h6, in_graph: True)
Node(a0.h24, in_graph: True)
Edge(input->a0.h24<q>, score: 0.0004902833607047796, in_graph: True)
Node(input, in_graph: True)
Node(a0.h24, in_graph: True)
Node(a0.h25, in_graph: True)
Edge(input->a0.h25<q>, score: -0.0006833298248238862, in_graph: True)
Node(input, in_graph: True)
Node(a0.h25, in_graph: True)
Node(a0.h29, in_graph: True)
Edge(input->a0.h29<q>, score: 0.0015183762880042195, in_graph: True)
Node(input,

{'a0.h2': 0.002173411368858069,
 'a0.h3': 0.0010811270913109183,
 'a0.h6': 0.0011417785426601768,
 'a0.h24': 0.0004902833607047796,
 'a0.h25': -0.0006833298248238862,
 'a0.h29': 0.0015183762880042195,
 'a1.h3': 0.0007961218361742795,
 'a1.h9': 0.0008253492414951324,
 'a1.h12': -0.0004866595845669508,
 'a1.h20': 0.001112618891056627,
 'a1.h26': -0.0007599727832712233,
 'a1.h28': 0.000645423773676157,
 'a2.h4': 0.0011084941506851465,
 'a2.h5': 0.0005551895592361689,
 'a2.h6': 0.0022937243338674307,
 'a2.h14': 0.0007389038219116628,
 'a2.h15': 0.0005782812950201333,
 'a2.h22': 0.00047686375910416245,
 'a2.h25': -0.0005167612689547241,
 'a2.h27': -0.0005597788258455694,
 'a3.h0': -0.001348192774457857,
 'a3.h1': 0.00012399174738675356,
 'a3.h4': 0.0010873244609683752,
 'a3.h7': 0.0010975822806358337,
 'a3.h9': 0.0005862407269887626,
 'a3.h22': -0.0005706677911803126,
 'a4.h0': -0.002283307461766526,
 'a4.h1': -0.0006837844266556203,
 'a4.h2': -0.001426564995199442,
 'a4.h3': 0.001309645129

In [212]:
import torch

# Create a tensor to store the heatmap values
heatmap_tensor = torch.zeros((16, 32))  # Assuming a maximum of 16 layers and 32 heads

# Populate the tensor with values from heatmap_dict
for key, value in heatmap_dict.items():
    layer, head = map(int, key[1:].split('.h'))
    heatmap_tensor[layer, head] = value

heatmap_tensor

tensor([[ 0.0000e+00,  0.0000e+00,  2.1734e-03,  1.0811e-03,  0.0000e+00,
          0.0000e+00,  1.1418e-03,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  4.9028e-04,
         -6.8333e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.5184e-03,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  7.9612e-04,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  8.2535e-04,
          0.0000e+00,  0.0000e+00, -4.8666e-04,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.1126e-03,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -7.5997e-04,  0.0000e+00,  6.4542e-04,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.00

In [213]:
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots

tensor1 = act_patching_tensor
tensor2 = heatmap_tensor

fig = make_subplots(
    rows=1, cols=2,
    horizontal_spacing=0.05, 
    subplot_titles=("", ""),
)

# Left heatmap
fig.add_trace(
    go.Heatmap(
        z=tensor1.numpy(),
        colorscale="RdBu",
        zmid=0,
        showscale=True,
        colorbar=dict(
            title="",
            thickness=20,
            # Position the left colorbar a bit to the right of the left domain
            x=0.46,        # 0 = far left, 1 = far right
            xanchor="left"
        ),
    ),
    row=1, col=1
)

# Right heatmap
fig.add_trace(
    go.Heatmap(
        z=tensor2.numpy(),
        colorscale="RdBu",
        zmid=0,
        showscale=True,
        colorbar=dict(
            title="",
            thickness=20,
            # Place the right colorbar well to the right
            x=1.01,
            xanchor="left"
        ),
    ),
    row=1, col=2
)

# -- Make the left subplot narrower and shifted left
fig.update_xaxes(domain=[0.0, 0.45], row=1, col=1)  # 30% of total width
# -- Shift the right subplot further right
fig.update_xaxes(domain=[0.55, 1.0], row=1, col=2) # starts at 0.45, ends at 0.8

fig.update_yaxes(title_text="Layer", row=1, col=1)
fig.update_yaxes(title_text="Layer", row=1, col=2)
fig.update_xaxes(title_text="Head", row=1, col=1)
fig.update_xaxes(title_text="Head", row=1, col=2)
fig.update_yaxes(autorange="reversed", row=1, col=1)
fig.update_yaxes(autorange="reversed", row=1, col=2)

fig.update_layout(
    width=1500,                # <--- Force a smaller figure width in pixels
    height=500,
    font=dict(size=14)
)
fig.show()

In [214]:
fig.write_image("bias_heatmap.svg")

### Toxicity heatmaps

In [9]:
# Specify the path to the .pth file
pth_file_path = 'final_toxicity_tensor_0.pth'

# Load the tensor from the .pth file
act_patching_tensor = torch.load(pth_file_path, map_location=torch.device('cpu'))
act_patching_tensor.requires_grad = False
# Print the loaded tensor
act_patching_tensor.shape


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



torch.Size([16, 32])

In [316]:
graph_path = "./toxicity_EAP_step12000_11853edges.json"

g = Graph.from_json(graph_path)

In [317]:
import pandas as pd

heatmap_dict = dict()

for node_str in g.nodes:
    node = g.nodes[node_str]
    if node.in_graph and "a" in node_str:
        print(node)
        heatmap_dict[node_str] = 0
        for p_edge in node.parent_edges:
            if p_edge.in_graph:
                heatmap_dict[node_str] += p_edge.score
                print(p_edge)
                print(p_edge.parent)
                print(p_edge.child)
        for c_edge in node.child_edges:
            if c_edge.in_graph:
                pass
                # print(c_edge)
                # print(c_edge.parent)
                # print(c_edge.child)

heatmap_dict

Node(a0.h0, in_graph: True)
Edge(input->a0.h0<k>, score: 0.022330591455101967, in_graph: True)
Node(input, in_graph: True)
Node(a0.h0, in_graph: True)
Edge(input->a0.h0<q>, score: 0.0007328155334107578, in_graph: True)
Node(input, in_graph: True)
Node(a0.h0, in_graph: True)
Edge(input->a0.h0<v>, score: -0.0005695744184777141, in_graph: True)
Node(input, in_graph: True)
Node(a0.h0, in_graph: True)
Node(a0.h1, in_graph: True)
Edge(input->a0.h1<q>, score: -0.0009942230535671115, in_graph: True)
Node(input, in_graph: True)
Node(a0.h1, in_graph: True)
Edge(input->a0.h1<k>, score: -0.00830574706196785, in_graph: True)
Node(input, in_graph: True)
Node(a0.h1, in_graph: True)
Edge(input->a0.h1<v>, score: 0.0010062727378681302, in_graph: True)
Node(input, in_graph: True)
Node(a0.h1, in_graph: True)
Node(a0.h2, in_graph: True)
Edge(input->a0.h2<q>, score: -0.0060812318697571754, in_graph: True)
Node(input, in_graph: True)
Node(a0.h2, in_graph: True)
Edge(input->a0.h2<k>, score: -0.011140699498355

{'a0.h0': 0.02249383257003501,
 'a0.h1': -0.008293697377666831,
 'a0.h2': -0.016595377295743674,
 'a0.h3': -0.008078590326476842,
 'a0.h4': 0.009157457039691508,
 'a0.h5': -0.009801684529520571,
 'a0.h6': -0.005132298218086362,
 'a0.h7': 0.006465522106736898,
 'a0.h8': 0.001052121282555163,
 'a0.h9': 0.0026952028274536133,
 'a0.h10': 0.0014006445417180657,
 'a0.h11': -0.0010506869293749332,
 'a0.h12': -0.002224645344540477,
 'a0.h13': 0.0057983361184597015,
 'a0.h14': -0.000484858377603814,
 'a0.h15': -0.003959375433623791,
 'a0.h16': 0.000522392860148102,
 'a0.h17': 0.0011023316765204072,
 'a0.h18': -0.004075650125741959,
 'a0.h19': 0.0007447243551723659,
 'a0.h20': -0.003921865951269865,
 'a0.h22': 0.00457946490496397,
 'a0.h24': -0.000793035258539021,
 'a0.h27': -0.001701776753179729,
 'a0.h28': 0.0011482920963317156,
 'a0.h30': -0.00509767746552825,
 'a1.h0': -0.003025828074896708,
 'a1.h1': 0.0025858654407784343,
 'a1.h2': -0.0023098314995877445,
 'a1.h3': -0.013575639808550477,
 

In [318]:
import torch

# Create a tensor to store the heatmap values
heatmap_tensor = torch.zeros((16, 32))  # Assuming a maximum of 16 layers and 32 heads

# Populate the tensor with values from heatmap_dict
for key, value in heatmap_dict.items():
    layer, head = map(int, key[1:].split('.h'))
    heatmap_tensor[layer, head] = value

heatmap_tensor

tensor([[ 2.2494e-02, -8.2937e-03, -1.6595e-02, -8.0786e-03,  9.1575e-03,
         -9.8017e-03, -5.1323e-03,  6.4655e-03,  1.0521e-03,  2.6952e-03,
          1.4006e-03, -1.0507e-03, -2.2246e-03,  5.7983e-03, -4.8486e-04,
         -3.9594e-03,  5.2239e-04,  1.1023e-03, -4.0757e-03,  7.4472e-04,
         -3.9219e-03,  0.0000e+00,  4.5795e-03,  0.0000e+00, -7.9304e-04,
          0.0000e+00,  0.0000e+00, -1.7018e-03,  1.1483e-03,  0.0000e+00,
         -5.0977e-03,  0.0000e+00],
        [-3.0258e-03,  2.5859e-03, -2.3098e-03, -1.3576e-02,  2.0040e-02,
          2.5965e-02, -1.2066e-02,  2.3528e-03,  9.2063e-04, -3.1961e-03,
          8.3352e-03,  5.3239e-03, -2.3245e-03,  1.2307e-03,  1.2243e-03,
          6.2842e-04,  7.6238e-04, -8.0170e-03, -6.0820e-03,  6.4064e-04,
          4.9053e-04, -6.1219e-03, -2.8733e-03,  2.1688e-02, -1.7642e-03,
         -2.0066e-02, -1.1349e-03,  8.3673e-04, -4.4289e-04, -2.9220e-03,
         -3.3788e-03,  7.8253e-03],
        [-6.4866e-03, -1.4211e-03, -1.17

In [319]:
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots

tensor1 = act_patching_tensor
tensor2 = heatmap_tensor

fig = make_subplots(
    rows=1, cols=2,
    horizontal_spacing=0.05, 
    subplot_titles=("", ""),
)

# Left heatmap
fig.add_trace(
    go.Heatmap(
        z=tensor1.numpy(),
        colorscale="RdBu",
        zmid=0,
        showscale=True,
        colorbar=dict(
            title="",
            thickness=20,
            # Position the left colorbar a bit to the right of the left domain
            x=0.46,        # 0 = far left, 1 = far right
            xanchor="left"
        ),
    ),
    row=1, col=1
)

# Right heatmap
fig.add_trace(
    go.Heatmap(
        z=tensor2.numpy(),
        colorscale="RdBu",
        zmid=0,
        showscale=True,
        colorbar=dict(
            title="",
            thickness=20,
            # Place the right colorbar well to the right
            x=1.01,
            xanchor="left"
        ),
    ),
    row=1, col=2
)

# -- Make the left subplot narrower and shifted left
fig.update_xaxes(domain=[0.0, 0.45], row=1, col=1)  # 30% of total width
# -- Shift the right subplot further right
fig.update_xaxes(domain=[0.55, 1.0], row=1, col=2) # starts at 0.45, ends at 0.8

fig.update_yaxes(title_text="Layer", row=1, col=1)
fig.update_yaxes(title_text="Layer", row=1, col=2)
fig.update_xaxes(title_text="Head", row=1, col=1)
fig.update_xaxes(title_text="Head", row=1, col=2)
fig.update_yaxes(autorange="reversed", row=1, col=1)
fig.update_yaxes(autorange="reversed", row=1, col=2)

fig.update_layout(
    width=1500,                # <--- Force a smaller figure width in pixels
    height=500,
    font=dict(size=14)
)
fig.show()

In [320]:
fig.write_image("toxicity_heatmap.svg")

# Debiasing graphs

In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.graph import Graph
import pandas as pd

In [None]:
toxicity_graph_path = "./toxicity_EAP_step12000_11853edges.json"
bias_graph_path = "./full_adv_bias_EAP-IG_step3000_2357edges.json"

toxicity_graph = Graph.from_json(toxicity_graph_path)
bias_graph = Graph.from_json(bias_graph_path)
ablated_bias = Graph.from_json(bias_graph_path)

num_edges = 0
for node_str in bias_graph.nodes:
    toxicity_node = toxicity_graph.nodes[node_str]
    bias_node = bias_graph.nodes[node_str]
    ablated_bias_node = ablated_bias.nodes[node_str]
    
    for bias_edge, toxicity_edge, ablated_bias_edge in zip(bias_node.parent_edges, toxicity_node.parent_edges, ablated_bias_node.parent_edges):
        assert bias_edge.name == toxicity_edge.name == ablated_bias_edge.name
        # keeping edge in bias and not in toxicity
        if bias_edge.in_graph and toxicity_edge.in_graph:
            num_edges += 1
            ablated_bias_edge.in_graph = False

print("number of common edges", num_edges)
total_edges = 0
for node_str in ablated_bias.nodes:
    node = ablated_bias.nodes[node_str]
    for p_edge in node.parent_edges:
        if p_edge.in_graph:
            total_edges += 1
print("number of edges in bias circuit after turning off common", total_edges)

In [None]:
ablated_bias.to_json("./ablated_bias_EAP-IG_step3000_2357edges_with_toxicity_EAP_step12000_11853edges.json")

In [None]:
# turn all edges in bias circuit off to create empty graph
false_bias = Graph.from_json(bias_graph_path)
num_edges = 0
for node_str in bias_graph.nodes:
    toxicity_node = toxicity_graph.nodes[node_str]
    bias_node = bias_graph.nodes[node_str]
    ablated_bias_node = false_bias.nodes[node_str]
    
    for bias_edge, toxicity_edge, ablated_bias_edge in zip(bias_node.parent_edges, toxicity_node.parent_edges, ablated_bias_node.parent_edges):
        assert bias_edge.name == toxicity_edge.name == ablated_bias_edge.name
        num_edges += 1
        ablated_bias_edge.in_graph = False

false_bias.to_json("./no_edges_bias.json")