# Evaluate circuit faithfulness

In [9]:
import os
import sys
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from transformer_lens import HookedTransformer
from src.graph import Graph


# def display_name(task):
#     if '-comma' in task:
#         task = task[:-6]
#     return display_name_dict[task]

def n_edges(model_name: str) -> int:
    model = HookedTransformer.from_pretrained(
        model_name,
        center_writing_weights=False,
        center_unembed=False,
        fold_ln=False,
    )
    model.cfg.use_split_qkv_input = True
    model.cfg.use_attn_result = True
    model.cfg.use_hook_mlp_in = True
    model.cfg.ungroup_grouped_query_attention = True
    g = Graph.from_model(model)
    return len(g.edges)

def edge_counter(g: Graph) -> int:
    total_edges = 0
    for node_str in g.nodes:
        node = g.nodes[node_str]
        for p_edge in node.parent_edges:
            if p_edge.in_graph:
                total_edges += 1
    return total_edges

model_name = "meta-llama/Llama-3.2-1B"
model_name_noslash = 'Llama-3.2-1B-Instruct'
total_edges = n_edges(model_name)
print(f"Total number of edges: {total_edges}")

Loaded pretrained model meta-llama/Llama-3.2-1B into HookedTransformer
Total number of edges: 386713


In [2]:
import plotly.subplots as sp
import plotly.graph_objects as go
import pandas as pd
import os

task_names = ['temp_toxicity_samp', 'adv-bias-3']
param_dict = {'temp_toxicity_samp': ['Toxicity', ['EAP', 'EAP-IG']],
                     'adv-bias-3': ['Name Bias', ['EAP', 'EAP-IG', 'EAP-IG-KL']]}

trace_colors = {
    'EAP': '#FFA15A', 
    'EAP-IG': '#19D3F3',
    'EAP-IG-KL': '#FF6692'
}

rows = 1
cols = len(task_names)
fig = sp.make_subplots(rows=rows, cols=cols, subplot_titles=[param_dict[task][0] for task in task_names])

# Iterate over tasks and add traces
for idx, task_name in enumerate(task_names):
    df = pd.read_csv(f'./inputs/{task_name}.csv')
    
    # Normalize faithfulness
    for losstype in param_dict[task_name][1]:
        df[f'normalized_faithfulness_{losstype}'] = (
            df[f'loss_{losstype}'] - df['corrupted_baseline']
        ) / (df['baseline'] - df['corrupted_baseline'])

    # Save processed data
    df.to_csv(f'./outputs/{task_name}_graph.csv', index=False)

    # Determine subplot position
    row = idx // cols + 1
    col = idx % cols + 1

    # Add traces with legend grouping
    show_legend = idx == 1  # Show legend only for the first subplot
    for losstype in param_dict[task_name][1]:
        fig.add_trace(go.Scatter(
            x=df[f'edges_{losstype}'], y=df[f'normalized_faithfulness_{losstype}'], mode='lines',
            name=losstype, line=dict(color=trace_colors[losstype]), legendgroup=losstype, showlegend=show_legend
        ), row=row, col=col)

    # Allow each subplot to have independent y-axis
    fig.update_yaxes(matches=None, row=row, col=col)

# Global layout settings
fig.update_layout(
    height=400, width=1000, title_text="",
    showlegend=True,
    legend=dict(
        orientation="h",  # Horizontal legend
        yanchor="bottom",  # Anchor to the bottom
        y=-0.4,  # Position below the plot
        xanchor="center",
        x=0.5
    )
)
fig.update_yaxes(range=[-0.1, 1.0], row=1, col=1)
fig.update_yaxes(range=[-0.1, 1.0], row=1, col=2)
fig.update_yaxes(nticks=8, row=1, col=1)
fig.update_xaxes(nticks=8, row=1, col=1)
fig.update_xaxes(nticks=8, row=1, col=2)

fig.update_layout(
    plot_bgcolor='white',  # White background
    paper_bgcolor='white'  # White figure background
)

# Ensure grid and zero lines are light gray
fig.update_xaxes(showgrid=True, gridcolor='lightgray', zeroline=True, zerolinecolor='lightgray')
fig.update_yaxes(showgrid=True, gridcolor='lightgray', zeroline=True, zerolinecolor='lightgray')

# X and Y labels (custom positioning)
fig.update_xaxes(title_text=f'Edges included (/{total_edges})')
fig.update_yaxes(title_text='Normalized faithfulness', row=1, col=1)

