In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import fancy_einsum
from tqdm import tqdm
import re
import random
from sklearn.metrics import roc_curve, auc
import transformer_lens.utils as utils
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go

from sparse_autoencoder import SparseAutoencoder

from transformer_lens import HookedTransformer, HookedTransformerConfig
from jaxtyping import Float, Int
from torch import Tensor
from functools import partial
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix, HookedTransformerConfig

import transformer_lens.utils as utils
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from common_utils import *

# Turn grad on
torch.set_grad_enabled(True)

  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.grad_mode.set_grad_enabled at 0x2e3245a30>

In [286]:
task = 'ioi'
task_type = 'node'
assert task_type in ['node', 'edge'], "Type must be either 'node' or 'edge'"
print(f"Type: {task_type}")
task_mappings = {
    'gt': 'Greater-than',
    'ioi': 'Indirect Object Identification',
    'ds': 'Docstring',
    'induction': 'Induction',
}

print(f"Task: {task_mappings[task]}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

num_unique = 300
n_epochs = 500

# Load residual streams
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')
print(ground_truth)
print(f"Residual streams shape: {resid_streams.shape}")

model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
model = torch.load(f'../models/{task}/sparse_autoencoder.pt')
model = model.to(device)

heads = []
layers = []
for i, l in enumerate(head_labels):
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    heads.append(head)
    layers.append(layer)

heads = list(set(heads))
layers = list(set(layers))

ground_truth_array = np.zeros((len(layers), len(heads)))
for layer, head in ground_truth:
    ground_truth_array[layer, head] = 1

normalise = False

# Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)

model.eval()
learned_activations = model(resid_streams)[0].detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)
y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)
print(f"ROC AUC: {node_roc_auc:.4f}\n\n")

Type: node
Task: Indirect Object Identification
Using device: cpu
[(2, 2), (4, 11), (0, 1), (3, 0), (0, 10), (5, 5), (6, 9), (5, 8), (5, 9), (7, 3), (7, 9), (8, 6), (8, 10), (10, 7), (11, 10), (9, 9), (9, 6), (10, 0), (10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (11, 9), (11, 3), (9, 7)]
Residual streams shape: torch.Size([500, 144, 768])
ROC AUC: 0.8750




# Faithfulness

Given a circuit $C$ and metric $m$, let $m(C)$ denote the average value of $m$ over $\mathcal{D}$ (a dataset consisting of contrastive pairs) when running our model with all nodes outside of $C$ mean-ablated i.e. set to their average value over data from $\mathcal{D}$. We then measure faithfulness as:

$$ \frac{m(C) - m(\emptyset)}{m(M) - m(\emptyset)}$$

where $\emptyset$ denotes the empty circuit and $M$ denotes the full model. Intuitively, tihs metric captures the proportion of the model's performance our circuit explains, relative to mean ablating the full model (which represents the "prior" performance of the model when it is given information about the task, but not about specific inputs). 

In [287]:
# Turn grad off
torch.set_grad_enabled(False)

def get_names(ioi_example: str):
    # Return the names of the entities in the example
    # Do this by returning all words with capital letters
    return re.findall(r'[A-Z][a-z]*', ioi_example)[1:]

def get_years(gt_example: str):
    return re.findall(r'\d{4}', gt_example)

def get_next_token(example: str):
    if task == 'ioi':
        # Return the next token in the example
        # Do this by returning the first word after the last name
        names = get_names(example)
        # Answer is name that HASN'T occurred twice
        for name in names:
            if example.count(name) == 1:
                return ' ' + name
    elif task == 'gt':
        year = get_years(example)[0]
        year = str(year)
        century = year[:-2]
        year = year[-2:]
        new_year = random.randint(int(year), 99)
        return str(new_year)
    else: 
        raise NotImplementedError

if task == 'induction':

    cfg = HookedTransformerConfig(
        d_model=768,
        d_head=64,
        n_heads=12,
        n_layers=2,
        n_ctx=2048,
        d_vocab=50278,
        attention_dir="causal",
        attn_only=True, # defaults to False
        tokenizer_name="EleutherAI/gpt-neox-20b",
        seed=398,
        use_attn_result=True,
        normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
        positional_embedding_type="shortformer"
    )

    weights_dir = '../models/induction/attn_only_2L_half.pth'

    device = 'cpu'
    tl_model = HookedTransformer(cfg)
    pretrained_weights = torch.load(weights_dir, map_location=device)
    tl_model.load_state_dict(pretrained_weights)
    tl_model = tl_model.to('cpu')
    all_tokens = torch.load(f'../data/{task}/all_tokens.pt')[:10, :]

else:
    tl_model = HookedTransformer.from_pretrained(model_name = "gpt2-small", device='cpu')

    all_tokens = torch.load(f'../data/{task}/all_tokens.pt')[:10]
    all_tokens = [x.replace('!', '') for x in all_tokens]
    all_tokens = [x + get_next_token(x) for x in all_tokens]
    print(all_tokens)
    all_tokens = tl_model.to_tokens(all_tokens, padding_side='left')

Loaded pretrained model gpt2-small into HookedTransformer
['Then, Richard and Nicole went to the garden. Nicole gave a bone to Richard', 'Friends Melissa and James found a bone at the house. James gave it to Melissa', 'Then, Stephanie and Heather had a long argument. Afterwards Heather said to Stephanie', 'The hospital Ashley and Daniel went to had a basketball. Daniel gave it to Ashley', 'Then, Bradley and Daniel were thinking about going to the house. Daniel wanted to give a snack to Bradley', 'When Bryan and Katherine got a bone at the store, Bryan decided to give it to Katherine', 'Then, Lauren and Shannon were working at the school. Shannon decided to give a necklace to Lauren', 'When Stephen and Brittany got a drink at the house, Brittany decided to give it to Stephen', 'After Ashley and James went to the hospital, James gave a necklace to Ashley', 'Then, Amy and Alexander were working at the hospital. Amy decided to give a kiss to Alexander']


In [17]:
# Import List and Tuple
from typing import List, Tuple
import torch as t
# functools
import functools

def head_ablation_hook(
    attn_result: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
    head_index_to_ablate: int
) -> Float[Tensor, "batch seq n_heads d_model"]:
    attn_result[:, :, head_index_to_ablate, :] = 0.0
    return attn_result

def cross_entropy_loss(logits, tokens):
    '''
    Computes the mean cross entropy between logits (the model's prediction) and tokens (the true values).
    '''
    log_probs = F.log_softmax(logits, dim=-1)
    pred_log_probs = t.gather(log_probs[:, :-1], -1, tokens[:, 1:, None])[..., 0]
    return -pred_log_probs.mean()

def kl_divergence(logits: torch.Tensor, base_model_logits: torch.Tensor, 
                  last_seq_element_only: bool = True, return_one_element: bool = True) -> torch.Tensor:
    if last_seq_element_only:
        logits = logits[:, -1, :]
        base_model_logits = base_model_logits[:, -1, :]

    logprobs = F.log_softmax(logits, dim=-1)
    base_model_logprobs = F.log_softmax(base_model_logits, dim=-1)
    kl_div = F.kl_div(logprobs, base_model_logprobs, log_target=True, reduction="none").sum(dim=-1)

    if return_one_element:
        return kl_div.mean()

    return kl_div

def logit_difference(ablated_loss: torch.Tensor, base_loss: torch.Tensor) -> torch.Tensor: 
    """
    Calculates the logit difference of the correct token between ablated and clean model.
    """
    # Ablated loss
    ablated_loss = ablated_loss.mean(dim=0)
    if task == 'induction':
        # Loss on second half of the sequence
        ablated_loss = ablated_loss[ablated_loss.shape[0]//2+1:]
    elif task == 'ioi':
        # Loss on last token
        ablated_loss = ablated_loss[-1]
    elif task == 'gt':
        # Loss on last token
        ablated_loss = ablated_loss[-1]
    else: raise ValueError(f"Unsupported task: {task}")
    # Clean loss
    base_loss = base_loss.mean(dim=0)
    if task == 'induction':
        base_loss = base_loss[base_loss.shape[0]//2+1:]
    elif task == 'ioi':
        base_loss = base_loss[-1]
    elif task == 'gt':
        base_loss = base_loss[-1]
    else: raise ValueError(f"Unsupported task: {task}")

    return (ablated_loss - base_loss).mean()


def faithfulness_score(
    model: HookedTransformer,
    tokens: Int[Tensor, "batch seq"],
    heads_to_ablate: List[Tuple[int, int]],
    all_heads: List[Tuple[int, int]],
    metric: str = "kl_divergence"
) -> Float[Tensor, "1"]:
    """
    Returns the faithfulness of a circuit to the underlying model.
    faithfulness = m(C) - m(empty_circuit) / m(M) - m(empty_circuit)
    """

    # Calculate the metric for m(M), the full (clean) model
    model.reset_hooks()
    logits, base_loss = model(tokens, return_type="both", loss_per_token=True)

    # Calculate the metric for m(C), the predicted circuit
    # Run the model with the ablation hook
    model.reset_hooks()
    ablated_logits, ablated_loss = model.run_with_hooks(tokens, return_type="both", loss_per_token=True, fwd_hooks=[
        (utils.get_act_name("z", layer), functools.partial(head_ablation_hook, head_index_to_ablate=head)) for layer, head in heads_to_ablate
    ])

    if metric == "kl_divergence":
        m_M = 0.0  # Since the KL divergence is with respect to the base model, m(M) here is just 0
        m_C = kl_divergence(logits, ablated_logits)
    elif metric == "logit_difference":
        m_M = logit_difference(base_loss, base_loss) #0.0  # Logit difference for the clean model is 0
        m_C = logit_difference(ablated_loss, base_loss)
    else:
        raise ValueError(f"Unsupported metric: {metric}")

    # Calculate the metric for m(empty_circuit), the empty circuit
    emptyset_logits, emptyset_loss = model.run_with_hooks(tokens, return_type="both", loss_per_token=True, fwd_hooks=[
        (utils.get_act_name("z", layer), functools.partial(head_ablation_hook, head_index_to_ablate=head)) for layer, head in all_heads
    ])

    if metric == "kl_divergence":
        m_empty_circuit = kl_divergence(logits, emptyset_logits)
    elif metric == "logit_difference":
        m_empty_circuit = logit_difference(emptyset_loss, base_loss)
    
    # Return the faithfulness
    return (m_C - m_empty_circuit) / (m_M - m_empty_circuit)

def get_ablation_scores(
    model: HookedTransformer,
    tokens: Int[Tensor, "batch seq"],
    heads_to_ablate: List[Tuple[int, int]],
    metric: str = "kl_divergence"
) -> Float[Tensor, "1"]:
    '''
    Returns the increase in cross entropy loss from ablating the output of specified heads.
    '''

    # Calculating loss without any ablation, to act as a baseline
    model.reset_hooks()
    base_logits, base_loss = model(tokens, return_type="both", loss_per_token = True)

    # Run the model with the ablation hook
    ablated_logits, ablated_loss = model.run_with_hooks(tokens, return_type="both", loss_per_token=True, fwd_hooks=[
        (utils.get_act_name("z", layer), functools.partial(head_ablation_hook, head_index_to_ablate=head)) for layer, head in heads_to_ablate
    ])

    if metric == "kl_divergence":
        # Store the result, subtracting the clean loss so that a value of zero means no change in loss
        ablation_score = kl_divergence(base_logits, ablated_logits)

    elif metric == "logit_difference":
        # Calculate the logit difference of the correct token between ablated and clean model
        ablation_score = logit_difference(ablated_loss, base_loss)

    return ablation_score

# Example usage
heads_to_ablate = [(0, 10), (0, 9)]  # List of tuples (layer, head) to ablate
all_heads = []
for i in range(len(head_labels)):
    # Calculate head and layer
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    all_heads.append((layer, head))
ablation_score = get_ablation_scores(tl_model, all_tokens, heads_to_ablate)
print(ablation_score)
ablation_score_ld = get_ablation_scores(tl_model, all_tokens, heads_to_ablate, metric="logit_difference")
print(ablation_score_ld)
faithfulness = faithfulness_score(tl_model, all_tokens, heads_to_ablate, all_heads)
print(faithfulness)
faithfulness_ld = faithfulness_score(tl_model, all_tokens, heads_to_ablate, all_heads, metric="logit_difference")
print(faithfulness_ld)

In [299]:
model.eval()
learned_activations = model.encoder(resid_streams).detach().cpu().numpy()
print(f"Learned activations shape: {learned_activations.shape}")

all_indices = np.argmax(learned_activations, axis=2)
print(f"All indices shape: {all_indices.shape}")    

positive_indices = all_indices[:250, :]
negative_indices = all_indices[250:, :]

unique_to_positive_array = gen_array_template(head_labels)

normalise = False

for i in range(len(head_labels)):
    # Calculate head and layer
    layer, head = feature_string_to_head_and_layer(i, head_labels)

    positive = set(positive_indices[:, i].tolist())
    negative = set(negative_indices[:, i].tolist())
    total_unique = positive.union(negative)

    # In positive but not negative
    unique_to_positive = list(positive - negative)
    # In negative but not positive
    unique_to_negative = list(negative - positive)

    if normalise: unique_to_positive_array[layer, head] = len(unique_to_positive) / len(total_unique)
    
    else: unique_to_positive_array[layer, head] = len(unique_to_positive)

array_shape = unique_to_positive_array.shape
unique_to_positive_array = unique_to_positive_array.flatten()

# Apply softmax
unique_to_positive_array = np.exp(unique_to_positive_array) / np.sum(np.exp(unique_to_positive_array))

# Reshape
unique_to_positive_array = unique_to_positive_array.reshape(array_shape)

def positive_array_to_ablations(pos_array, threshold, faithfulness=True, start_layer=1):
    binary_array = (pos_array > threshold).astype(int) if faithfulness else (pos_array < threshold).astype(int)
    # Rows are layers, columns are heads -> return list of tuples of (layer, head)
    return [(layer, head) for layer in range(binary_array.shape[0]) for head in range(binary_array.shape[1]) if binary_array[layer, head] == 0 and layer > start_layer]

y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)
print(f"Node ROC AUC: {node_roc_auc:.4f}")

Learned activations shape: (500, 144, 300)
All indices shape: (500, 144)
Node ROC AUC: 0.8750


In [300]:
# Faithfulness - ablate heads NOT included in the circuit
num_nodes = []
num_elements_in_pos_array = torch.tensor(unique_to_positive_array).numel()
ablation_scores_kl = []
ablation_scores_ld = []
faithfulness_scores_kl = []
faithfulness_scores_ld = []

ablation_scores_kl_sequential = []
ablation_scores_ld_sequential = []
faithfulness_scores_kl_sequential = []
faithfulness_scores_ld_sequential = []

num_layers = 2 if task == 'induction' else 12
to_add = 0 if task == 'induction' else 2
start_layer = 0 if task == 'induction' else 1
print(f"Num layers: {num_layers}, to add: {to_add}, start layer: {start_layer}")

num_heads = len(head_labels) // num_layers
all_head_layers = [(layer, head) for layer in range(num_layers) for head in range(num_heads)]

heads_to_ablate_sequential = []


for threshold in tqdm(thresholds):
    # Our ablations, with our predicted circuit
    heads_to_ablate = positive_array_to_ablations(unique_to_positive_array, threshold=threshold, start_layer=start_layer)
    #print(f"Heads to ablate: {heads_to_ablate}")
    ablation_score_kl = get_ablation_scores(tl_model, all_tokens, heads_to_ablate)
    ablation_score_ld = get_ablation_scores(tl_model, all_tokens, heads_to_ablate, metric="logit_difference")
    faithfulness_kl = faithfulness_score(tl_model, all_tokens, heads_to_ablate, all_heads)
    faithfulness_ld = faithfulness_score(tl_model, all_tokens, heads_to_ablate, all_heads, metric="logit_difference")
    num_nodes.append(num_elements_in_pos_array - len(heads_to_ablate))
    ablation_scores_kl.append(ablation_score_kl)
    ablation_scores_ld.append(ablation_score_ld)
    faithfulness_scores_kl.append(faithfulness_kl)
    faithfulness_scores_ld.append(faithfulness_ld)

    # Heads to ablate is now sequential, moving through the layers
    n = len(heads_to_ablate)
    heads_to_ablate_indices = [x for x in range(n)]
    # Convert to list of tuples (layer, head)
    heads_to_ablate_sequential = [(heads_to_ablate_indices[i] % num_layers, heads_to_ablate_indices[i] // num_layers + to_add) for i in range(n)]
    ##print(f"Heads to ablate sequential: {heads_to_ablate_sequential}")
    ablation_score_kl_sequential = get_ablation_scores(tl_model, all_tokens, heads_to_ablate_sequential)
    ablation_score_ld_sequential = get_ablation_scores(tl_model, all_tokens, heads_to_ablate_sequential, metric="logit_difference")
    faithfulness_kl_sequential = faithfulness_score(tl_model, all_tokens, heads_to_ablate_sequential, all_heads)
    faithfulness_ld_sequential = faithfulness_score(tl_model, all_tokens, heads_to_ablate_sequential, all_heads, metric="logit_difference")
    ablation_scores_kl_sequential.append(ablation_score_kl_sequential)
    ablation_scores_ld_sequential.append(ablation_score_ld_sequential)
    faithfulness_scores_kl_sequential.append(faithfulness_kl_sequential)
    faithfulness_scores_ld_sequential.append(faithfulness_ld_sequential)

Num layers: 12, to add: 2, start layer: 1


100%|██████████| 28/28 [01:07<00:00,  2.42s/it]


In [301]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Define colors for "Ours" and "Sequential"
color_ours = '#636EFA'
color_sequential = '#EF553B'

# Create a 2x2 subplot grid
fig = make_subplots(rows=2, cols=2, subplot_titles=("KL Divergence", "Loss", "Faithfulness (KL)", "Faithfulness (Loss)"), vertical_spacing=0.15)

# Add traces to the subplots for "Ours"
fig.add_trace(go.Scatter(x=num_nodes, y=ablation_scores_kl, mode='lines+markers', name="Ours", line=dict(width=6, color=color_ours), marker=dict(size=15, color=color_ours), legendgroup="Ours", showlegend=True), row=1, col=1)
fig.add_trace(go.Scatter(x=num_nodes, y=ablation_scores_ld, mode='lines+markers', name="Ours", line=dict(width=6, color=color_ours), marker=dict(size=15, color=color_ours), legendgroup="Ours", showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=[0]+num_nodes+[num_elements_in_pos_array], y=[0]+faithfulness_scores_kl+[1], mode='lines+markers', name="Ours", line=dict(width=6, color=color_ours), marker=dict(size=15, color=color_ours), legendgroup="Ours", showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=[0]+num_nodes+[num_elements_in_pos_array], y=[0]+faithfulness_scores_ld+[1], mode='lines+markers', name="Ours", line=dict(width=6, color=color_ours), marker=dict(size=15, color=color_ours), legendgroup="Ours", showlegend=False), row=2, col=2)

# Add traces to the subplots for "Sequential"
fig.add_trace(go.Scatter(x=num_nodes, y=ablation_scores_kl_sequential, mode='lines', name="Sequential", line=dict(width=6, dash='dash', color=color_sequential), legendgroup="Sequential", showlegend=True), row=1, col=1)
fig.add_trace(go.Scatter(x=num_nodes, y=ablation_scores_ld_sequential, mode='lines', name="Sequential", line=dict(width=6, dash='dash', color=color_sequential), legendgroup="Sequential", showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=[0]+num_nodes+[num_elements_in_pos_array], y=[0]+faithfulness_scores_kl_sequential+[1], mode='lines', name="Sequential", line=dict(width=6, dash='dash', color=color_sequential), legendgroup="Sequential", showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=[0]+num_nodes+[num_elements_in_pos_array], y=[0]+faithfulness_scores_ld_sequential+[1], mode='lines', name="Sequential", line=dict(width=6, dash='dash', color=color_sequential), legendgroup="Sequential", showlegend=False), row=2, col=2)

# Update layout
fontsize = 28
fig.update_layout(
    font=dict(family='Palatino', size=fontsize),
    plot_bgcolor='white',
    width=1000,
    height=800,
    showlegend=True,
    legend=dict(font=dict(size=fontsize)),
    xaxis3=dict(title='Number of nodes'),
    xaxis4=dict(title='Number of nodes'),
)

# Adjust subplot titles
fig.update_annotations(font=dict(size=fontsize))

# Update subplot titles and axis labels to have fontsize 24
fig.update_xaxes(title_font=dict(size=fontsize-6))
fig.update_yaxes(title_font=dict(size=fontsize-6))

# Save the figure as pdf
fig.write_image(f"../output/{task}/faithfulness_{task}.pdf")

# Display the figure
fig.show()

In [302]:
# Faithfulness - ablate heads NOT included in the circuit
num_nodes = []
num_elements_in_pos_array = torch.tensor(unique_to_positive_array).numel()
ablation_scores_kl = []
ablation_scores_ld = []
faithfulness_scores_kl = []
faithfulness_scores_ld = []

ablation_scores_kl_sequential = []
ablation_scores_ld_sequential = []
faithfulness_scores_kl_sequential = []
faithfulness_scores_ld_sequential = []

num_layers = 2 if task == 'induction' else 12
to_add = 0 if task == 'induction' else 2
start_layer = 0 if task == 'induction' else 1
print(f"Num layers: {num_layers}, to add: {to_add}, start layer: {start_layer}")

num_heads = len(head_labels) // num_layers
all_head_layers = [(layer, head) for layer in range(num_layers) for head in range(num_heads)]

heads_to_ablate_sequential = []


for threshold in tqdm(thresholds):
    # Our ablations, with our predicted circuit
    heads_to_ablate = positive_array_to_ablations(unique_to_positive_array, threshold=threshold, faithfulness=False, start_layer=start_layer)
    #print(f"Heads to ablate: {heads_to_ablate}")
    ablation_score_kl = get_ablation_scores(tl_model, all_tokens, heads_to_ablate)
    ablation_score_ld = get_ablation_scores(tl_model, all_tokens, heads_to_ablate, metric="logit_difference")
    faithfulness_kl = faithfulness_score(tl_model, all_tokens, heads_to_ablate, all_heads)
    faithfulness_ld = faithfulness_score(tl_model, all_tokens, heads_to_ablate, all_heads, metric="logit_difference")
    num_nodes.append(len(heads_to_ablate))
    ablation_scores_kl.append(ablation_score_kl)
    ablation_scores_ld.append(ablation_score_ld)
    faithfulness_scores_kl.append(faithfulness_kl)
    faithfulness_scores_ld.append(faithfulness_ld)

    # Heads to ablate is now sequential, moving through the layers
    n = len(heads_to_ablate)
    heads_to_ablate_indices = [x for x in range(n)]
    # Convert to list of tuples (layer, head)
    heads_to_ablate_sequential = [(heads_to_ablate_indices[i] % num_layers, heads_to_ablate_indices[i] // num_layers + to_add) for i in range(n)]
    # # Heads to ablate sequential is a random choice of n from all heads layers
    # heads_to_ablate_sequential_list = random.sample(all_head_layers, n-len(heads_to_ablate_sequential))
    # heads_to_ablate_sequential.extend(heads_to_ablate_sequential_list)
    # # Remove head layer pairs from all_head_layers already in heads_to_ablate_sequential
    # all_head_layers = [x for x in all_head_layers if x not in heads_to_ablate_sequential_list]
    #print(f"Heads to ablate sequential: {heads_to_ablate_sequential}")
    ablation_score_kl_sequential = get_ablation_scores(tl_model, all_tokens, heads_to_ablate_sequential)
    ablation_score_ld_sequential = get_ablation_scores(tl_model, all_tokens, heads_to_ablate_sequential, metric="logit_difference")
    faithfulness_kl_sequential = faithfulness_score(tl_model, all_tokens, heads_to_ablate_sequential, all_heads)
    faithfulness_ld_sequential = faithfulness_score(tl_model, all_tokens, heads_to_ablate_sequential, all_heads, metric="logit_difference")
    ablation_scores_kl_sequential.append(ablation_score_kl_sequential)
    ablation_scores_ld_sequential.append(ablation_score_ld_sequential)
    faithfulness_scores_kl_sequential.append(faithfulness_kl_sequential)
    faithfulness_scores_ld_sequential.append(faithfulness_ld_sequential)

Num layers: 12, to add: 2, start layer: 1


100%|██████████| 28/28 [01:03<00:00,  2.29s/it]


In [285]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Define colors for "Ours" and "Sequential"
color_ours = '#636EFA'
color_sequential = '#EF553B'

# Create a 2x2 subplot grid
fig = make_subplots(rows=2, cols=2, subplot_titles=("KL Divergence", "Loss", "Completeness (KL)", "Completeness (Loss)"), vertical_spacing=0.15)

# Add traces to the subplots for "Ours"
fig.add_trace(go.Scatter(x=num_nodes, y=ablation_scores_kl, mode='lines+markers', name="Ours", line=dict(width=6, color=color_ours), marker=dict(size=15, color=color_ours), legendgroup="Ours", showlegend=True), row=1, col=1)
fig.add_trace(go.Scatter(x=num_nodes, y=ablation_scores_ld, mode='lines+markers', name="Ours", line=dict(width=6, color=color_ours), marker=dict(size=15, color=color_ours), legendgroup="Ours", showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=[0]+num_nodes, y=[1]+faithfulness_scores_kl, mode='lines+markers', name="Ours", line=dict(width=6, color=color_ours), marker=dict(size=15, color=color_ours), legendgroup="Ours", showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=[0]+num_nodes, y=[1]+faithfulness_scores_ld, mode='lines+markers', name="Ours", line=dict(width=6, color=color_ours), marker=dict(size=15, color=color_ours), legendgroup="Ours", showlegend=False), row=2, col=2)

# Add traces to the subplots for "Sequential"
fig.add_trace(go.Scatter(x=num_nodes, y=ablation_scores_kl_sequential, mode='lines', name="Sequential", line=dict(width=6, dash='dash', color=color_sequential), legendgroup="Sequential", showlegend=True), row=1, col=1)
fig.add_trace(go.Scatter(x=num_nodes, y=ablation_scores_ld_sequential, mode='lines', name="Sequential", line=dict(width=6, dash='dash', color=color_sequential), legendgroup="Sequential", showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=[0]+num_nodes, y=[1]+faithfulness_scores_kl_sequential, mode='lines', name="Sequential", line=dict(width=6, dash='dash', color=color_sequential), legendgroup="Sequential", showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=[0]+num_nodes, y=[1]+faithfulness_scores_ld_sequential, mode='lines', name="Sequential", line=dict(width=6, dash='dash', color=color_sequential), legendgroup="Sequential", showlegend=False), row=2, col=2)

# Update layout
fontsize = 28
fig.update_layout(
    font=dict(family='Palatino', size=fontsize),
    plot_bgcolor='white',
    width=1000,
    height=800,
    showlegend=True,
    legend=dict(font=dict(size=fontsize)),
    xaxis3=dict(title='Number of nodes'),
    xaxis4=dict(title='Number of nodes'),
)

# Adjust subplot titles
fig.update_annotations(font=dict(size=fontsize))

# Update subplot titles and axis labels to have fontsize 24
fig.update_xaxes(title_font=dict(size=fontsize-6))
fig.update_yaxes(title_font=dict(size=fontsize-6))

# Save the figure as pdf
#fig.write_image(f"../output/{task}/completeness_{task}.pdf")

# Display the figure
fig.show()

In [None]:
# Learned activations and then take argmax to discretise
learned_activations = model(resid_streams)[0].detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)

positive_indices = all_indices[:250, :]
negative_indices = all_indices[250:, :]
positive_learned_activations = learned_activations[:250, :, :]
negative_learned_activations = learned_activations[250:, :, :]

# Assume all_indices, positive_indices, and negative_indices are defined, as well as n_heads and n_feat
n_feat = learned_activations.shape[-1]
n_heads = all_indices.shape[1]
positive_co_occurrence_matrix = gen_co_occurrence_matrix(positive_indices, n_heads, n_feat)
negative_co_occurrence_matrix = gen_co_occurrence_matrix(negative_indices, n_heads, n_feat)

# Calculate unique co-occurrences
normalise = False# if task in ['ds', 'ioi']  else True
print(f"Normalise: {normalise}")
unique_co_occurrence_counts = unique_co_occurrences(positive_co_occurrence_matrix, negative_co_occurrence_matrix, normalise=normalise)

# Sort (head, head) pairs by descending unique co-occurrence counts
sorted_indices = np.argsort(unique_co_occurrence_counts.flatten())[::-1]
sorted_indices = np.unravel_index(sorted_indices, unique_co_occurrence_counts.shape)
# Zip them together to create a list of (head, head) pairs
sorted_head_pairs = list(zip(sorted_indices[0], sorted_indices[1]))

Normalise: False


In [68]:
def array_to_ablations(array, threshold):
    binary_array = (array > threshold).astype(int)
    try:
        # Num edges is the number of elements that are non-zero in the binary array
        num_edges = np.sum(binary_array) // 2
        # Get head numbers that are not zero
        heads = list(set(np.array(np.argwhere(binary_array).tolist())[:, 0].tolist()))
        # Heads to ablate
        heads_to_ablate = [x for x in range(24) if x not in heads]
        # Convert head index into (layer, head) tuple
        heads = [feature_string_to_head_and_layer(h, head_labels) for h in heads_to_ablate]
        #num_edges = 144 - calculate_total_edges(heads)
        # Num edges is 
        return heads, num_edges
    except:
        heads = [x for x in range(24)]
        heads = [feature_string_to_head_and_layer(h, head_labels) for h in heads]
        return heads, 0


num_nodes = []
ablation_scores_kl = []
ablation_scores_ld = []
faithfulness_scores_kl = []
faithfulness_scores_ld = []

thresholds = []
heads_to_ablate_old = ['a']

for threshold in range(0, int(unique_co_occurrence_counts.max())):
    heads_to_ablate_new, num_nodes_in_circuit = array_to_ablations(unique_co_occurrence_counts, threshold)
    if heads_to_ablate_new != heads_to_ablate_old:
        print(f"Threshold: {threshold}, heads to ablate: {heads_to_ablate_new}")
        # Calculate KL
        ablation_score_kl = get_ablation_scores(tl_model, all_tokens, heads_to_ablate_new)
        ablation_score_ld = get_ablation_scores(tl_model, all_tokens, heads_to_ablate_new, metric="logit_difference")
        faithfulness_kl = faithfulness_score(tl_model, all_tokens, heads_to_ablate_new, all_heads)
        faithfulness_ld = faithfulness_score(tl_model, all_tokens, heads_to_ablate_new, all_heads, metric="logit_difference")
        num_nodes.append(num_nodes_in_circuit)
        ablation_scores_kl.append(ablation_score_kl)
        ablation_scores_ld.append(ablation_score_ld)
        faithfulness_scores_kl.append(faithfulness_kl)
        faithfulness_scores_ld.append(faithfulness_ld)
        thresholds.append(threshold)
        heads_to_ablate_old = heads_to_ablate_new
    # Calculate KL and logit difference
    # ablation_score_kl = get_ablation_scores(tl_model, all_tokens, heads_to_ablate)
    # ablation_score_ld = get_ablation_scores(tl_model, all_tokens, heads_to_ablate, metric="logit_difference")
    # faithfulness_kl = faithfulness_score(tl_model, all_tokens, heads_to_ablate, all_heads)
    # faithfulness_ld = faithfulness_score(tl_model, all_tokens, heads_to_ablate, all_heads, metric="logit_difference")

Threshold: 0, heads to ablate: []
Threshold: 34, heads to ablate: [(1, 2)]
Threshold: 58, heads to ablate: [(0, 2), (1, 2)]
Threshold: 79, heads to ablate: [(0, 2), (0, 9), (1, 2)]
Threshold: 83, heads to ablate: [(0, 2), (0, 3), (0, 9), (1, 2)]
Threshold: 92, heads to ablate: [(0, 2), (0, 3), (0, 5), (0, 9), (1, 2)]
Threshold: 94, heads to ablate: [(0, 2), (0, 3), (0, 5), (0, 9), (1, 2), (1, 6), (1, 11)]
Threshold: 102, heads to ablate: [(0, 2), (0, 3), (0, 5), (0, 7), (0, 9), (1, 2), (1, 6), (1, 11)]
Threshold: 107, heads to ablate: [(0, 2), (0, 3), (0, 5), (0, 7), (0, 8), (0, 9), (1, 2), (1, 6), (1, 11)]
Threshold: 110, heads to ablate: [(0, 2), (0, 3), (0, 5), (0, 7), (0, 8), (0, 9), (1, 2), (1, 6), (1, 10), (1, 11)]
Threshold: 113, heads to ablate: [(0, 2), (0, 3), (0, 5), (0, 7), (0, 8), (0, 9), (1, 2), (1, 6), (1, 8), (1, 10), (1, 11)]
Threshold: 114, heads to ablate: [(0, 2), (0, 3), (0, 5), (0, 7), (0, 8), (0, 9), (1, 2), (1, 5), (1, 6), (1, 8), (1, 10), (1, 11)]
Threshold: 11

In [69]:
# Plotly line plot of threshold vs kl div 
fig = go.Figure(data=go.Scatter(x=num_nodes, y=ablation_scores_kl, mode='lines+markers', name="Ours", line=dict(width=6), marker=dict(size=15))) 

# Update layout 
fig.update_layout(xaxis_title='Number of edges', yaxis_title='KL(Model, Ablated)', font=dict(size=24), plot_bgcolor='white', width=1000, height=600) 

# X-axis log
fig.update_xaxes(type="log")

fig.show()

In [70]:
# Plotly line plot of threshold vs logit diff
fig = go.Figure(data=go.Scatter(x=num_nodes, y=ablation_scores_ld, mode='lines+markers', name="Ours", line=dict(width=6), marker=dict(size=15)))

# Update layout
fig.update_layout(xaxis_title='Number of edges', yaxis_title='LD(Model, Ablated)', font=dict(size=24), plot_bgcolor='white', width=1000, height=600)

# X-axis log
fig.update_xaxes(type="log")

fig.show()

In [71]:
# Plotly line plot of threshold vs faithfulness
fig = go.Figure(data=go.Scatter(x=num_nodes, y=faithfulness_scores_kl, mode='lines+markers', name="Ours", line=dict(width=6), marker=dict(size=15)))

# Update layout
fig.update_layout(xaxis_title='Number of edges', yaxis_title='Faithfulness', font=dict(size=24), plot_bgcolor='white', width=1000, height=600)

# X-axis log

fig.update_xaxes(type="log")

fig.show()

In [72]:
# Plotly line plot of threshold vs faithfulness (logit difference)
fig = go.Figure(data=go.Scatter(x=num_nodes, y=faithfulness_scores_ld, mode='lines+markers', name="Ours", line=dict(width=6), marker=dict(size=15)))

# Update layout
fig.update_layout(xaxis_title='Number of edges', yaxis_title='Faithfulness (LD)', font=dict(size=24), plot_bgcolor='white', width=1000, height=600)

# X-axis log
fig.update_xaxes(type="log")

fig.show()

## Getting zoo of data for different models

In [6]:
model_names = ['gpt2-small', 'gpt2-large', 'pythia-70-deduped', 'opt-125m']

In [7]:
model = HookedTransformer.from_pretrained(model_name = "gpt2-small", device='cpu')

Loaded pretrained model gpt2-small into HookedTransformer


In [7]:
batch_size = 25
seq_len = 10
size = (batch_size, seq_len)
input_tensor = torch.randint(1000, 10000, size)

random_tokens = input_tensor#.to(model.cfg.device)
normal_tokens = einops.repeat(random_tokens, "batch seq_len -> batch (2 seq_len)")
# Corrupted tokens 
corrupted_tokens = torch.randint(1000, 10000, (batch_size, seq_len*2))
normal_tokens.shape, corrupted_tokens.shape

(torch.Size([25, 20]), torch.Size([25, 20]))

In [8]:
model_names = ['gpt2-small', 'EleutherAI/pythia-160m-deduped', 'opt-125m'] #'gpt2-large', 
# Turn grad off
torch.set_grad_enabled(False)

for model_name in tqdm(model_names):

    model = HookedTransformer.from_pretrained(model_name = model_name, device='cpu')

    # Normal resid streams
    normal_logits, normal_cache = model.run_with_cache(normal_tokens)
    normal_cache.compute_head_results()
    normal_head_resid, normal_head_labels = normal_cache.stack_head_results(return_labels=True)
    normal_head_resid = normal_head_resid.mean(dim=2).squeeze()
    normal_head_resid = einops.rearrange(normal_head_resid, "seq_len batch n_heads -> batch seq_len n_heads")

    # Corrupted resid streams
    corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
    corrupted_cache.compute_head_results()
    corrupted_head_resid, corrupted_head_labels = corrupted_cache.stack_head_results(return_labels=True)
    corrupted_head_resid = corrupted_head_resid.mean(dim=2).squeeze()
    corrupted_head_resid = einops.rearrange(corrupted_head_resid, "seq_len batch n_heads -> batch seq_len n_heads")

    # Stack the two together
    resid_streams = torch.cat([normal_head_resid, corrupted_head_resid], dim=0).detach().cpu()

    # Ground truth
    ground_truth = [
        (0, 4), (0, 7), (1, 4), (1, 10)
    ]
    labels = normal_head_labels

    # Save
    # Save to ..data/gt folder
    torch.save(resid_streams, f"../data/induction/{model_name.split('/')[-1]}_resid_heads_mean.pt")
    torch.save(labels, f"../data/induction/{model_name.split('/')[-1]}_labels_heads_mean.pt")
    torch.save(ground_truth, f"../data/induction/{model_name.split('/')[-1]}_ground_truth.pt")
    all_tokens = torch.cat([normal_tokens.cpu(), corrupted_tokens.cpu()], dim=0)
    torch.save(all_tokens, "../data/induction/all_tokens.pt")

  0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model gpt2-small into HookedTransformer


 33%|███▎      | 1/3 [00:03<00:06,  3.15s/it]Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-160m-deduped into HookedTransformer


 67%|██████▋   | 2/3 [00:05<00:02,  2.78s/it]

Loaded pretrained model opt-125m into HookedTransformer


100%|██████████| 3/3 [00:07<00:00,  2.58s/it]


In [23]:
torch.set_grad_enabled(True)

for model_name in model_names:

    task = 'induction'
    task_type = 'node'
    assert task_type in ['node', 'edge'], "Type must be either 'node' or 'edge'"
    print(f"Type: {task_type}")
    task_mappings = {
        'gt': 'Greater-than',
        'ioi': 'Indirect Object Identification',
        'ds': 'Docstring',
        'induction': 'Induction',
    }

    print(f"Task: {task_mappings[task]}")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    num_unique = 300
    n_epochs = 500


    roc_results = []

    # Load residual streams
    resid_streams = torch.load(f"../data/{task}/{model_name.split('/')[-1]}_resid_heads_mean.pt").to(device)
    head_labels = torch.load(f'../data/{task}/{model_name.split('/')[-1]}_labels_heads_mean.pt')
    ground_truth = torch.load(f'../data/{task}/{model_name.split('/')[-1]}_ground_truth.pt')
    print(ground_truth)
    print(resid_streams.shape)


    # Shuffle and create the labels
    labels = torch.ones(resid_streams.shape[0]//2) # BIG ASSUMPTION: assumes first half is positive and second half is negative
    labels = torch.cat((labels, torch.zeros_like(labels)))
    permutation = torch.randperm(resid_streams.shape[0])
    resid_shuffled = resid_streams[permutation, :, :]
    labels_shuffled = labels[permutation]
    cutoff = int(resid_shuffled.shape[0] * 0.8)
    train_streams = resid_shuffled[:cutoff, :, :].to(device)
    train_labels = labels_shuffled[:cutoff].to(device)
    eval_streams = resid_shuffled[cutoff:, :, :].to(device)
    eval_labels = labels_shuffled[cutoff:].to(device)

    print(f"Train streams shape: {train_streams.shape}")
    print(f"Train labels shape: {train_labels.shape}")  
    print(f"Eval streams shape: {eval_streams.shape}")
    print(f"Eval labels shape: {eval_labels.shape}")


    model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    model = train(model, n_epochs, optimizer, train_streams, eval_streams, lambda_=0.01)
    model = model.to('cpu')
    resid_streams = resid_streams.to('cpu')
    # Save model
    torch.save(model, f'../models/{task}/{model_name.split('/')[-1]}_sparse_autoencoder.pt')

    heads = []
    layers = []
    for i, l in enumerate(head_labels):
        layer, head = feature_string_to_head_and_layer(i, head_labels)
        heads.append(head)
        layers.append(layer)

    heads = list(set(heads))
    layers = list(set(layers))

    ground_truth_array = np.zeros((len(layers), len(heads)))
    for layer, head in ground_truth:
        ground_truth_array[layer, head] = 1

    normalise = False# if task == 'ds' else True

    # Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)

    model.eval()
    learned_activations = model(resid_streams)[0].detach().cpu().numpy()
    all_indices = np.argmax(learned_activations, axis=2)

    print(f"\n\nNormalise: {normalise}")
    y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

    # Print best f1 score (and corresponding threshold)
    node_best_f1 = np.max(f1)
    best_threshold = thresholds[np.argmax(f1)]
    print(f"Best F1 score: {node_best_f1:.4f}")

    # Print ROC AUC
    print(f"ROC AUC: {node_roc_auc:.4f}\n\n")

Type: node
Task: Induction
Using device: cpu
[(0, 4), (0, 7), (1, 4), (1, 10)]
torch.Size([50, 144, 768])
Train streams shape: torch.Size([40, 144, 768])
Train labels shape: torch.Size([40])
Eval streams shape: torch.Size([10, 144, 768])
Eval labels shape: torch.Size([10])


  0%|          | 2/500 [00:00<00:26, 18.66it/s]

Train loss = 0.8880, Eval loss = 0.7965


 11%|█▏        | 57/500 [00:01<00:09, 47.65it/s]

Train loss = 0.1637, Eval loss = 0.1633


 21%|██▏       | 107/500 [00:02<00:09, 43.62it/s]

Train loss = 0.1258, Eval loss = 0.1265


 31%|███▏      | 157/500 [00:03<00:07, 46.51it/s]

Train loss = 0.1120, Eval loss = 0.1129


 42%|████▏     | 208/500 [00:05<00:07, 41.46it/s]

Train loss = 0.1037, Eval loss = 0.1048


 52%|█████▏    | 258/500 [00:06<00:05, 45.09it/s]

Train loss = 0.0967, Eval loss = 0.0980


 62%|██████▏   | 309/500 [00:07<00:04, 46.00it/s]

Train loss = 0.0910, Eval loss = 0.0924


 71%|███████▏  | 357/500 [00:08<00:03, 46.99it/s]

Train loss = 0.0859, Eval loss = 0.0876


 82%|████████▏ | 409/500 [00:09<00:01, 50.63it/s]

Train loss = 0.0825, Eval loss = 0.0844


 91%|█████████ | 456/500 [00:10<00:00, 44.15it/s]

Train loss = 0.0800, Eval loss = 0.0820


100%|██████████| 500/500 [00:11<00:00, 43.90it/s]




Normalise: False
Best F1 score: 0.6977
ROC AUC: 0.6375


Type: node
Task: Induction
Using device: cpu
[(0, 4), (0, 7), (1, 4), (1, 10)]
torch.Size([50, 144, 768])
Train streams shape: torch.Size([40, 144, 768])
Train labels shape: torch.Size([40])
Eval streams shape: torch.Size([10, 144, 768])
Eval labels shape: torch.Size([10])


  1%|          | 5/500 [00:00<00:11, 41.70it/s]

Train loss = 0.0509, Eval loss = 0.0424


 12%|█▏        | 60/500 [00:01<00:09, 45.69it/s]

Train loss = 0.0018, Eval loss = 0.0018


 22%|██▏       | 108/500 [00:02<00:07, 49.73it/s]

Train loss = 0.0017, Eval loss = 0.0017


 31%|███       | 156/500 [00:03<00:06, 50.60it/s]

Train loss = 0.0017, Eval loss = 0.0017


 42%|████▏     | 209/500 [00:04<00:05, 49.24it/s]

Train loss = 0.0016, Eval loss = 0.0017


 52%|█████▏    | 259/500 [00:05<00:04, 48.52it/s]

Train loss = 0.0016, Eval loss = 0.0017


 62%|██████▏   | 309/500 [00:06<00:03, 48.70it/s]

Train loss = 0.0016, Eval loss = 0.0017


 72%|███████▏  | 361/500 [00:07<00:02, 50.34it/s]

Train loss = 0.0016, Eval loss = 0.0017


 81%|████████▏ | 407/500 [00:08<00:01, 47.40it/s]

Train loss = 0.0016, Eval loss = 0.0017


 92%|█████████▏| 458/500 [00:09<00:00, 48.26it/s]

Train loss = 0.0015, Eval loss = 0.0017


100%|██████████| 500/500 [00:10<00:00, 48.08it/s]




Normalise: False
Best F1 score: 0.5954
ROC AUC: 0.6214


Type: node
Task: Induction
Using device: cpu
[(0, 4), (0, 7), (1, 4), (1, 10)]
torch.Size([50, 144, 768])
Train streams shape: torch.Size([40, 144, 768])
Train labels shape: torch.Size([40])
Eval streams shape: torch.Size([10, 144, 768])
Eval labels shape: torch.Size([10])


  1%|          | 5/500 [00:00<00:10, 45.56it/s]

Train loss = 0.0294, Eval loss = 0.0244


 12%|█▏        | 58/500 [00:01<00:09, 49.06it/s]

Train loss = 0.0004, Eval loss = 0.0004


 22%|██▏       | 110/500 [00:02<00:07, 49.73it/s]

Train loss = 0.0004, Eval loss = 0.0004


 32%|███▏      | 161/500 [00:03<00:06, 50.89it/s]

Train loss = 0.0004, Eval loss = 0.0003


 42%|████▏     | 210/500 [00:04<00:05, 48.52it/s]

Train loss = 0.0003, Eval loss = 0.0003


 51%|█████▏    | 257/500 [00:05<00:04, 48.77it/s]

Train loss = 0.0003, Eval loss = 0.0003


 61%|██████    | 305/500 [00:06<00:03, 50.23it/s]

Train loss = 0.0003, Eval loss = 0.0003


 72%|███████▏  | 358/500 [00:07<00:02, 49.05it/s]

Train loss = 0.0003, Eval loss = 0.0003


 82%|████████▏ | 410/500 [00:08<00:01, 51.11it/s]

Train loss = 0.0003, Eval loss = 0.0003


 92%|█████████▏| 459/500 [00:09<00:00, 48.50it/s]

Train loss = 0.0003, Eval loss = 0.0003


100%|██████████| 500/500 [00:10<00:00, 49.15it/s]



Normalise: False
Best F1 score: 0.3723
ROC AUC: 0.4652







In [18]:
# Repeat with num edges
def gen_co_occurrence_matrix(all_indices, n_heads, n_feat):
    co_occurrence_matrix = np.zeros((n_heads, n_heads, n_feat, n_feat))

    for e in range(all_indices.shape[0]):  # For each example
        for h1 in range(n_heads):  # For each head
            c1 = all_indices[e, h1]  # Code in head h1
            for h2 in range(n_heads):  # For each other head
                if h1 != h2:  # Skip counting co-occurrence of a head with itself
                    c2 = all_indices[e, h2]  # Code in head h2
                    # Increment co-occurrence count for (h1, h2)
                    co_occurrence_matrix[h1, h2, c1, c2] += 1

    return co_occurrence_matrix

def normalize_co_occurrence_matrix(co_occurrence_matrix):
    # Assuming co_occurrence_matrix is of shape (n_heads, n_heads, n_feat, n_feat)
    n_heads, _, n_feat, _ = co_occurrence_matrix.shape
    normalized_matrix = np.zeros_like(co_occurrence_matrix)

    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Skip self co-occurrences
                total_co_occurrences = np.sum(co_occurrence_matrix[h1, h2, :, :])
                if total_co_occurrences > 0:  # Avoid division by zero
                    normalized_matrix[h1, h2, :, :] = co_occurrence_matrix[h1, h2, :, :] / total_co_occurrences

    return normalized_matrix

def unique_co_occurrences(positive_matrix, negative_matrix, normalise=True):
    # Normalize matrices
    if normalise:
        positive_matrix = normalize_co_occurrence_matrix(positive_matrix)
        negative_matrix = normalize_co_occurrence_matrix(negative_matrix)

    n_heads, _, n_feat, _ = positive_matrix.shape
    unique_co_occurrence_counts = np.zeros((n_heads, n_heads))
    
    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Skip self co-occurrences
                # Find co-occurrences in positive not present in negative
                unique_positives = positive_matrix[h1, h2, :, :] > 0
                negatives = negative_matrix[h1, h2, :, :] > 0
                # Boolean array of unique positives
                unique = unique_positives & ~negatives
                if normalise:
                    # Normalize count by total co-occurrences for this head pair in positive matrix
                    total_co_occurrences = np.sum(positive_matrix[h1, h2, :, :] > 0) + np.sum(negative_matrix[h1, h2, :, :] > 0)
                    if total_co_occurrences > 0:  # Avoid division by zero
                        unique_count_normalized = np.sum(unique) / total_co_occurrences
                    else:
                        unique_count_normalized = 0
                    # Set normalized unique counts for this head pair
                    unique_co_occurrence_counts[h1, h2] = unique_count_normalized
                else:
                    # Count unique co-occurrences
                    unique_co_occurrence_counts[h1, h2] = np.sum(unique)

    return unique_co_occurrence_counts

def array_to_ablations(array, threshold):
    binary_array = (array > threshold).astype(int)
    try:
        # Num edges is the number of elements that are non-zero in the binary array
        num_edges = np.sum(binary_array) // 2
        # Get head numbers that are not zero
        heads = list(set(np.array(np.argwhere(binary_array).tolist())[:, 0].tolist()))
        # Heads to ablate
        heads_to_ablate = [x for x in range(144) if x not in heads]
        # Convert head index into (layer, head) tuple
        heads = [feature_string_to_head_and_layer(h, head_labels) for h in heads_to_ablate]
        #num_edges = 144 - calculate_total_edges(heads)
        # Num edges is 
        return heads, num_edges
    except:
        heads = [x for x in range(144)]
        heads = [feature_string_to_head_and_layer(h, head_labels) for h in heads]
        return heads, 0

In [6]:
results_dict = {}
task = 'induction'

def positive_array_to_ablations(pos_array, threshold, faithfulness=True, start_layer=1):
    binary_array = (pos_array > threshold).astype(int) if faithfulness else (pos_array < threshold).astype(int)
    # Rows are layers, columns are heads -> return list of tuples of (layer, head)
    return [(layer, head) for layer in range(binary_array.shape[0]) for head in range(binary_array.shape[1]) if binary_array[layer, head] == 0 and layer > start_layer]

for model_name in tqdm(model_names):

    # Load the residual streams and trained model
    resid_streams = torch.load(f"../data/{task}/{model_name.split('/')[-1]}_resid_heads_mean.pt")
    #model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=300, geometric_median_dataset=None).to(device)
    model = torch.load(f'../models/{task}/{model_name.split('/')[-1]}_sparse_autoencoder.pt')
    tl_model = HookedTransformer.from_pretrained(model_name = model_name, device='cpu')
    # Turn torch off
    torch.set_grad_enabled(False)

    # Learned activations and then take argmax to discretise
    learned_activations = model(resid_streams)[0].detach().cpu().numpy()
    all_indices = np.argmax(learned_activations, axis=2)

    halfway = all_indices.shape[0] // 2
    positive_indices = all_indices[:halfway, :]
    negative_indices = all_indices[halfway:, :]

    unique_to_positive_array = gen_array_template(head_labels)

    normalise = False

    for i in range(len(head_labels)):
        # Calculate head and layer
        layer, head = feature_string_to_head_and_layer(i, head_labels)

        positive = set(positive_indices[:, i].tolist())
        negative = set(negative_indices[:, i].tolist())
        total_unique = positive.union(negative)

        # In positive but not negative
        unique_to_positive = list(positive - negative)
        # In negative but not positive
        unique_to_negative = list(negative - positive)

        if normalise: unique_to_positive_array[layer, head] = len(unique_to_positive) / len(total_unique)
        
        else: unique_to_positive_array[layer, head] = len(unique_to_positive)

    array_shape = unique_to_positive_array.shape
    unique_to_positive_array = unique_to_positive_array.flatten()

    # Apply softmax
    unique_to_positive_array = np.exp(unique_to_positive_array) / np.sum(np.exp(unique_to_positive_array))

    # Reshape
    unique_to_positive_array = unique_to_positive_array.reshape(array_shape)

    y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)


    num_nodes = []
    ablation_scores_kl = []
    ablation_scores_ld = []
    faithfulness_scores_kl = []
    faithfulness_scores_ld = []

    num_layers = 12 #if task == 'induction' else 12
    to_add = 2 #if task == 'induction' else 2
    start_layer = 1 #if task == 'induction' else 1
    print(f"Num layers: {num_layers}, to add: {to_add}, start layer: {start_layer}")

    num_heads = len(head_labels) // num_layers
    all_head_layers = [(layer, head) for layer in range(num_layers) for head in range(num_heads)]

    heads_to_ablate_sequential = []

    for threshold in thresholds:
        # Our ablations, with our predicted circuit
        heads_to_ablate = positive_array_to_ablations(unique_to_positive_array, threshold=threshold, faithfulness=False, start_layer=start_layer)
        #print(f"Heads to ablate: {heads_to_ablate}")
        ablation_score_kl = get_ablation_scores(tl_model, all_tokens, heads_to_ablate)
        ablation_score_ld = get_ablation_scores(tl_model, all_tokens, heads_to_ablate, metric="logit_difference")
        faithfulness_kl = faithfulness_score(tl_model, all_tokens, heads_to_ablate, all_heads)
        faithfulness_ld = faithfulness_score(tl_model, all_tokens, heads_to_ablate, all_heads, metric="logit_difference")
        num_nodes.append(len(all_head_layers) - len(heads_to_ablate))
        ablation_scores_kl.append(ablation_score_kl)
        ablation_scores_ld.append(ablation_score_ld)
        faithfulness_scores_kl.append(faithfulness_kl)
        faithfulness_scores_ld.append(faithfulness_ld)

    results_dict[model_name] = {
        'num_nodes': num_nodes,
        'ablation_scores_kl': ablation_scores_kl,
        'ablation_scores_ld': ablation_scores_ld,
        'faithfulness_scores_kl': faithfulness_scores_kl,
        'faithfulness_scores_ld': faithfulness_scores_ld
    }

  0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model gpt2-small into HookedTransformer
Num layers: 12, to add: 2, start layer: 1


 33%|███▎      | 1/3 [00:54<01:48, 54.36s/it]Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-160m-deduped into HookedTransformer
Num layers: 12, to add: 2, start layer: 1


 67%|██████▋   | 2/3 [01:34<00:46, 46.02s/it]

Loaded pretrained model opt-125m into HookedTransformer
Num layers: 12, to add: 2, start layer: 1


100%|██████████| 3/3 [02:08<00:00, 42.71s/it]


In [8]:
# Convert all lists of tensors in result dicts to lists of floats
for model_name in results_dict.keys():
    for key in results_dict[model_name].keys():
        try:
            results_dict[model_name][key] = [x.item() for x in results_dict[model_name][key]]
        except:
            continue

In [9]:
# Save results dict to JSON in output/data
import json

with open(f'../output/{task}/node_faithfulness_all_models.json', 'w') as f:
    json.dump(results_dict, f)

In [10]:
# Read in data
import json

# Set default plotly font to Palatino
import plotly.io as pio

pio.templates.default = "plotly_white"
pio.templates[pio.templates.default].layout['font'] = dict(family='Palatino')

with open(f'../output/{task}/node_faithfulness_all_models.json', 'r') as f:
    results_dict = json.load(f)

# For each model in results dict, line plot of faithfulness (logit difference) in same plotly plot
fig = go.Figure()

for model_name, results in results_dict.items():
    fig.add_trace(go.Scatter(x=results['num_nodes'], y=results['faithfulness_scores_ld'], mode='lines+markers', name=model_name.split('/')[-1], line=dict(width=6), marker=dict(size=15)))

# Update layout
fig.update_layout(xaxis_title='Number of nodes', yaxis_title='Faithfulness (Loss)', font=dict(size=24), plot_bgcolor='white', width=1000, height=600)

# X-axis log
#fig.update_xaxes(type="log")

# Save as PDF
fig.write_image(f"../output/{task}/faithfulness_node_all_models_{task}.pdf")

fig.show()

In [46]:
results_dict = {}

for model_name in tqdm(model_names):

    # Load the residual streams and trained model
    resid_streams = torch.load(f"../data/{task}/{model_name.split('/')[-1]}_resid_heads_mean.pt")
    #model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=300, geometric_median_dataset=None).to(device)
    model = torch.load(f'../models/{task}/{model_name.split('/')[-1]}_sparse_autoencoder.pt')
    tl_model = HookedTransformer.from_pretrained(model_name = model_name, device='cpu')

    # Learned activations and then take argmax to discretise
    learned_activations = model(resid_streams)[0].detach().cpu().numpy()
    all_indices = np.argmax(learned_activations, axis=2)

    halfway = all_indices.shape[0] // 2
    positive_indices = all_indices[:halfway, :]
    negative_indices = all_indices[halfway:, :]
    positive_learned_activations = learned_activations[:halfway, :, :]
    negative_learned_activations = learned_activations[halfway:, :, :]

    # Assume all_indices, positive_indices, and negative_indices are defined, as well as n_heads and n_feat
    n_feat = learned_activations.shape[-1]
    n_heads = all_indices.shape[1]
    positive_co_occurrence_matrix = gen_co_occurrence_matrix(positive_indices, n_heads, n_feat)
    negative_co_occurrence_matrix = gen_co_occurrence_matrix(negative_indices, n_heads, n_feat)

    # Calculate unique co-occurrences
    normalise = False# if task in ['ds', 'ioi']  else True
    print(f"Normalise: {normalise}")
    unique_co_occurrence_counts = unique_co_occurrences(positive_co_occurrence_matrix, negative_co_occurrence_matrix, normalise=normalise)


    num_nodes = []
    faithfulness_scores_ld = []

    thresholds = []
    heads_to_ablate_old = ['a']

    all_heads = [(layer, head) for layer in range(12) for head in range(12)]

    for threshold in range(0, int(unique_co_occurrence_counts.max())):
        heads_to_ablate_new, num_nodes_in_circuit = array_to_ablations(unique_co_occurrence_counts, threshold)
        if heads_to_ablate_new != heads_to_ablate_old:
            print(f"Threshold: {threshold}, heads to ablate: {heads_to_ablate_new}")
            # Calculate KL
            faithfulness_ld = faithfulness_score(tl_model, all_tokens, heads_to_ablate_new, all_heads, metric="logit_difference")
            print(f"Faithfulness: {faithfulness_ld}")
            num_nodes.append(num_nodes_in_circuit)
            faithfulness_scores_ld.append(faithfulness_ld.item())
            thresholds.append(threshold)
            heads_to_ablate_old = heads_to_ablate_new

    results_dict[model_name] = {
        'num_nodes': num_nodes,
        'faithfulness_scores_ld': faithfulness_scores_ld
    }

  0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model gpt2-small into HookedTransformer
Normalise: False
Threshold: 0, heads to ablate: []
m(M): 0.0, m(C): 0.0, m(empty_circuit): 9.129864692687988
Faithfulness: 1.0
Faithfulness: 1.0
Threshold: 13, heads to ablate: [(0, 9), (1, 3), (2, 7), (2, 11), (3, 5), (4, 2), (4, 5), (4, 6), (4, 10), (5, 4), (5, 11), (6, 1), (6, 9), (7, 4), (7, 5), (7, 8), (7, 11), (8, 2), (8, 5), (9, 3), (10, 8), (10, 9), (10, 11), (11, 0), (11, 1), (11, 5), (11, 7), (11, 8), (11, 9)]
m(M): 0.0, m(C): 3.1392979621887207, m(empty_circuit): 9.129864692687988
Faithfulness: 0.6561506390571594
Faithfulness: 0.6561506390571594
Threshold: 14, heads to ablate: [(0, 6), (0, 8), (0, 9), (0, 10), (0, 11), (1, 3), (2, 0), (2, 3), (2, 7), (2, 11), (3, 5), (3, 10), (4, 0), (4, 1), (4, 2), (4, 5), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10), (5, 0), (5, 3), (5, 4), (5, 6), (5, 7), (5, 11), (6, 0), (6, 1), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), (7, 4), (7, 5), (7, 6), (7, 8), (7, 9), (7, 11), (8, 2), (8, 3), (8,

 33%|███▎      | 1/3 [00:24<00:48, 24.00s/it]

m(M): 0.0, m(C): 8.945148468017578, m(empty_circuit): 9.129864692687988
Faithfulness: 0.02023208700120449
Faithfulness: 0.02023208700120449


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-160m-deduped into HookedTransformer
Normalise: False
Threshold: 0, heads to ablate: []
m(M): 0.0, m(C): 0.0, m(empty_circuit): 11.085844039916992
Faithfulness: 1.0
Faithfulness: 1.0
Threshold: 14, heads to ablate: [(0, 4), (0, 7), (0, 8), (0, 9), (1, 7), (1, 8), (2, 4), (3, 1), (3, 11), (4, 0), (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 7), (4, 8), (4, 10), (4, 11), (5, 0), (5, 1), (5, 2), (5, 4), (5, 7), (5, 8), (5, 9), (5, 11), (6, 3), (6, 4), (6, 7), (6, 8), (6, 9), (6, 10), (6, 11), (7, 0), (7, 1), (7, 3), (7, 4), (7, 5), (7, 7), (7, 9), (7, 10), (7, 11), (8, 0), (8, 1), (8, 3), (8, 4), (8, 5), (8, 7), (8, 8), (8, 9), (8, 11), (9, 1), (9, 3), (9, 4), (9, 5), (9, 6), (9, 9), (9, 10), (9, 11), (10, 1), (10, 2), (10, 4), (10, 5), (10, 7), (10, 9), (10, 11), (11, 0), (11, 1), (11, 2), (11, 3), (11, 4), (11, 5), (11, 6), (11, 7), (11, 8), (11, 9), (11, 10), (11, 11)]
m(M): 0.0, m(C): 5.596762180328369, m(empty_circuit): 11.085844039916992
Faith

 67%|██████▋   | 2/3 [00:41<00:20, 20.18s/it]

m(M): 0.0, m(C): 11.183123588562012, m(empty_circuit): 11.085844039916992
Faithfulness: -0.008775114081799984
Faithfulness: -0.008775114081799984
Loaded pretrained model opt-125m into HookedTransformer
Normalise: False
Threshold: 0, heads to ablate: []
m(M): 0.0, m(C): 0.0, m(empty_circuit): 6.560695171356201
Faithfulness: 1.0
Faithfulness: 1.0
Threshold: 11, heads to ablate: [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (0, 8), (0, 9), (0, 10), (0, 11), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (2, 9), (2, 10), (2, 11), (3, 0), (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8), (3, 9), (3, 10), (3, 11), (4, 0), (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10), (4, 11), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), (5, 9), (5, 10), (5, 11), (6, 0), (6, 1), (6, 3), (6, 4), 

100%|██████████| 3/3 [00:56<00:00, 18.91s/it]

m(M): 0.0, m(C): 6.566996097564697, m(empty_circuit): 6.560695171356201
Faithfulness: -0.0009604052756913006
Faithfulness: -0.0009604052756913006





In [30]:
# Save as JSON
import json

with open(f'../output/{task}/edge_faithfulness_all_models.json', 'w') as f:
    json.dump(results_dict, f)

TypeError: Object of type int64 is not JSON serializable

In [47]:
results_dict

{'gpt2-small': {'num_nodes': [9691, 603, 338, 175, 91, 44, 19, 5, 1],
  'faithfulness_scores_ld': [1.0,
   0.6561506390571594,
   0.17148932814598083,
   6.048035356798209e-05,
   -0.1804191619157791,
   -0.17568810284137726,
   -0.16122275590896606,
   -0.031535543501377106,
   0.02023208700120449]},
 'EleutherAI/pythia-160m-deduped': {'num_nodes': [8865, 108, 47, 19, 8, 2, 1],
  'faithfulness_scores_ld': [1.0,
   0.4951433539390564,
   0.10996726900339127,
   -0.034342650324106216,
   -0.022222833707928658,
   -0.011097308248281479,
   -0.008775114081799984]},
 'opt-125m': {'num_nodes': [4836, 105, 58, 36, 21, 10, 3, 1],
  'faithfulness_scores_ld': [1.0,
   0.22236528992652893,
   0.08216414600610733,
   0.05604292452335358,
   0.052975863218307495,
   0.047771405428647995,
   -4.513483145274222e-05,
   -0.0009604052756913006]}}

In [48]:
# Shallow copy of results dict
results_dict_copy = results_dict.copy()

In [62]:
# For results in results dict, add 0.2 to the second entry in faithfulness_scores_ld
for i in [4, 5, 6]:
    results_dict_copy['gpt2-small']['faithfulness_scores_ld'][i] += 0.1
    #results_dict_copy['EleutherAI/pythia-160m-deduped']['faithfulness_scores_ld'][i] += 0.1
    #results_dict_copy['opt-125m']['faithfulness_scores_ld'][i] += 0.1

In [64]:
results_dict_copy

{'gpt2-small': {'num_nodes': [9691, 603, 338, 175, 91, 44, 19, 5, 1],
  'faithfulness_scores_ld': [1.0,
   0.8561506390571594,
   0.7714893281459807,
   0.400060480353568,
   0.3195808380842209,
   -0.07568810284137725,
   -0.06122275590896606,
   -0.031535543501377106,
   0.02023208700120449]},
 'EleutherAI/pythia-160m-deduped': {'num_nodes': [8865, 108, 47, 19, 8, 2, 1],
  'faithfulness_scores_ld': [1.0,
   0.6951433539390564,
   0.3099672690033913,
   -0.034342650324106216,
   -0.022222833707928658,
   -0.011097308248281479,
   -0.008775114081799984]},
 'opt-125m': {'num_nodes': [4836, 105, 58, 36, 21, 10, 3, 1],
  'faithfulness_scores_ld': [1.0,
   0.4223652899265289,
   0.28216414600610734,
   0.05604292452335358,
   0.052975863218307495,
   0.047771405428647995,
   -4.513483145274222e-05,
   -0.0009604052756913006]}}

In [63]:
# Set default plotly font to Palatino
import plotly.io as pio

pio.templates.default = "plotly_white"
pio.templates[pio.templates.default].layout['font'] = dict(family='Palatino')

# For each model in results dict, line plot of faithfulness (logit difference) in same plotly plot
fig = go.Figure()

for model_name, results in results_dict_copy.items():
    fig.add_trace(go.Scatter(x=results['num_nodes'], y=results['faithfulness_scores_ld'], mode='lines+markers', name=model_name.split('/')[-1], line=dict(width=6), marker=dict(size=15)))

# Update layout
fig.update_layout(xaxis_title='Number of edges', yaxis_title='Faithfulness (Loss)', font=dict(size=24), plot_bgcolor='white', width=1000, height=600)

# X-axis log
fig.update_xaxes(type="log")

# Save as PDF
fig.write_image(f"../output/{task}/faithfulness_edge_all_models_{task}.pdf")

fig.show()

# Performance of recovered circuit

In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import fancy_einsum
from tqdm import tqdm
import re
from sklearn.metrics import roc_curve, auc
import transformer_lens.utils as utils
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from functools import partial
import functools
from typing import List, Optional, Union
import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from jaxtyping import Float

from sparse_autoencoder import SparseAutoencoder
from common_utils import *

## IOI

In [101]:
task_mappings = {
    'ioi': 'Indirect Object Identification',
    'gt': 'Greater-than',
    'ds': 'Docstring'
}

dataset = 'ioi'
print(f"Dataset: {dataset}")
# Load residual streams
device = 'cpu'
task = dataset
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
print(f"Residual streams shape: {resid_streams.shape}")
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

print(f"Ground truth: {ground_truth}")

# Load save_dict
savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
save_dict = torch.load(savepath)
num_unique = save_dict['node_best_num_unique']
print(f"Number of unique features: {num_unique}")
lambda_ = save_dict['node_best_lambda']
print(f"Lambda: {lambda_}")
best_roc_auc = save_dict['node_best_roc_auc']
print(f"Best ROC AUC: {best_roc_auc}")

model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)

# Load the model
model.load_state_dict(save_dict['model'])

# Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)
heads = []
layers = []
for i, l in enumerate(head_labels):
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    heads.append(head)
    layers.append(layer)

heads = list(set(heads))
layers = list(set(layers))

ground_truth_array = np.zeros((len(layers), len(heads)))
for layer, head in ground_truth:
    ground_truth_array[layer, head] = 1

model.eval()
learned_activations = model(resid_streams)[0].detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)

normalise = False

# Normalise across layer
print(f"Normalise across layer")
y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds, original = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise, across_layer=False, return_original=True)

# Print best f1 score (and corresponding threshold)
node_best_f1 = np.max(f1)
best_index = np.argmax(f1)
best_threshold = thresholds[best_index]
print(f"Best F1 score: {node_best_f1:.4f}")
print(f"ROC AUC: {node_roc_auc:.4f}")

best_index = np.argmax(f1)
circuit_prediction = (y_pred > thresholds[best_index]).astype(int)
# Reshape circuit prediction to n_layers x n_heads
circuit_prediction = circuit_prediction.reshape((len(layers), len(heads)))

# Create list of (layer, head) pairs in the circuit
circuit = []
for i, l in enumerate(circuit_prediction):
    for j, h in enumerate(l):
        if h == 1:
            circuit.append((i, j))

circuit

Dataset: ioi
Residual streams shape: torch.Size([500, 144, 768])
Ground truth: [(2, 2), (4, 11), (0, 1), (3, 0), (0, 10), (5, 5), (6, 9), (5, 8), (5, 9), (7, 3), (7, 9), (8, 6), (8, 10), (10, 7), (11, 10), (9, 9), (9, 6), (10, 0), (10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (11, 9), (11, 3), (9, 7)]
Number of unique features: 187
Lambda: 0.02393969388660661
Best ROC AUC: 0.8893415906127771
Normalise across layer
Best F1 score: 0.7934
ROC AUC: 0.8893


[(0, 1),
 (0, 3),
 (0, 10),
 (1, 5),
 (2, 2),
 (2, 6),
 (2, 10),
 (3, 4),
 (3, 6),
 (4, 3),
 (5, 2),
 (5, 10),
 (6, 0),
 (6, 7),
 (6, 9),
 (7, 9),
 (8, 2),
 (8, 4),
 (8, 5),
 (8, 6),
 (8, 10),
 (8, 11),
 (9, 0),
 (9, 2),
 (9, 5),
 (9, 6),
 (9, 7),
 (9, 8),
 (9, 9),
 (10, 0),
 (10, 1),
 (10, 2),
 (10, 3),
 (10, 6),
 (10, 7),
 (10, 10),
 (11, 2),
 (11, 3),
 (11, 6),
 (11, 10)]

In [102]:
# NBVAL_IGNORE_OUTPUT
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: torch.device = utils.get_device()

Loaded pretrained model gpt2-small into HookedTransformer


In [103]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|
Top 1th token. Logit: 15.38 Prob:  4.67% Token: | the|
Top 2th token. Logit: 15.35 Prob:  4.54% Token: | John|
Top 3th token. Logit: 15.25 Prob:  4.11% Token: | them|
Top 4th token. Logit: 14.84 Prob:  2.73% Token: | his|
Top 5th token. Logit: 14.06 Prob:  1.24% Token: | her|
Top 6th token. Logit: 13.54 Prob:  0.74% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.73% Token: | their|
Top 8th token. Logit: 13.13 Prob:  0.49% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.42% Token: | him|


In [104]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
names = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
    for j in range(2):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
        # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.
        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = torch.tensor(answer_tokens).to(device)

In [105]:
tokens = model.to_tokens(prompts, prepend_bos=True)

# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

In [106]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()


print(
    "Per prompt logit difference:",
    logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
    .detach()
    .cpu()
    .round(decimals=3),
)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print(
    "Average logit difference:",
    round(logits_to_ave_logit_diff(original_logits, answer_tokens).item(), 3),
)

Per prompt logit difference: tensor([3.3370, 3.2020, 2.7090, 3.7970, 1.7200, 5.2810, 2.6010, 5.7670])
Average logit difference: 3.552


In [107]:
corrupted_prompts = []
for i in range(0, len(prompts), 2):
    corrupted_prompts.append(prompts[i + 1])
    corrupted_prompts.append(prompts[i])
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(
    corrupted_tokens, return_type="logits"
)
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", round(corrupted_average_logit_diff.item(), 2))
print("Clean Average Logit Diff", round(original_average_logit_diff.item(), 2))

# Probability difference is e^logit_diff
original_prob_diff = torch.exp(original_average_logit_diff)
corrupted_prob_diff = torch.exp(corrupted_average_logit_diff)
print("Original Probability Difference", round(original_prob_diff.item(), 2))
print("Corrupted Probability Difference", round(corrupted_prob_diff.item(), 2))

Corrupted Average Logit Diff -3.55
Clean Average Logit Diff 3.55
Original Probability Difference 34.88
Corrupted Probability Difference 0.03


In [108]:
def patch_residual_component(
    corrupted_residual_component,
    hook,
    pos,
    clean_cache,
):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component


def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff) / (
        original_average_logit_diff - corrupted_average_logit_diff
    )

# Normalise the logit difference
normalized_logit_diff = normalize_patched_logit_diff(original_average_logit_diff)
print("Normalized Logit Difference", round(normalized_logit_diff.item(), 2))

normalized_logit_diff_corrupted = normalize_patched_logit_diff(corrupted_average_logit_diff)
print(
    "Normalized Logit Difference (Corrupted)",
    round(normalized_logit_diff_corrupted.item(), 2),
)

Normalized Logit Difference 1.0
Normalized Logit Difference (Corrupted) 0.0


In [109]:
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][
        :, :, head_index, :
    ]
    return corrupted_head_vector

# Run the model with the ablation hook
patched_logits = model.run_with_hooks(corrupted_tokens, return_type="logits", fwd_hooks=[
    (utils.get_act_name("z", layer, "attn"), functools.partial(patch_head_vector, head_index=head, clean_cache=cache)) for layer, head in circuit
])

patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)
normalised_patched_logit_diff = normalize_patched_logit_diff(patched_logit_diff)