# Show figure
fig.show()
fig.write_image(f"./outputs/normalized_faithfulness.svg")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


### Bias heatmaps

In [3]:
# Specify the path to the .pth file
pth_file_path = './inputs/vul_heads_og_method_bias.pth'

# Load the tensor from the .pth file
act_patching_tensor = torch.load(pth_file_path, map_location=torch.device('cpu'))
act_patching_tensor.requires_grad = False
# Print the loaded tensor
act_patching_tensor.shape

graph_path = "./inputs/adv-bias-3_EAP-IG-KL_step7000_6908edges.json"
g = Graph.from_json(graph_path)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [4]:
heatmap_dict = dict()

for node_str in g.nodes:
    node = g.nodes[node_str]
    if node.in_graph and "a" in node_str:
        heatmap_dict[node_str] = 0
        for p_edge in node.parent_edges:
            if p_edge.in_graph:
                heatmap_dict[node_str] += p_edge.score
        for c_edge in node.child_edges:
            if c_edge.in_graph:
                pass

# Create a tensor to store the heatmap values
heatmap_tensor = torch.zeros((16, 32))  # Assuming a maximum of 16 layers and 32 heads

# Populate the tensor with values from heatmap_dict
for key, value in heatmap_dict.items():
    layer, head = map(int, key[1:].split('.h'))
    heatmap_tensor[layer, head] = value

In [5]:
tensor1 = act_patching_tensor
tensor2 = heatmap_tensor

fig = make_subplots(
    rows=1, cols=2,
    horizontal_spacing=0.05, 
    subplot_titles=("", ""),
)

# Left heatmap
fig.add_trace(
    go.Heatmap(
        z=tensor1.numpy(),
        colorscale="RdBu",
        zmid=0,
        showscale=True,
        colorbar=dict(
            title="",
            thickness=20,
            # Position the left colorbar a bit to the right of the left domain
            x=0.46,        # 0 = far left, 1 = far right
            xanchor="left"
        ),
    ),
    row=1, col=1
)

# Right heatmap
fig.add_trace(
    go.Heatmap(
        z=tensor2.numpy(),
        colorscale="RdBu",
        zmid=0,
        showscale=True,
        colorbar=dict(
            title="",
            thickness=20,
            # Place the right colorbar well to the right
            x=1.01,
            xanchor="left"
        ),
    ),
    row=1, col=2
)

# -- Make the left subplot narrower and shifted left
fig.update_xaxes(domain=[0.0, 0.45], row=1, col=1)  # 30% of total width
# -- Shift the right subplot further right
fig.update_xaxes(domain=[0.55, 1.0], row=1, col=2) # starts at 0.45, ends at 0.8

fig.update_yaxes(title_text="Layer", row=1, col=1)
fig.update_yaxes(title_text="Layer", row=1, col=2)
fig.update_xaxes(title_text="Head", row=1, col=1)
fig.update_xaxes(title_text="Head", row=1, col=2)
fig.update_yaxes(autorange="reversed", row=1, col=1)
fig.update_yaxes(autorange="reversed", row=1, col=2)

fig.update_layout(
    width=1500,                # <--- Force a smaller figure width in pixels
    height=500,
    font=dict(size=14)
)
fig.show()

fig.write_image("./outputs/bias_heatmap.svg")

### Toxicity heatmaps

In [6]:
# Specify the path to the .pth file
pth_file_path = './inputs/final_toxicity_tensor_0.pth'

# Load the tensor from the .pth file
act_patching_tensor = torch.load(pth_file_path, map_location=torch.device('cpu'))
act_patching_tensor.requires_grad = False
# Print the loaded tensor
act_patching_tensor.shape

# load graph
graph_path = "./inputs/toxicity-samp_EAP-IG_step19000_18922edges.json"
g = Graph.from_json(graph_path)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [7]:
heatmap_dict = dict()

for node_str in g.nodes:
    node = g.nodes[node_str]
    if node.in_graph and "a" in node_str:
        heatmap_dict[node_str] = 0
        for p_edge in node.parent_edges:
            if p_edge.in_graph:
                heatmap_dict[node_str] += p_edge.score
        for c_edge in node.child_edges:
            if c_edge.in_graph:
                pass

# Create a tensor to store the heatmap values
heatmap_tensor = torch.zeros((16, 32))  # Assuming a maximum of 16 layers and 32 heads

# Populate the tensor with values from heatmap_dict
for key, value in heatmap_dict.items():
    layer, head = map(int, key[1:].split('.h'))
    heatmap_tensor[layer, head] = value