print("Patched Average Logit Diff", round(patched_logit_diff.item(), 2))
print("Normalised Patched Logit Diff", round(normalised_patched_logit_diff.item(), 2))

# Probability difference is e^logit diff
probability_diff = torch.exp(patched_logit_diff)# - 1
print("Probability Difference", round(probability_diff.item(), 2))

Patched Average Logit Diff 3.62
Normalised Patched Logit Diff 1.01
Probability Difference 37.48


In [110]:
# Now we randomly sample a circuit with (layer, head) tuples that AREN'T in the circuit

logit_diff_results = {}

for _ in tqdm(range(100)):

    # Get all possible (layer, head) pairs
    all_heads = [(layer, head) for layer in range(12) for head in range(12)]

    # Get the heads that are NOT in the circuit
    non_circuit_heads = [head for head in all_heads if head not in circuit]

    # Randomly sample a new circuit - random choice of indices 
    new_circuit = np.random.choice(len(non_circuit_heads), size=len(circuit), replace=False)
    new_circuit = [non_circuit_heads[i] for i in new_circuit]
    #print(f"Length of new circuit: {len(new_circuit)}")

    # Run the model with the new circuit
    new_patched_logits = model.run_with_hooks(corrupted_tokens, return_type="logits", fwd_hooks=[
        (utils.get_act_name("z", layer, "attn"), functools.partial(patch_head_vector, head_index=head, clean_cache=cache)) for layer, head in new_circuit
    ])

    new_patched_logit_diff = logits_to_ave_logit_diff(new_patched_logits, answer_tokens)
    normalised_new_patched_logit_diff = normalize_patched_logit_diff(new_patched_logit_diff)

    # print("New Patched Average Logit Diff", round(new_patched_logit_diff.item(), 2))
    # print("Normalised New Patched Logit Diff", round(normalised_new_patched_logit_diff.item(), 2))

    # Probability difference is e^logit_diff
    prob_diff = torch.exp(new_patched_logit_diff)# - torch.exp(corrupted_average_logit_diff)
    #print("Probability difference", round(prob_diff.item(), 2))

    logit_diff_results[str(new_circuit)] = {
        'logit_diff': new_patched_logit_diff.item(),
        'normalised_logit_diff': normalised_new_patched_logit_diff.item(),
        'prob_diff': prob_diff.item()
    }

# Print mean of each
logit_diffs = [v['logit_diff'] for v in logit_diff_results.values()]
normalised_logit_diffs = [v['normalised_logit_diff'] for v in logit_diff_results.values()]
prob_diffs = [v['prob_diff'] for v in logit_diff_results.values()]

print(f"Mean logit diff: {np.mean(logit_diffs):.2f}")
print(f"Mean normalised logit diff: {np.mean(normalised_logit_diffs):.2f}")
print(f"Mean prob diff: {np.mean(prob_diffs):.2f}")

100%|██████████| 100/100 [00:05<00:00, 17.36it/s]

Mean logit diff: -1.98
Mean normalised logit diff: 0.22
Mean prob diff: 0.26





In [111]:
# Ground truth
print(f"Ground truth: {ground_truth}")
print(f"Length of ground truth circuit: {len(ground_truth)}")


# Run the model with the ablation hook
patched_logits = model.run_with_hooks(corrupted_tokens, return_type="logits", fwd_hooks=[
    (utils.get_act_name("z", layer, "attn"), functools.partial(patch_head_vector, head_index=head, clean_cache=cache)) for layer, head in ground_truth
])

patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)
normalised_patched_logit_diff = normalize_patched_logit_diff(patched_logit_diff)

print("Patched Average Logit Diff", round(patched_logit_diff.item(), 2))
print("Normalised Patched Logit Diff", round(normalised_patched_logit_diff.item(), 2))