In [8]:
tensor1 = act_patching_tensor
tensor2 = heatmap_tensor

fig = make_subplots(
    rows=1, cols=2,
    horizontal_spacing=0.05, 
    subplot_titles=("", ""),
)

# Left heatmap
fig.add_trace(
    go.Heatmap(
        z=tensor1.numpy(),
        colorscale="RdBu",
        zmid=0,
        showscale=True,
        colorbar=dict(
            title="",
            thickness=20,
            # Position the left colorbar a bit to the right of the left domain
            x=0.46,        # 0 = far left, 1 = far right
            xanchor="left"
        ),
    ),
    row=1, col=1
)

# Right heatmap
fig.add_trace(
    go.Heatmap(
        z=tensor2.numpy(),
        colorscale="RdBu",
        zmid=0,
        showscale=True,
        colorbar=dict(
            title="",
            thickness=20,
            # Place the right colorbar well to the right
            x=1.01,
            xanchor="left"
        ),
    ),
    row=1, col=2
)

# -- Make the left subplot narrower and shifted left
fig.update_xaxes(domain=[0.0, 0.45], row=1, col=1)  # 30% of total width
# -- Shift the right subplot further right
fig.update_xaxes(domain=[0.55, 1.0], row=1, col=2) # starts at 0.45, ends at 0.8

fig.update_yaxes(title_text="Layer", row=1, col=1)
fig.update_yaxes(title_text="Layer", row=1, col=2)
fig.update_xaxes(title_text="Head", row=1, col=1)
fig.update_xaxes(title_text="Head", row=1, col=2)
fig.update_yaxes(autorange="reversed", row=1, col=1)
fig.update_yaxes(autorange="reversed", row=1, col=2)

fig.update_layout(
    width=1500,                # <--- Force a smaller figure width in pixels
    height=500,
    font=dict(size=14)
)
fig.show()
fig.write_image("./outputs/toxicity_heatmap.svg")

# Debiasing graphs

In [11]:
toxicity_graph_path = "./inputs/toxicity-samp_EAP-IG_step19000_18922edges.json"
bias_graph_path = "./inputs/adv-bias-3_EAP-IG-KL_step7000_6908edges.json"

toxicity_graph = Graph.from_json(toxicity_graph_path)
bias_graph = Graph.from_json(bias_graph_path)
ablated_bias = Graph.from_json(bias_graph_path)

num_edges = 0
for node_str in bias_graph.nodes:
    toxicity_node = toxicity_graph.nodes[node_str]
    bias_node = bias_graph.nodes[node_str]
    ablated_bias_node = ablated_bias.nodes[node_str]
    
    for bias_edge, toxicity_edge, ablated_bias_edge in zip(bias_node.parent_edges, toxicity_node.parent_edges, ablated_bias_node.parent_edges):
        assert bias_edge.name == toxicity_edge.name == ablated_bias_edge.name
        # keeping edge in bias and not in toxicity
        if bias_edge.in_graph and toxicity_edge.in_graph:
            num_edges += 1
            ablated_bias_edge.in_graph = False

print("number of common edges:", num_edges)
print("number of edges before pruning dead nodes:", edge_counter(ablated_bias))
ablated_bias.prune_dead_nodes(prune_childless=True, prune_parentless=True)
print("number of edges after pruning dead nodes:", edge_counter(ablated_bias))

ablated_bias.to_json("./outputs/ablated_bias-3_EAP-IG-KL_step7000_6908edges_toxicity-samp_EAP-IG_step19000_18922edges.json")

number of common edges: 4718
number of edges before pruning dead nodes: 2190
number of edges after pruning dead nodes: 1253


In [12]:
# turn all edges in bias circuit off to create empty graph
false_bias = Graph.from_json(bias_graph_path)
num_edges = 0
for node_str in bias_graph.nodes:
    toxicity_node = toxicity_graph.nodes[node_str]
    bias_node = bias_graph.nodes[node_str]
    ablated_bias_node = false_bias.nodes[node_str]
    
    for bias_edge, toxicity_edge, ablated_bias_edge in zip(bias_node.parent_edges, toxicity_node.parent_edges, ablated_bias_node.parent_edges):
        assert bias_edge.name == toxicity_edge.name == ablated_bias_edge.name
        num_edges += 1
        ablated_bias_edge.in_graph = False

print("check empty graph has no edges:", edge_counter(false_bias))
false_bias.to_json("./outputs/no_edges_bias.json")

check empty graph has no edges: 0