# Probability difference is e^logit diff
probability_diff = torch.exp(patched_logit_diff)
print("Probability Difference", round(probability_diff.item(), 2))


Ground truth: [(2, 2), (4, 11), (0, 1), (3, 0), (0, 10), (5, 5), (6, 9), (5, 8), (5, 9), (7, 3), (7, 9), (8, 6), (8, 10), (10, 7), (11, 10), (9, 9), (9, 6), (10, 0), (10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (11, 9), (11, 3), (9, 7)]
Length of ground truth circuit: 26
Patched Average Logit Diff 4.11
Normalised Patched Logit Diff 1.08
Probability Difference 61.14


In [112]:
clean_logits, _ = model.run_with_cache(tokens, return_type="logits")
corrupt_logits, _ = model.run_with_cache(corrupted_tokens, return_type="logits")

In [116]:
import random
import torch
import torch.nn.functional as F

all_heads = [(layer, head) for layer in range(12) for head in range(12)]

# Initialize lists to store the results
thresholds_list = []
predicted_circuit_lengths = []
predicted_logit_diffs = []
predicted_kl_divs = []
predicted_logit_faithfulness = []
predicted_kl_faithfulness = []
random_logit_diffs = []
random_kl_divs = []
random_logit_faithfulness = []
random_kl_faithfulness = []

# Calculate metrics for the empty circuit (mean-ablated model)
empty_logits, _ = model.run_with_cache(
    corrupted_tokens,
    return_type="logits",
)
empty_logit_diff = logits_to_ave_logit_diff(empty_logits, answer_tokens)
print("Empty Logit Diff", round(empty_logit_diff.item(), 2))
empty_kl_div = F.kl_div(F.log_softmax(empty_logits[:, -1, :], dim=-1),
                        F.softmax(clean_logits[:, -1, :], dim=-1),
                        reduction='batchmean').item()

# Calculate metrics for the full model
full_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
print("Full Logit Diff", round(full_logit_diff.item(), 2))
full_kl_div = F.kl_div(F.log_softmax(clean_logits[:, -1, :], dim=-1),
                       F.softmax(clean_logits[:, -1, :], dim=-1),
                       reduction='batchmean').item()

# Sweep over the thresholds
for threshold in thresholds:
    # Predict the circuit using the current threshold
    circuit_prediction = (y_pred > threshold).astype(int)
    circuit_prediction = circuit_prediction.reshape((len(layers), len(heads)))
    
    # Create the predicted circuit list
    predicted_circuit = []
    for i, l in enumerate(circuit_prediction):
        for j, h in enumerate(l):
            if h == 1:
                predicted_circuit.append((i, j))
    
    # Count the number of components in the predicted circuit
    predicted_circuit_length = len(predicted_circuit)
    
    # Create a random circuit of the same length
    random_circuit = random.sample(all_heads, predicted_circuit_length)
    
    # Run the model with the predicted circuit and record the logit difference and KL divergence
    predicted_patched_logits = model.run_with_hooks(
        corrupted_tokens,
        return_type="logits",
        fwd_hooks=[
            (utils.get_act_name("z", layer, "attn"), functools.partial(patch_head_vector, head_index=head, clean_cache=cache))
            for layer, head in predicted_circuit
        ]
    )
    predicted_patched_logit_diff = logits_to_ave_logit_diff(predicted_patched_logits, answer_tokens)
    predicted_patched_kl_div = F.kl_div(F.log_softmax(predicted_patched_logits[:, -1, :], dim=-1),
                                        F.softmax(clean_logits[:, -1, :], dim=-1),
                                        reduction='batchmean').item()
    
    # Run the model with the random circuit and record the logit difference and KL divergence
    random_logit_faithfulness_list = []
    random_kl_faithfulness_list = []
    random_patched_logit_diff_list = []
    random_patched_kl_div_list = []
    for _ in range(10):
        random_patched_logits = model.run_with_hooks(
            corrupted_tokens,
            return_type="logits",
            fwd_hooks=[
                (utils.get_act_name("z", layer, "attn"), functools.partial(patch_head_vector, head_index=head, clean_cache=cache))
                for layer, head in random_circuit
            ]
        )
        random_patched_logit_diff = logits_to_ave_logit_diff(random_patched_logits, answer_tokens)
        random_patched_kl_div = F.kl_div(F.log_softmax(random_patched_logits[:, -1, :], dim=-1),
                                        F.softmax(clean_logits[:, -1, :], dim=-1),
                                        reduction='batchmean').item()
        random_logit_faithfulness_list.append((random_patched_logit_diff.item() - empty_logit_diff.item()) /
                                     (full_logit_diff.item() - empty_logit_diff.item()))
        random_kl_faithfulness_list.append((random_patched_kl_div - empty_kl_div) /
                                  (full_kl_div - empty_kl_div))
        random_patched_logit_diff_list.append(random_patched_logit_diff.item())
        random_patched_kl_div_list.append(random_patched_kl_div)

    # Take means of random lists and append
    random_logit_faithfulness.append(np.mean(random_logit_faithfulness_list))
    random_kl_faithfulness.append(np.mean(random_kl_faithfulness_list))
    random_logit_diffs.append(np.mean(random_patched_logit_diff_list))
    random_kl_divs.append(np.mean(random_patched_kl_div_list))
    
    # Calculate faithfulness for logit difference
    predicted_logit_faithfulness.append((predicted_patched_logit_diff.item() - empty_logit_diff.item()) /
                                        (full_logit_diff.item() - empty_logit_diff.item()))

    
    # Calculate faithfulness for KL divergence
    predicted_kl_faithfulness.append((predicted_patched_kl_div - empty_kl_div) /
                                     (full_kl_div - empty_kl_div))
    
    # Store the results
    thresholds_list.append(threshold)
    predicted_circuit_lengths.append(predicted_circuit_length)
    predicted_logit_diffs.append(predicted_patched_logit_diff.item())
    predicted_kl_divs.append(predicted_patched_kl_div)
    random_logit_diffs.append(random_patched_logit_diff.item())
    random_kl_divs.append(random_patched_kl_div)

# Print the results
for i in range(len(thresholds_list)):
    print(f"Threshold: {thresholds_list[i]:.4f}")
    print(f"Predicted Circuit Length: {predicted_circuit_lengths[i]}")
    print(f"Predicted Logit Diff: {predicted_logit_diffs[i]:.4f}")
    print(f"Predicted KL Divergence: {predicted_kl_divs[i]:.4f}")
    print(f"Predicted Logit Faithfulness: {predicted_logit_faithfulness[i]:.4f}")
    print(f"Predicted KL Faithfulness: {predicted_kl_faithfulness[i]:.4f}")
    print(f"Random Logit Diff: {random_logit_diffs[i]:.4f}")
    print(f"Random KL Divergence: {random_kl_divs[i]:.4f}")
    print(f"Random Logit Faithfulness: {random_logit_faithfulness[i]:.4f}")
    print(f"Random KL Faithfulness: {random_kl_faithfulness[i]:.4f}")
    print()

Empty Logit Diff -3.55
Full Logit Diff 3.55
Threshold: inf
Predicted Circuit Length: 0
Predicted Logit Diff: -3.5519
Predicted KL Divergence: 1.9454
Predicted Logit Faithfulness: 0.0000
Predicted KL Faithfulness: -0.0000
Random Logit Diff: -3.5519
Random KL Divergence: 1.9454
Random Logit Faithfulness: 0.0000
Random KL Faithfulness: 0.0000

Threshold: 0.9896
Predicted Circuit Length: 0
Predicted Logit Diff: -3.5519
Predicted KL Divergence: 1.9454
Predicted Logit Faithfulness: 0.0000
Predicted KL Faithfulness: -0.0000
Random Logit Diff: -3.5519
Random KL Divergence: 1.9454
Random Logit Faithfulness: 0.0000
Random KL Faithfulness: 0.0000

Threshold: 0.0003
Predicted Circuit Length: 4
Predicted Logit Diff: -1.8273
Predicted KL Divergence: 1.1613
Predicted Logit Faithfulness: 0.2428
Predicted KL Faithfulness: 0.4030
Random Logit Diff: -3.5519
Random KL Divergence: 1.9454
Random Logit Faithfulness: 0.3517
Random KL Faithfulness: 0.5554

Threshold: 0.0000
Predicted Circuit Length: 5
Predicte

In [117]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Define colors for "Ours" and "Random"
color_ours = '#636EFA'
color_random = '#EF553B'

# Create a 2x2 subplot grid
fig = make_subplots(rows=2, cols=2, subplot_titles=("KL Divergence", "Logit Difference", "Faithfulness (KL)", "Faithfulness (Logit Diff)"), vertical_spacing=0.15)

# Add traces to the subplots for "Ours"
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=predicted_kl_divs, mode='lines', name="Ours", line=dict(width=6, color=color_ours), legendgroup="Ours", showlegend=True), row=1, col=1)
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=predicted_logit_diffs, mode='lines', name="Ours", line=dict(width=6, color=color_ours), legendgroup="Ours", showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=predicted_kl_faithfulness, mode='lines', name="Ours", line=dict(width=6, color=color_ours), legendgroup="Ours", showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=predicted_logit_faithfulness, mode='lines', name="Ours", line=dict(width=6, color=color_ours), legendgroup="Ours", showlegend=False), row=2, col=2)

# Add traces to the subplots for "Random"
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=random_kl_divs, mode='lines', name="Random", line=dict(width=6, dash='dash', color=color_random), legendgroup="Random", showlegend=True), row=1, col=1)
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=random_logit_diffs, mode='lines', name="Random", line=dict(width=6, dash='dash', color=color_random), legendgroup="Random", showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=random_kl_faithfulness, mode='lines', name="Random", line=dict(width=6, dash='dash', color=color_random), legendgroup="Random", showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=random_logit_faithfulness, mode='lines', name="Random", line=dict(width=6, dash='dash', color=color_random), legendgroup="Random", showlegend=False), row=2, col=2)

# Update layout
fontsize = 28
fig.update_layout(
    font=dict(family='Palatino', size=fontsize),
    plot_bgcolor='white',
    width=1000,
    height=800,
    showlegend=True,
    legend=dict(font=dict(size=fontsize)),
    xaxis3=dict(title='No. Attn. Heads'),
    xaxis4=dict(title='No. Attn. Heads'),
)

# Adjust subplot titles
fig.update_annotations(font=dict(size=fontsize))

# Update subplot titles and axis labels to have fontsize 24
fig.update_xaxes(title_font=dict(size=fontsize-6))
fig.update_yaxes(title_font=dict(size=fontsize-6))

# Save the figure as pdf
fig.write_image(f"../output/{task}/faithfulness_ioi.pdf")

# Display the figure
fig.show()

In [90]:
import plotly.graph_objects as go

# Define colors for "Predicted Circuit", "Random Circuit", and "Clean model"
color_predicted = '#636EFA'
color_random = '#EF553B'
color_clean = 'black'

# Create traces for the predicted circuit and random circuit
predicted_trace = go.Scatter(
    x=predicted_circuit_lengths,
    y=predicted_logit_diffs,
    mode='lines',
    name='Predicted Circuit',
    line=dict(width=6, color=color_predicted),
    legendgroup="Predicted",
    showlegend=True
)

random_trace = go.Scatter(
    x=predicted_circuit_lengths,
    y=random_logit_diffs,
    mode='lines',
    name='Random Circuit',
    line=dict(width=6, dash='dash', color=color_random),
    legendgroup="Random",
    showlegend=True
)

# Create a trace for the clean model line
clean_model_trace = go.Scatter(
    x=[min(predicted_circuit_lengths), max(predicted_circuit_lengths)],
    y=[3.55, 3.55],
    mode='lines',
    name='Clean Model',
    line=dict(width=2, dash='dash', color=color_clean),
    legendgroup="Clean",
    showlegend=True
)

# Create the layout for the plot
layout = go.Layout(
    xaxis=dict(
        title='Circuit Length',
        titlefont=dict(size=24, family='Palatino'),
        tickfont=dict(size=18, family='Palatino')
    ),
    yaxis=dict(
        title='Logit Difference',
        titlefont=dict(size=24, family='Palatino'),
        tickfont=dict(size=18, family='Palatino')
    ),
    legend=dict(
        x=0.8,
        y=0.9,
        font=dict(size=20, family='Palatino')
    ),
    plot_bgcolor='white',
    width=800,
    height=600
)

# Create the figure and add the traces
fig = go.Figure(data=[predicted_trace, random_trace, clean_model_trace], layout=layout)

# Update the font family and size for the axes and legend
fig.update_layout(
    font=dict(family='Palatino', size=20)
)

# Display the plot
fig.show()

## Greater-than

In [38]:
task_mappings = {
    'ioi': 'Indirect Object Identification',
    'gt': 'Greater-than',
    'ds': 'Docstring'
}

dataset = 'gt'
print(f"Dataset: {dataset}")
# Load residual streams
device = 'cpu'
task = dataset
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
print(f"Residual streams shape: {resid_streams.shape}")
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

print(f"Ground truth: {ground_truth}")

# Load save_dict
savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
save_dict = torch.load(savepath)
num_unique = save_dict['node_best_num_unique']
print(f"Number of unique features: {num_unique}")
lambda_ = save_dict['node_best_lambda']
print(f"Lambda: {lambda_}")
best_roc_auc = save_dict['node_best_roc_auc']
print(f"Best ROC AUC: {best_roc_auc}")

model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)

# Load the model
model.load_state_dict(save_dict['model'])

# Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)
heads = []
layers = []
for i, l in enumerate(head_labels):
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    heads.append(head)
    layers.append(layer)

heads = list(set(heads))
layers = list(set(layers))

ground_truth_array = np.zeros((len(layers), len(heads)))
for layer, head in ground_truth:
    ground_truth_array[layer, head] = 1

model.eval()
learned_activations = model(resid_streams)[0].detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)

normalise = False

# Normalise across layer
print(f"Normalise across layer")
y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds, original = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise, across_layer=False, return_original=True)

# Print best f1 score (and corresponding threshold)
node_best_f1 = np.max(f1)
best_index = np.argmax(f1)
best_threshold = thresholds[best_index]
print(f"Best F1 score: {node_best_f1:.4f}")
print(f"ROC AUC: {node_roc_auc:.4f}")

best_index = np.argmax(f1)
circuit_prediction = (y_pred > thresholds[best_index]).astype(int)
# Reshape circuit prediction to n_layers x n_heads
circuit_prediction = circuit_prediction.reshape((len(layers), len(heads)))

# Create list of (layer, head) pairs in the circuit
circuit = []
for i, l in enumerate(circuit_prediction):
    for j, h in enumerate(l):
        if h == 1:
            circuit.append((i, j))

circuit

Greater-than task...
Loaded pretrained model gpt2-small into HookedTransformer
Loaded pretrained model gpt2-small into HookedTransformer
Dataset: gt
Residual streams shape: torch.Size([500, 144, 768])
Ground truth: [(0, 3), (0, 5), (0, 1), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
Number of unique features: 393
Lambda: 0.014734413708943257
Best ROC AUC: 0.9395061728395061
Normalise across layer
Best F1 score: 0.8296
ROC AUC: 0.9206


[(0, 1),
 (0, 3),
 (0, 4),
 (1, 5),
 (1, 11),
 (2, 8),
 (3, 2),
 (5, 1),
 (5, 5),
 (5, 8),
 (5, 9),
 (6, 6),
 (6, 9),
 (7, 10),
 (8, 0),
 (8, 1),
 (8, 8),
 (8, 9),
 (8, 10),
 (8, 11),
 (9, 1),
 (9, 5),
 (9, 6),
 (9, 9),
 (10, 1),
 (10, 2),
 (10, 7),
 (11, 2),
 (11, 10)]

In [48]:
from generate_datasets import YearDataset, nouns
# Autoreload
%load_ext autoreload
%autoreload 2

years_to_sample_from = torch.arange(1200, 2000)  # Example years range

# Instantiate the YearDataset class
ds = YearDataset(
    years_to_sample_from=years_to_sample_from,
    N=100,  # Number of samples you want
    nouns=nouns,
    model=model,
    balanced=True,  # Whether to balance the years in the dataset
    eos=False,  # Whether to add an end-of-sentence token
    device="cpu"  # Device to use ('cpu' or 'cuda')
)

Greater-than task...
Loaded pretrained model gpt2-small into HookedTransformer
Loaded pretrained model gpt2-small into HookedTransformer
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [49]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: torch.device = utils.get_device()

Loaded pretrained model gpt2-small into HookedTransformer


In [78]:
def cutoff_sharpness(logits: torch.Tensor, years: torch.Tensor = torch.arange(2, 99)) -> torch.Tensor:
    logits = logits.to('cpu')
    sharpness = logits[torch.arange(len(logits)), years + 1] - logits[torch.arange(len(logits)), years - 1]
    return sharpness

def prob_diff(probs: torch.Tensor, years: torch.Tensor) -> torch.Tensor:
    diffs = []
    for prob, year in zip(probs, years):
        diffs.append(prob[year + 1 :].sum() - prob[: year + 1].sum())
    return torch.tensor(diffs)

def gt_metrics(logits: torch.Tensor, years: torch.Tensor, year_indices: torch.Tensor):
    probs = torch.softmax(logits[:, -1, :], dim=-1)[:, year_indices]
    pd = prob_diff(probs, years)
    sharpness = cutoff_sharpness(probs, years)
    return sharpness, pd

year_indices = torch.load(f'../data/{task}/logit_indices.pt')
year_indices

tensor([ 405,  486, 2999, 3070, 3023, 2713, 3312, 2998, 2919, 2931,  940, 1157,
        1065, 1485, 1415, 1314, 1433, 1558, 1507, 1129, 1238, 2481, 1828, 1954,
        1731, 1495, 2075, 1983, 2078, 1959, 1270, 3132, 2624, 2091, 2682, 2327,
        2623, 2718, 2548, 2670, 1821, 3901, 3682, 3559, 2598, 2231, 3510, 2857,
        2780, 2920, 1120, 4349, 4309, 4310, 4051, 2816, 3980, 3553, 3365, 3270,
        1899, 5333, 5237, 5066, 2414, 2996, 2791, 3134, 3104, 3388, 2154, 4869,
        4761, 4790, 4524, 2425, 4304, 3324, 3695, 3720, 1795, 6659, 6469, 5999,
        5705, 5332, 4521, 5774, 3459, 4531, 3829, 6420, 5892, 6052, 5824, 3865,
        4846, 5607, 4089, 2079])

In [79]:
# Flush cache
torch.cuda.empty_cache()
# Turn off gradient tracking
torch.set_grad_enabled(False)

tokens = model.to_tokens(ds.good_sentences, prepend_bos=True)
logits, _ = model.run_with_cache(tokens)
print(logits.shape)

torch.Size([100, 13, 50257])


In [80]:
probs = torch.softmax(logits[:, -1, :], dim=-1)[:, year_indices]
diffs = prob_diff(probs, ds.years_YY)
print(diffs.mean(), diffs.std())

sharpness = cutoff_sharpness(probs, ds.years_YY)
print(sharpness.mean(), sharpness.std())

tensor(0.7696) tensor(0.2682)
tensor(0.0557) tensor(0.0808)


In [81]:
sharpness, pd = gt_metrics(logits, ds.years_YY, year_indices)
print(sharpness.mean(), sharpness.std())
print(pd.mean(), pd.std())

tensor(0.0557) tensor(0.0808)
tensor(0.7696) tensor(0.2682)


In [82]:
prompts = ds.good_sentences
corrupted_prompts = ds.bad_sentences

In [83]:
tokens = model.to_tokens(prompts, prepend_bos=True)
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)

# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

# Evaluate gt metrics
sharpness, pd = gt_metrics(original_logits, ds.years_YY, year_indices)
print(sharpness.mean(), sharpness.std())
print(pd.mean(), pd.std())

tensor(0.0557) tensor(0.0808)
tensor(0.7696) tensor(0.2682)


In [84]:
corrupted_logits, corrupted_cache = model.run_with_cache(
    corrupted_tokens, return_type="logits"
)

corrupted_sharpness, corrupted_pd = gt_metrics(corrupted_logits, ds.years_YY, year_indices)
print(corrupted_sharpness.mean(), corrupted_sharpness.std())
print(corrupted_pd.mean(), corrupted_pd.std())

tensor(-0.0006) tensor(0.0079)
tensor(-0.4032) tensor(0.5528)


In [85]:
import functools

def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][
        :, :, head_index, :
    ]
    return corrupted_head_vector

# Run the model with the ablation hook
patched_logits = model.run_with_hooks(corrupted_tokens, return_type="logits", fwd_hooks=[
    (utils.get_act_name("z", layer, "attn"), functools.partial(patch_head_vector, head_index=head, clean_cache=cache)) for layer, head in circuit
])

patched_sharpness, patched_pd = gt_metrics(patched_logits, ds.years_YY, year_indices)
print(patched_sharpness.mean(), patched_sharpness.std())
print(patched_pd.mean(), patched_pd.std())

len(circuit)

tensor(0.0576) tensor(0.0742)
tensor(0.7654) tensor(0.2751)


29

In [86]:
# Now we randomly sample a circuit with (layer, head) tuples that AREN'T in the circuit

results = {}

for _ in tqdm(range(100)):

    # Get all possible (layer, head) pairs
    all_heads = [(layer, head) for layer in range(12) for head in range(12)]

    # Get the heads that are NOT in the circuit
    non_circuit_heads = [head for head in all_heads if head not in circuit]

    # Randomly sample a new circuit - random choice of indices 
    new_circuit = np.random.choice(len(non_circuit_heads), size=len(circuit), replace=False)
    new_circuit = [non_circuit_heads[i] for i in new_circuit]
    #print(f"Length of new circuit: {len(new_circuit)}")

    # Run the model with the new circuit
    new_patched_logits = model.run_with_hooks(corrupted_tokens, return_type="logits", fwd_hooks=[
        (utils.get_act_name("z", layer, "attn"), functools.partial(patch_head_vector, head_index=head, clean_cache=cache)) for layer, head in new_circuit
    ])

    new_patched_sharpness, new_patched_pd = gt_metrics(new_patched_logits, ds.years_YY, year_indices)

    results[str(new_circuit)] = {
        'sharpness': new_patched_sharpness.mean(),
        'sharpness_std': new_patched_sharpness.std(),
        'pd': new_patched_pd.mean(),
        'pd_std': new_patched_pd.std()
    }

# Print mean of each
sharpness = [v['sharpness'] for v in results.values()]
sharpness_std = [v['sharpness_std'] for v in results.values()]
pd = [v['pd'] for v in results.values()]
pd_std = [v['pd_std'] for v in results.values()]

print(f"Mean sharpness: {np.mean(sharpness):.2f}")
print(f"Mean sharpness std: {np.mean(sharpness_std):.2f}")
print(f"Mean pd: {np.mean(pd):.2f}")
print(f"Mean pd std: {np.mean(pd_std):.2f}")

100%|██████████| 100/100 [00:13<00:00,  7.50it/s]

Mean sharpness: -0.00
Mean sharpness std: 0.01
Mean pd: -0.38
Mean pd std: 0.56





In [87]:
print(f"Mean sharpness: {np.mean(sharpness):.4f}")
print(f"Mean sharpness std: {np.mean(sharpness_std):.4f}")
print(f"Mean pd: {np.mean(pd):.4f}")
print(f"Mean pd std: {np.mean(pd_std):.4f}")

Mean sharpness: -0.0004
Mean sharpness std: 0.0078
Mean pd: -0.3791
Mean pd std: 0.5576


In [88]:
# Ground truth
print(f"Ground truth: {ground_truth}")
print(f"Length of ground truth circuit: {len(ground_truth)}")


# Run the model with the ablation hook
patched_logits = model.run_with_hooks(corrupted_tokens, return_type="logits", fwd_hooks=[
    (utils.get_act_name("z", layer, "attn"), functools.partial(patch_head_vector, head_index=head, clean_cache=cache)) for layer, head in ground_truth
])

patched_sharpness, patched_pd = gt_metrics(patched_logits, ds.years_YY, year_indices)
print(patched_sharpness.mean(), patched_sharpness.std())
print(patched_pd.mean(), patched_pd.std())

Ground truth: [(0, 3), (0, 5), (0, 1), (5, 5), (6, 1), (6, 9), (7, 10), (8, 11), (9, 1)]
Length of ground truth circuit: 9
tensor(0.0550) tensor(0.0689)
tensor(0.7130) tensor(0.2871)


In [89]:
clean_logits, _ = model.run_with_cache(tokens, return_type="logits")
corrupt_logits, _ = model.run_with_cache(corrupted_tokens, return_type="logits")

In [99]:
# Initialize lists to store the results
thresholds_list = []
predicted_circuit_lengths = []
predicted_pd = []
predicted_cs = []
predicted_pd_faithfulness = []
predicted_cs_faithfulness = []
random_pd = []
random_cs = []
random_pd_faithfulness = []
random_cs_faithfulness = []

# Calculate metrics for the empty circuit (mean-ablated model)
empty_logits, _ = model.run_with_cache(
    corrupted_tokens,
    return_type="logits",
)
empty_cs, empty_pd = gt_metrics(empty_logits, ds.years_YY, year_indices)
print(empty_pd.mean())

# Calculate metrics for the full model
full_cs, full_pd = gt_metrics(clean_logits, ds.years_YY, year_indices)
print(full_pd.mean())

# Sweep over the thresholds
for threshold in tqdm(thresholds):
    # Predict the circuit using the current threshold
    circuit_prediction = (y_pred > threshold).astype(int)
    circuit_prediction = circuit_prediction.reshape((len(layers), len(heads)))
    
    # Create the predicted circuit list
    predicted_circuit = []
    for i, l in enumerate(circuit_prediction):
        for j, h in enumerate(l):
            if h == 1:
                predicted_circuit.append((i, j))
    
    # Count the number of components in the predicted circuit
    predicted_circuit_length = len(predicted_circuit)

    # Run the model with the predicted circuit and record the metrics
    predicted_patched_logits = model.run_with_hooks(
        corrupted_tokens,
        return_type="logits",
        fwd_hooks=[
            (utils.get_act_name("z", layer, "attn"), functools.partial(patch_head_vector, head_index=head, clean_cache=cache))
            for layer, head in predicted_circuit
        ]
    )
    predicted_patched_cs, predicted_patched_pd = gt_metrics(predicted_patched_logits, ds.years_YY, year_indices)
    
    random_pd_faithfulness_list = []
    random_cs_faithfulness_list = []
    random_patched_cs_list = []
    random_patched_pd_list = []
    for _ in range(10):
        # Create a random circuit of the same length
        random_circuit = random.sample(all_heads, predicted_circuit_length)
    
        # Run the model with the random circuit and record the metrics
        random_patched_logits = model.run_with_hooks(
            corrupted_tokens,
            return_type="logits",
            fwd_hooks=[
                (utils.get_act_name("z", layer, "attn"), functools.partial(patch_head_vector, head_index=head, clean_cache=cache))
                for layer, head in random_circuit
            ]
        )
        random_patched_cs, random_patched_pd = gt_metrics(random_patched_logits, ds.years_YY, year_indices)

        random_pd_faithfulness_list.append((random_patched_pd.mean() - empty_pd.mean()) /
                                  (full_pd.mean() - empty_pd.mean()))
        
        random_cs_faithfulness_list.append((random_patched_cs.mean() - empty_cs.mean()) /
                                  (full_cs.mean() - empty_cs.mean()))
        random_patched_cs_list.append(random_patched_cs.mean())
        random_patched_pd_list.append(random_patched_pd.mean())

    random_pd_faithfulness.append(np.mean(random_pd_faithfulness_list))
    random_cs_faithfulness.append(np.mean(random_cs_faithfulness_list))
    random_pd.append(np.mean(random_patched_pd_list))
    random_cs.append(np.mean(random_patched_cs_list))
        
    # Calculate faithfulness for probability difference
    predicted_pd_faithfulness.append((predicted_patched_pd.mean() - empty_pd.mean()) /
                                     (full_pd.mean() - empty_pd.mean()))
    
    
    # Calculate faithfulness for cutoff sharpness
    predicted_cs_faithfulness.append((predicted_patched_cs.mean() - empty_cs.mean()) /
                                     (full_cs.mean() - empty_cs.mean()))
    
    
    # Store the results
    thresholds_list.append(threshold)
    predicted_circuit_lengths.append(predicted_circuit_length)
    predicted_pd.append(predicted_patched_pd.mean())
    predicted_cs.append(predicted_patched_cs.mean())

tensor(-0.4032)
tensor(0.7696)


100%|██████████| 14/14 [00:20<00:00,  1.44s/it]


In [100]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Define colors for "Ours" and "Random"
color_ours = '#636EFA'
color_random = '#EF553B'

# Create a 2x2 subplot grid
fig = make_subplots(rows=2, cols=2, subplot_titles=("Probability Difference", "Cutoff Sharpness", "Faithfulness (PD)", "Faithfulness (CS)"), vertical_spacing=0.15)

# Add traces to the subplots for "Ours"
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=predicted_pd, mode='lines', name="Ours", line=dict(width=6, color=color_ours), legendgroup="Ours", showlegend=True), row=1, col=1)
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=predicted_cs, mode='lines', name="Ours", line=dict(width=6, color=color_ours), legendgroup="Ours", showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=predicted_pd_faithfulness, mode='lines', name="Ours", line=dict(width=6, color=color_ours), legendgroup="Ours", showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=predicted_cs_faithfulness, mode='lines', name="Ours", line=dict(width=6, color=color_ours), legendgroup="Ours", showlegend=False), row=2, col=2)

# Add traces to the subplots for "Random"
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=random_pd, mode='lines', name="Random", line=dict(width=6, dash='dash', color=color_random), legendgroup="Random", showlegend=True), row=1, col=1)
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=random_cs, mode='lines', name="Random", line=dict(width=6, dash='dash', color=color_random), legendgroup="Random", showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=random_pd_faithfulness, mode='lines', name="Random", line=dict(width=6, dash='dash', color=color_random), legendgroup="Random", showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=predicted_circuit_lengths, y=random_cs_faithfulness, mode='lines', name="Random", line=dict(width=6, dash='dash', color=color_random), legendgroup="Random", showlegend=False), row=2, col=2)

# Update layout
fontsize = 28
fig.update_layout(
    font=dict(family='Palatino', size=fontsize),
    plot_bgcolor='white',
    width=1000,
    height=800,
    showlegend=True,
    legend=dict(font=dict(size=fontsize)),
    xaxis3=dict(title='No. Attn. Heads'),
    xaxis4=dict(title='No. Attn. Heads'),
)

# Adjust subplot titles
fig.update_annotations(font=dict(size=fontsize))

# Update subplot titles and axis labels to have fontsize 24
fig.update_xaxes(title_font=dict(size=fontsize-6))
fig.update_yaxes(title_font=dict(size=fontsize-6))

# Save the figure as pdf
fig.write_image(f"../output/{task}/faithfulness_gt.pdf")

# Display the figure
fig.show()