# Figures and data for paper

## Setup

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import fancy_einsum
from tqdm import tqdm
import re
from sklearn.metrics import roc_curve, auc
import transformer_lens.utils as utils
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go

from sparse_autoencoder import SparseAutoencoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.set_grad_enabled(False)

import plotly.io as pio

# Define the template with Palatino font
template = pio.templates["plotly"]
template.layout.font.family = "Palatino"

# Set the modified template as the default
pio.templates.default = template

In [3]:
import transformer_lens.utils as utils
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

In [4]:
### FUNCTION DEFINITIONS ###

# Loss function is MSE (reconstruction loss) + L1 norm of the learned activations + similarity loss
def loss_fn(decoded_activations, learned_activations, resid_streams, resid_labels, lambda_=0.01, alpha_=0.5, verbose=False):

    # RECONSTRUCTION LOSS
    recon_loss = F.mse_loss(decoded_activations, resid_streams)

    # SPARSITY LOSS
    learned_activations_flat = einops.rearrange(learned_activations, 'b s n -> (b s) n')
    sparsity_loss = torch.mean(torch.norm(learned_activations_flat, p=1, dim=1))

    # SIMILARITY LOSS
    # Pos and neg - pos is where resid_labels == 1, neg is where resid_labels == 0
    if alpha_ > 0:
        learned_activations_pos = learned_activations[resid_labels == 1, :, :]
        learned_activations_neg = learned_activations[resid_labels == 0, :, :]
        # Currently (N, S, D) and (M, S, D) -> need to be (D, S, N) and (D, S, M)
        learned_activations_pos = einops.rearrange(learned_activations_pos, 'n s d -> d s n')
        learned_activations_neg = einops.rearrange(learned_activations_neg, 'n s d -> d s n')
        pos_sim_loss = calculate_similarity_loss(learned_activations_pos, learned_activations_neg, verbose=verbose)
    else: 
        pos_sim_loss = torch.tensor(0.0)

    # combine
    return recon_loss + (lambda_ * sparsity_loss) + (alpha_ * pos_sim_loss), recon_loss, sparsity_loss, pos_sim_loss


def train(model, n_epochs, optimizer, train_streams, eval_streams, lambda_=0.01, alpha_=0.5, verbose=False):
    for epoch in tqdm(range(n_epochs)):
        model.train()
        optimizer.zero_grad()
        learned_activations, decoded_activations = model(train_streams)
        loss, recon_loss, sparsity_loss, pos_sim_loss = loss_fn(decoded_activations, learned_activations, train_streams, 
                                                                train_labels, lambda_=lambda_, alpha_=alpha_)
        loss.backward()
        optimizer.step()
        if epoch % (n_epochs // 10) == 0:
            model.eval()
            with torch.no_grad():
                eval_learned_activations, eval_decoded_activations = model(eval_streams)
                eval_loss, _, _, eval_pos_sim_loss = loss_fn(eval_decoded_activations, eval_learned_activations,
                                                             eval_streams, eval_labels, lambda_=lambda_, alpha_=alpha_, verbose=verbose)
                print(f"Train loss = {loss.item():.4f}, Eval loss = {eval_loss.item():.4f}")
    return model

def feature_string_to_head_and_layer(feature_index, head_labels):

    extraction = head_labels[feature_index]

    if 'mlp' in extraction.lower(): 
        layer = int(extraction.split('_')[0])
        head = 12
        return layer, head

    # Get head and layer e.g. 'L0H1' -> (0, 1)
    # Layer is everything after L and before H
    layer = int(re.findall(r'L(\d+)H', extraction)[0])
    # Head is everything after H
    head = int(re.findall(r'H(\d+)', extraction)[0])

    return layer, head

def gen_array_template(head_labels):

    # Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)
    heads = []
    layers = []
    for i, l in enumerate(head_labels):
        layer, head = feature_string_to_head_and_layer(i, head_labels)
        heads.append(head)
        layers.append(layer)

    heads = list(set(heads))
    layers = list(set(layers))

    return np.zeros((len(layers), len(heads)))

def softmax(x, axis):
    """Return the softmax of x (if x is a vector) or the softmax of each row (if x is a matrix)"""
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / e_x.sum(axis=axis, keepdims=True)

def gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=False, across_layer=False):
    # Negative and positive indices
    positive_indices = all_indices[:250, :]
    negative_indices = all_indices[250:, :]

    unique_to_positive_array = gen_array_template(head_labels)
    unique_to_negative_array = gen_array_template(head_labels)

    for i in range(len(head_labels)):
        # Calculate head and layer
        layer, head = feature_string_to_head_and_layer(i, head_labels)

        positive = set(positive_indices[:, i].tolist())
        negative = set(negative_indices[:, i].tolist())
        total_unique = positive.union(negative)

        # In positive but not negative
        unique_to_positive = list(positive - negative)
        # In negative but not positive
        unique_to_negative = list(negative - positive)

        if normalise:
            # Normalise by total number of unique indices
            unique_to_positive_array[layer, head] = len(unique_to_positive) / len(total_unique)
            unique_to_negative_array[layer, head] = len(unique_to_negative) / len(total_unique)
        
        else:
            # Set the values
            unique_to_positive_array[layer, head] = len(unique_to_positive)
            unique_to_negative_array[layer, head] = len(unique_to_negative)

    # If across_layer, we apply softmax to the rows (layers) of the array
    if across_layer:
        unique_to_positive_array = softmax(unique_to_positive_array, axis=0)
        unique_to_negative_array = softmax(unique_to_negative_array, axis=0)

    # Plot the ROC curve in plotly
    y_true = ground_truth_array.flatten()
    y_pred = unique_to_positive_array.flatten()

    # Normalise y_pred with softmax
    if not across_layer:
        y_pred = softmax(y_pred, axis=0)

    fpr, tpr, thresholds = roc_curve(y_true, y_pred)

    # Calculate ROC AUC
    roc_auc = auc(fpr, tpr)

    # Calculate F1
    f1 = 2 * (tpr * (1 - fpr)) / (tpr + (1 - fpr))

    return y_true, y_pred, fpr, tpr, roc_auc, f1, thresholds


def gen_co_occurrence_matrix(all_indices, n_heads, n_feat):
    co_occurrence_matrix = np.zeros((n_heads, n_heads, n_feat, n_feat))

    for e in range(all_indices.shape[0]):  # For each example
        for h1 in range(n_heads):  # For each head
            c1 = all_indices[e, h1]  # Code in head h1
            for h2 in range(n_heads):  # For each other head
                if h1 != h2:  # Skip counting co-occurrence of a head with itself
                    c2 = all_indices[e, h2]  # Code in head h2
                    # Increment co-occurrence count for (h1, h2)
                    co_occurrence_matrix[h1, h2, c1, c2] += 1

    return co_occurrence_matrix

def normalize_co_occurrence_matrix(co_occurrence_matrix):
    # Assuming co_occurrence_matrix is of shape (n_heads, n_heads, n_feat, n_feat)
    n_heads, _, n_feat, _ = co_occurrence_matrix.shape
    normalized_matrix = np.zeros_like(co_occurrence_matrix)

    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Skip self co-occurrences
                total_co_occurrences = np.sum(co_occurrence_matrix[h1, h2, :, :])
                if total_co_occurrences > 0:  # Avoid division by zero
                    normalized_matrix[h1, h2, :, :] = co_occurrence_matrix[h1, h2, :, :] / total_co_occurrences

    return normalized_matrix

def unique_co_occurrences(positive_matrix, negative_matrix, normalise=True):
    # Normalize matrices
    if normalise:
        positive_matrix = normalize_co_occurrence_matrix(positive_matrix)
        negative_matrix = normalize_co_occurrence_matrix(negative_matrix)

    n_heads, _, n_feat, _ = positive_matrix.shape
    unique_co_occurrence_counts = np.zeros((n_heads, n_heads))
    
    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Skip self co-occurrences
                # Find co-occurrences in positive not present in negative
                unique_positives = positive_matrix[h1, h2, :, :] > 0
                negatives = negative_matrix[h1, h2, :, :] > 0
                # Boolean array of unique positives
                unique = unique_positives & ~negatives
                if normalise:
                    # Normalize count by total co-occurrences for this head pair in positive matrix
                    total_co_occurrences = np.sum(positive_matrix[h1, h2, :, :] > 0) + np.sum(negative_matrix[h1, h2, :, :] > 0)
                    if total_co_occurrences > 0:  # Avoid division by zero
                        unique_count_normalized = np.sum(unique) / total_co_occurrences
                    else:
                        unique_count_normalized = 0
                    # Set normalized unique counts for this head pair
                    unique_co_occurrence_counts[h1, h2] = unique_count_normalized
                else:
                    # Count unique co-occurrences
                    unique_co_occurrence_counts[h1, h2] = np.sum(unique)

    return unique_co_occurrence_counts

def pairwise_cosine_similarities(pos_examples, neg_examples, eps=1e-12):
    """
    pos_examples = (D, S, N)
    neg_examples = (D, S, M)

    Calculate the average cosine similarity for vectors at the same sequence
    position in pos_examples and neg_examples, vectorized.
    """

    # Reshape tensors for dot product computation
    pos_examples_perm = pos_examples.permute(1, 2, 0)  # Change to shape (S, N, D) for batch processing
    neg_examples_perm = neg_examples.permute(1, 0, 2)  # Change to shape (S, D, M) for correct dot product

    # Compute dot products. Now, using einsum for clarity and correctness
    dot_products = torch.einsum('snd,sdm->snm', pos_examples_perm, neg_examples_perm)

    # Calculate magnitudes for normalization
    magnitude_p = torch.sqrt(torch.einsum('snd,snd->sn', pos_examples_perm, pos_examples_perm) + eps).unsqueeze(-1)
    magnitude_n = torch.sqrt(torch.einsum('sdm,sdm->sm', neg_examples_perm, neg_examples_perm) + eps).unsqueeze(-2)

    # Calculate cosine similarities
    cosine_similarities = dot_products / (magnitude_p * magnitude_n + eps)

    # Average the cosine similarities for each position across all N, M pairs
    average_cosine_similarities_per_position = torch.mean(cosine_similarities, dim=(1, 2))

    # Finally, average these across all sequence positions
    final_scalar = torch.mean(average_cosine_similarities_per_position)

    return final_scalar

def calculate_similarity_loss(pos_examples, neg_examples, eps=1e-12, delta=1.0, verbose=False):

    # Positive-negative
    pos_neg_scalar = pairwise_cosine_similarities(pos_examples, neg_examples, eps)
    if verbose: print(f"Pos-neg loss = {pos_neg_scalar.item():.4f}")

    # Positive-positive
    pos_pos_scalar = pairwise_cosine_similarities(pos_examples, pos_examples, eps)
    if verbose: print(f"Pos-pos loss = {pos_pos_scalar.item():.4f}")
    
    return pos_neg_scalar + (delta - (pos_pos_scalar))

def calculate_f1_score(y_true, y_pred):
    # Calculate True Positives (TP)
    TP = np.sum((y_true == 1) & (y_pred == 1))
    
    # Calculate False Positives (FP)
    FP = np.sum((y_true == 0) & (y_pred == 1))
    
    # Calculate False Negatives (FN)
    FN = np.sum((y_true == 1) & (y_pred == 0))

    # Calculate Treu Negatives (TN)
    TN = np.sum((y_true == 0) & (y_pred == 0))
    
    # Calculate Precision and Recall
    Precision = TP / (TP + FP) if (TP + FP) > 0 else 0
    Recall = TP / (TP + FN) if (TP + FN) > 0 else 0
    
    # Calculate F1 Score
    F1 = 2 * (Precision * Recall) / (Precision + Recall) if (Precision + Recall) > 0 else 0
    
    return F1, Precision, Recall, TP, FP, TN, FN

## PCA plots of activations

In [None]:
# Run each of the three datasets with and without softmax across layer 
datasets = ['gt', 'ioi']
activations = {}

def head_and_layer_to_index(layer, head, head_labels):
    # Get the index of the feature given the layer and head
    for i, l in enumerate(head_labels):
        l, h = feature_string_to_head_and_layer(i, head_labels)
        if l == layer and h == head:
            return i
    return None

for dataset in datasets:
    print(f"Dataset: {dataset}")
    # Load residual streams
    device = 'cpu'
    task = dataset
    resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
    head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
    ground_truth = torch.load(f'../data/{task}/ground_truth.pt')
    print(f"Ground truth = {ground_truth}\n")

    # Load save_dict
    savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
    save_dict = torch.load(savepath)
    num_unique = save_dict['node_best_num_unique']
    lambda_ = save_dict['node_best_lambda']
    best_roc_auc = save_dict['node_best_roc_auc']

    model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)

    # Load the model
    model.load_state_dict(save_dict['model'])

    model.eval()
    learned_activations = model(resid_streams)[0].detach().cpu().numpy()
    # Set each but the maximum activation to zero
    #learned_activations[learned_activations < np.max(learned_activations, axis=2, keepdims=True)] = 0

    all_indices = np.argmax(learned_activations, axis=2)

    head_indices = []
    for (h, l) in ground_truth:
        head_indices.append(head_and_layer_to_index(l, h, head_labels))

    activations[dataset] = {'activations': learned_activations, 'indices': all_indices, 'ground_truth': ground_truth, 'head_labels': head_labels, 'head_indices': head_indices}

In [None]:
all_indices.shape

In [None]:
for k, v in activations.items():
    print(f"Dataset: {k}")
    print(v.shape)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.decomposition import PCA

# Create a subplot grid with 1 row and 3 columns
fig = make_subplots(rows=1, cols=len(activations), subplot_titles=list(activations.keys()))

for i, (dataset, data) in enumerate(activations.items()):
    # Reshape the data to collapse n_examples and n_heads dimensions
    reshaped_data = data['activations'] #[:, head_number, :] #einops.rearrange(data['activations'], 'b s n -> (b s) n')

    head_indices_to_keep = np.array(data['head_indices'])
    print(head_indices_to_keep)
    print(reshaped_data.shape)
    reshaped_data = reshaped_data[:, head_indices_to_keep, :]
    print(reshaped_data.shape)
    b, s, n = reshaped_data.shape
    reshaped_data = einops.rearrange(reshaped_data, 'b s n -> (b s) n')
    print(reshaped_data.shape)
    
    # Perform PCA to reduce d_model to 2 dimensions
    pca = PCA(n_components=2)
    reduced_data = pca.fit_transform(reshaped_data)
    
    # Create labels for each point (1 for positive, 0 for negative)
    labels = np.repeat([1, 0], b * s // 2)
    print(labels.shape)
    
    # Create a scatter trace for the current dataset
    trace = go.Scatter(x=reduced_data[:, 0], y=reduced_data[:, 1], mode='markers', marker=dict(color=labels, colorscale='Viridis', showscale=True), name=dataset)
    
    # Add the trace to the corresponding subplot
    fig.add_trace(trace, row=1, col=i+1)

# Update the layout to display subplots side by side
fig.update_layout(title='PCA Reduction of Activations', showlegend=False)

# Update x-axis and y-axis labels for each subplot
for i in range(1, len(activations) + 1):
    fig.update_xaxes(title_text='PC1', row=1, col=i)
    fig.update_yaxes(title_text='PC2', row=1, col=i)

fig.show()

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import umap
import numpy as np
import einops

# Create a subplot grid with 1 row and 3 columns
fig = make_subplots(rows=1, cols=len(activations), subplot_titles=list(activations.keys()))

for i, (dataset, data) in enumerate(activations.items()):
    # Reshape the data to collapse n_examples and n_heads dimensions
    reshaped_data = data['activations']
    head_indices_to_keep = np.array(data['head_indices'])
    print(head_indices_to_keep)
    print(reshaped_data.shape)
    reshaped_data = reshaped_data[:, head_indices_to_keep, :]
    print(reshaped_data.shape)
    b, s, n = reshaped_data.shape
    reshaped_data = einops.rearrange(reshaped_data, 'b s n -> (b s) n')
    print(reshaped_data.shape)
    
    # Perform UMAP to reduce d_model to 2 dimensions
    umap_model = umap.UMAP(n_components=2, random_state=42)
    reduced_data = umap_model.fit_transform(reshaped_data)
    
    # Create labels for each point (1 for positive, 0 for negative)
    labels = np.repeat([1, 0], b * s // 2)
    print(labels.shape)
    
    # Create a scatter trace for the current dataset
    trace = go.Scatter(x=reduced_data[:, 0], y=reduced_data[:, 1], mode='markers', marker=dict(color=labels, colorscale='Viridis', showscale=True), name=dataset)
    
    # Add the trace to the corresponding subplot
    fig.add_trace(trace, row=1, col=i+1)

# Update the layout to display subplots side by side
fig.update_layout(title='UMAP Reduction of Activations', showlegend=False)

# Update x-axis and y-axis labels for each subplot
for i in range(1, len(activations) + 1):
    fig.update_xaxes(title_text='UMAP1', row=1, col=i)
    fig.update_yaxes(title_text='UMAP2', row=1, col=i)

fig.show()

In [None]:
# Create a subplot grid with 1 row and 3 columns
fig = make_subplots(rows=1, cols=len(activations), subplot_titles=list(activations.keys()))

for i, (dataset, data) in enumerate(activations.items()):
    # Reshape the data to collapse n_examples and n_heads dimensions
    reshaped_data = data['indices']
    
    # Perform UMAP to reduce d_model to 2 dimensions
    umap_model = umap.UMAP(n_components=2, random_state=42)
    reduced_data = umap_model.fit_transform(reshaped_data)

    # Create labels for each point (1 for positive, 0 for negative)
    labels = np.repeat([1, 0], data['indices'].shape[0] // 2)
    print(labels.shape)
    
    # Create a scatter trace for the current dataset
    trace = go.Scatter(x=reduced_data[:, 0], y=reduced_data[:, 1], mode='markers', marker=dict(color=labels, colorscale='Viridis', showscale=True), name=dataset)
    
    # Add the trace to the corresponding subplot
    fig.add_trace(trace, row=1, col=i+1)

# Update the layout to display subplots side by side
fig.update_layout(title='PCA Reduction of Indices', showlegend=False)

# Update x-axis and y-axis labels for each subplot
for i in range(1, len(activations) + 1):
    fig.update_xaxes(title_text='PC1', row=1, col=i)
    fig.update_yaxes(title_text='PC2', row=1, col=i)

fig.show()

## Softmax across layer and softmax on flattened vector plot of ROC AUC for each dataset

In [None]:
# Run each of the three datasets with and without softmax across layer 
datasets = ['gt', 'ds', 'ioi']
dataset_results = {}

for dataset in datasets:
    print(f"Dataset: {dataset}")
    # Load residual streams
    device = 'cpu'
    task = dataset
    resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
    print(f"Residual streams shape: {resid_streams.shape}")
    head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
    ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

    # Load save_dict
    savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
    save_dict = torch.load(savepath)
    num_unique = save_dict['node_best_num_unique']
    print(f"Number of unique features: {num_unique}")
    lambda_ = save_dict['node_best_lambda']
    print(f"Lambda: {lambda_}")
    best_roc_auc = save_dict['node_best_roc_auc']
    print(f"Best ROC AUC: {best_roc_auc}")

    model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)

    # Load the model
    model.load_state_dict(save_dict['model'])

    # Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)
    heads = []
    layers = []
    for i, l in enumerate(head_labels):
        layer, head = feature_string_to_head_and_layer(i, head_labels)
        heads.append(head)
        layers.append(layer)

    heads = list(set(heads))
    layers = list(set(layers))

    ground_truth_array = np.zeros((len(layers), len(heads)))
    for layer, head in ground_truth:
        ground_truth_array[layer, head] = 1

    model.eval()
    learned_activations = model(resid_streams)[0].detach().cpu().numpy()
    all_indices = np.argmax(learned_activations, axis=2)

    normalise = False if task == 'ds' or task == 'ioi' else True

    # Normalise across layer
    print(f"Normalise across layer")
    y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise, across_layer=True)

    # Print best f1 score (and corresponding threshold)
    node_best_f1 = np.max(f1)
    best_threshold = thresholds[np.argmax(f1)]
    print(f"Best F1 score: {node_best_f1:.4f}")
    print(f"ROC AUC: {node_roc_auc:.4f}")

    # Add to dataset results; first key is dataset, second key is normalise across layer
    dataset_results[dataset] = {}
    dataset_results[dataset]['across_layer'] = (node_roc_auc, node_best_f1)

    # Don't normalise across layer
    print(f"Don't normalise across layer")
    y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise, across_layer=False)

    # Print best f1 score (and corresponding threshold)
    node_best_f1 = np.max(f1)
    best_threshold = thresholds[np.argmax(f1)]
    print(f"Best F1 score: {node_best_f1:.4f}")
    print(f"ROC AUC: {node_roc_auc:.4f}\n\n")

    # Add to dataset results; first key is dataset, second key is normalise across layer
    dataset_results[dataset]['not_across_layer'] = (node_roc_auc, node_best_f1)

In [None]:
import pandas as pd
import plotly.express as px

# Multiple grouped bar chart for ROC AUC in Plotly (across layer and not across layer)
data = []
task_mappings = {
    'gt': 'Greater-Than',
    'ds': 'Docstring',
    'ioi': 'IOI'
}

for dataset in datasets:
    dataset_data = dataset_results[dataset]
    for key, value in dataset_data.items():
        if key == 'across_layer':
            label = 'Across Layer'
        else:
            label = 'Across Heads'
        data.append({'Dataset': task_mappings[dataset], 'Softmax': label, 'ROC AUC': value[0], 'F1': value[1]})

df = pd.DataFrame(data)
fig = px.bar(df, x='Dataset', y='ROC AUC', color='Softmax', barmode='group')

fig.update_layout(
    width=1000, height=600,
    margin=dict(l=50, r=50, t=100, b=100),
    #title=dict(text='ROC AUC for Different Datasets', font=dict(size=28)),
    legend_title_text='Softmax',
    legend=dict(font=dict(size=14), x=1.02, y=1),#, borderwidth=1),
    plot_bgcolor='white'
)

#fig.update_xaxes(title_text='Dataset', title_font=dict(size=24))
fig.update_yaxes(title_text='ROC AUC', title_font=dict(size=24), showgrid=True, gridwidth=1, gridcolor='lightgray')

# Remove x-axis label
fig.update_xaxes(title=None)

fig.update_traces(width=0.35, hovertemplate='ROC AUC: %{y:.3f}<extra></extra>')

# Update the font size of the x-axis and y-axis labels
fig.update_xaxes(tickfont=dict(size=24))
fig.update_yaxes(tickfont=dict(size=24))

# Make fontsize of legend bigger
fig.update_layout(legend=dict(font=dict(size=22)))

# y-axis range is 0 to 1
fig.update_yaxes(range=[0, 1])

# Save
fig.write_image(f"../output/figures/softmax_layer_head.pdf")

fig.show()

## Bar chart of normalised vs unnormalised

In [None]:
tasks = ['gt', 'ds', 'ioi']

def get_roc_from_model(task, normalise=False):
    
    # Load residual streams
    device = 'cpu'
    resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
    head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
    ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

    # Load save_dict
    if not normalise: savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
    else: savepath = f"../models/{task}/sparse_autoencoder_norm.pt"
    save_dict = torch.load(savepath)
    num_unique = save_dict['node_best_num_unique']
    lambda_ = save_dict['node_best_lambda']
    best_roc_auc = save_dict['node_best_roc_auc']

    model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
    model.load_state_dict(save_dict['model'])

    # Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)
    heads = []
    layers = []
    for i, l in enumerate(head_labels):
        layer, head = feature_string_to_head_and_layer(i, head_labels)
        heads.append(head)
        layers.append(layer)

    heads = list(set(heads))
    layers = list(set(layers))

    ground_truth_array = np.zeros((len(layers), len(heads)))
    for layer, head in ground_truth:
        ground_truth_array[layer, head] = 1


    model.eval()
    learned_activations = model(resid_streams)[0].detach().cpu().numpy()
    all_indices = np.argmax(learned_activations, axis=2)

    y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

    return node_roc_auc

roc_results = {}
for task in tasks:
    auc_no_norm = get_roc_from_model(task, normalise=False)
    auc_norm = get_roc_from_model(task, normalise=True)
    roc_results[task] = (auc_no_norm, auc_norm)
    

In [None]:
# Bar chart, grouped by dataset, of norm vs. no norm
data = []
task_mappings = {
    'gt': 'Greater-Than',
    'ds': 'Docstring',
    'ioi': 'IOI'
}

for task in tasks:
    data.append({'Dataset': task_mappings[task], 'ROC AUC': roc_results[task][0], 'Normalised': 'No'})
    data.append({'Dataset': task_mappings[task], 'ROC AUC': roc_results[task][1], 'Normalised': 'Yes'})

df = pd.DataFrame(data)

fig = px.bar(df, x='Dataset', y='ROC AUC', color='Normalised', barmode='group')

fig.update_layout(
    width=1000, height=600,
    margin=dict(l=50, r=50, t=100, b=100),
    #title=dict(text='ROC AUC for Different Datasets', font=dict(size=28)),
    legend_title_text='Normalised',
    legend=dict(font=dict(size=14), x=1.02, y=1),#, borderwidth=1),
    plot_bgcolor='white'
)

#fig.update_xaxes(title_text='Dataset', title_font=dict(size=24))

fig.update_yaxes(title_text='ROC AUC', title_font=dict(size=24), showgrid=True, gridwidth=1, gridcolor='lightgray')

# Remove x-axis label
fig.update_xaxes(title=None)

fig.update_traces(width=0.35, hovertemplate='ROC AUC: %{y:.3f}<extra></extra>')

# Update the font size of the x-axis and y-axis labels
fig.update_xaxes(tickfont=dict(size=24))
fig.update_yaxes(tickfont=dict(size=24))

# Make fontsize of legend bigger
fig.update_layout(legend=dict(font=dict(size=22)))

# y-axis range is 0 to 1
fig.update_yaxes(range=[0, 1])

# Save
fig.write_image(f"../output/figures/norm.pdf")

fig.show()

## Side-by-side plots of the number of unique positive codes in each head, the binary plots at the best threshold and the ground truth circuit array (node level)

Requires:
* Ground truth array, stored in data
* Residual streams, stored in data
* Trained model, stored in models

For all tasks: docstring, greater-than, indirect object identification, tracr-reverse.

In [102]:
# Load residual streams
device = 'cpu'
task = 'gt'
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
print(f"Residual streams shape: {resid_streams.shape}")
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

# Load save_dict
savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
save_dict = torch.load(savepath)
num_unique = save_dict['node_best_num_unique']
print(f"Number of unique features: {num_unique}")
lambda_ = save_dict['node_best_lambda']
print(f"Lambda: {lambda_}")
best_roc_auc = save_dict['node_best_roc_auc']
print(f"Best ROC AUC: {best_roc_auc}")


model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)

# Load the model
model.load_state_dict(save_dict['model'])

Residual streams shape: torch.Size([500, 144, 768])
Number of unique features: 393
Lambda: 0.014734413708943257
Best ROC AUC: 0.9395061728395061


<All keys matched successfully>

In [103]:
heads = []
layers = []
for i, l in enumerate(head_labels):
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    heads.append(head)
    layers.append(layer)

heads = list(set(heads))
layers = list(set(layers))

ground_truth_array = np.zeros((len(layers), len(heads)))
for layer, head in ground_truth:
    ground_truth_array[layer, head] = 1

model.eval()
learned_activations = model.encoder(resid_streams).detach().cpu().numpy()
print(f"Learned activations shape: {learned_activations.shape}")

all_indices = np.argmax(learned_activations, axis=2)
print(f"All indices shape: {all_indices.shape}")    

positive_indices = all_indices[:250, :]
negative_indices = all_indices[250:, :]

unique_to_positive_array = gen_array_template(head_labels)
unique_to_negative_array = gen_array_template(head_labels)

normalise = False

for i in range(len(head_labels)):
    # Calculate head and layer
    layer, head = feature_string_to_head_and_layer(i, head_labels)

    positive = set(positive_indices[:, i].tolist())
    negative = set(negative_indices[:, i].tolist())
    total_unique = positive.union(negative)

    # In positive but not negative
    unique_to_positive = list(positive - negative)
    # In negative but not positive
    unique_to_negative = list(negative - positive)

    if normalise:
        # Normalise by total number of unique indices
        unique_to_positive_array[layer, head] = len(unique_to_positive) / len(total_unique)
        unique_to_negative_array[layer, head] = len(unique_to_negative) / len(total_unique)
    
    else:
        # Set the values
        unique_to_positive_array[layer, head] = len(unique_to_positive)
        unique_to_negative_array[layer, head] = len(unique_to_negative)


# Plotly imshow indices
fig = px.imshow(all_indices.T, title='All indices')
fig.show()


Learned activations shape: (500, 144, 393)
All indices shape: (500, 144)


In [104]:
# Assuming unique_to_positive_array and ground_truth_array are numpy arrays
arrays_sequence = [unique_to_positive_array, ground_truth_array]

# Titles for each subplot
titles = ["Codes Unique to +ve Examples", "Ground-Truth Circuit"]

# Create a subplot layout: 1 row, 2 columns, with specified horizontal spacing
fig = make_subplots(rows=1, cols=2, subplot_titles=titles, horizontal_spacing=0.15)

# Add each array as a separate heatmap trace with a unique color scale
fig.add_trace(go.Heatmap(z=arrays_sequence[0], colorscale='Blues', coloraxis="coloraxis1"), row=1, col=1)
fig.add_trace(go.Heatmap(z=arrays_sequence[1], colorscale='Reds', coloraxis="coloraxis2"), row=1, col=2)

fontsize=24

# Manually specify the layout for each color axis to include a colorbar
fig.update_layout(
    coloraxis1=dict(colorscale='Blues', colorbar=dict(x=0.43)),
    coloraxis2=dict(colorscale='Blues', colorbar=dict(x=1)),
    width=900,  # Adjust the figure width
    height=400,  # Adjust the figure height to ensure the aspect ratio makes the heatmaps appear square
    margin=dict(l=50, r=50, t=50, b=50),  # Adjust margins if necessary
    title_font=dict(size=fontsize+4),  # Increase title font size
    font=dict(size=fontsize-4),  # Update global font size, affects tick labels and legend
)

# Update axes titles and reverse the y-axis with larger font sizes
fig.update_xaxes(title_text="Head", title_font=dict(size=20), tickfont=dict(size=16), row=1, col=1)
fig.update_xaxes(title_text="Head", title_font=dict(size=20), tickfont=dict(size=16), row=1, col=2)
fig.update_yaxes(title_text="Layer", title_font=dict(size=20), tickfont=dict(size=16), autorange="reversed", row=1, col=1)
fig.update_yaxes(title_font=dict(size=fontsize), tickfont=dict(size=fontsize-4), autorange="reversed", row=1, col=2)

# For subplot titles
for annotation in fig['layout']['annotations']: 
    annotation['font'] = dict(size=fontsize)

# Save the figure
fig.write_image(f'../output/{task}/{task}_unique_to_positive_vs_ground_truth.pdf')

# Show the figure
fig.show()

In [105]:
# Assuming unique_to_positive_array and ground_truth_array are numpy arrays
arrays_sequence = [unique_to_positive_array, ground_truth_array]

# Titles for each subplot
titles = ["Codes Unique to +ve Examples", "Ground-Truth Circuit"]

# Create a subplot layout: 1 row, 2 columns, with specified horizontal spacing
fig = make_subplots(rows=1, cols=2, subplot_titles=titles, horizontal_spacing=0.2)

# Add each array as a separate heatmap trace with a unique color scale
fig.add_trace(go.Heatmap(z=arrays_sequence[0], colorscale='Blues', coloraxis="coloraxis1"), row=1, col=1)
fig.add_trace(go.Heatmap(z=arrays_sequence[1], colorscale='Reds', coloraxis="coloraxis2"), row=1, col=2)

fontsize=28

# Manually specify the layout for each color axis to include a colorbar
fig.update_layout(
    coloraxis1=dict(colorscale='Blues', colorbar=dict(x=0.41)),
    coloraxis2=dict(colorscale='Blues', colorbar=dict(x=1)),
    width=900,  # Adjust the figure width
    height=400,  # Adjust the figure height to ensure the aspect ratio makes the heatmaps appear square
    margin=dict(l=50, r=50, t=50, b=50),  # Adjust margins if necessary
    title_font=dict(size=fontsize+4),  # Increase title font size
    font=dict(size=fontsize-4),  # Update global font size, affects tick labels and legend
)

# Update axes titles and reverse the y-axis with larger font sizes
fig.update_xaxes(title_text="Head", title_font=dict(size=24), tickfont=dict(size=20), row=1, col=1)
fig.update_xaxes(title_text="Head", title_font=dict(size=24), tickfont=dict(size=20), row=1, col=2)
fig.update_yaxes(title_text="Layer", title_font=dict(size=24), tickfont=dict(size=20), autorange="reversed", row=1, col=1)
fig.update_yaxes(title_font=dict(size=fontsize), tickfont=dict(size=fontsize-4), autorange="reversed", row=1, col=2)

# Remove y-axis tick labels for both
fig.update_yaxes(showticklabels=False, row=1, col=1)
fig.update_yaxes(showticklabels=False, row=1, col=2)

# For subplot titles
for annotation in fig['layout']['annotations']: 
    annotation['font'] = dict(size=fontsize)

# Save the figure
fig.write_image(f'../output/{task}/{task}_unique_to_positive_vs_ground_truth.pdf')

# Show the figure
fig.show()

In [106]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

# Create a subplot layout: 1 row, 1 column
fig = make_subplots(rows=1, cols=1)

# Add the heatmap trace
fig.add_trace(go.Heatmap(z=unique_to_positive_array, colorscale='Blues'), row=1, col=1)

# Iterate over ground_truth_array to add shapes where needed
for i in range(ground_truth_array.shape[0]):
    for j in range(ground_truth_array.shape[1]):
        if ground_truth_array[i, j] == 1:
            # Calculate cell boundaries for shape
            x0, x1 = j-0.5, j + 0.5
            y0, y1 = i-0.5, i + 0.5
            fig.add_shape(type="rect",
                          x0=x0, y0=y0, x1=x1, y1=y1,
                          line=dict(color="Red", width=3),
                          row=1, col=1)

# Update layout settings
fontsize = 28
fig.update_layout(
    width=600,  # Adjust figure width
    height=550,  # Adjust figure height
    margin=dict(l=50, r=50, t=50, b=50),  # Adjust margins
    title_font=dict(size=fontsize+4),  # Title font size
    font=dict(size=fontsize-4),  # Global font size
)

# Update axes titles and reverse the y-axis
fig.update_xaxes(title_text="Head", title_font=dict(size=24), tickfont=dict(size=20), row=1, col=1)
fig.update_yaxes(title_text="Layer", title_font=dict(size=24), tickfont=dict(size=20), autorange="reversed", row=1, col=1)

# Make y-axis tick labels integers
fig.update_yaxes(tickvals=np.arange(ground_truth_array.shape[0]), row=1, col=1)

# Save the figure
fig.write_image(f'../output/{task}/{task}_unique_to_positive_vs_ground_truth.pdf')

# Show the figure
fig.show()

## Co-occurrence plots

In [None]:
# Load residual streams
device = 'cpu'
task = 'gt'
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
print(f"Residual streams shape: {resid_streams.shape}")
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

# Load save_dict
savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
# num_unique = 300
# lambda_ = 0.02
save_dict = torch.load(savepath)
num_unique = save_dict['node_best_num_unique']
print(f"Number of unique features: {num_unique}")
lambda_ = save_dict['node_best_lambda']
print(f"Lambda: {lambda_}")
best_roc_auc = save_dict['node_best_roc_auc']
print(f"Best ROC AUC: {best_roc_auc}")


model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)

# Load the model
model.load_state_dict(save_dict['model'])
#model.load_state_dict(torch.load(savepath).state_dict())

In [None]:
# Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)
heads = []
layers = []
for i, l in enumerate(head_labels):
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    heads.append(head)
    layers.append(layer)

heads = list(set(heads))
layers = list(set(layers))

ground_truth_array = np.zeros((len(layers), len(heads)))
for layer, head in ground_truth:
    ground_truth_array[layer, head] = 1

model.eval()
learned_activations = model(resid_streams)[0].detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)

normalise = False# if task in ['ds', 'ioi'] else True
print(f"Normalise: {normalise}")
y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

# Print best f1 score (and corresponding threshold)
node_best_f1 = np.max(f1)
best_threshold = thresholds[np.argmax(f1)]
print(f"Best F1 score: {node_best_f1:.4f}")

# Print ROC AUC
print(f"ROC AUC: {node_roc_auc:.4f}")

In [None]:
import numpy as np

def gen_co_occurrence_matrix(all_indices, n_heads, n_feat):
    co_occurrence_matrix = np.zeros((n_heads, n_heads, n_feat, n_feat))

    for e in range(all_indices.shape[0]):  # For each example
        for h1 in range(n_heads):  # For each head
            c1 = all_indices[e, h1]  # Code in head h1
            for h2 in range(n_heads):  # For each other head
                if h1 != h2:  # Skip counting co-occurrence of a head with itself
                    c2 = all_indices[e, h2]  # Code in head h2
                    # Increment co-occurrence count for (h1, h2)
                    co_occurrence_matrix[h1, h2, c1, c2] += 1

    return co_occurrence_matrix

def normalize_co_occurrence_matrix(co_occurrence_matrix):
    # Assuming co_occurrence_matrix is of shape (n_heads, n_heads, n_feat, n_feat)
    n_heads, _, n_feat, _ = co_occurrence_matrix.shape
    normalized_matrix = np.zeros_like(co_occurrence_matrix)

    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Skip self co-occurrences
                total_co_occurrences = np.sum(co_occurrence_matrix[h1, h2, :, :])
                if total_co_occurrences > 0:  # Avoid division by zero
                    normalized_matrix[h1, h2, :, :] = co_occurrence_matrix[h1, h2, :, :] / total_co_occurrences

    return normalized_matrix

def unique_co_occurrences(positive_matrix, negative_matrix, normalise=True):
    # Normalize matrices
    if normalise:
        positive_matrix = normalize_co_occurrence_matrix(positive_matrix)
        negative_matrix = normalize_co_occurrence_matrix(negative_matrix)

    n_heads, _, n_feat, _ = positive_matrix.shape
    unique_co_occurrence_counts = np.zeros((n_heads, n_heads))
    
    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Skip self co-occurrences
                # Find co-occurrences in positive not present in negative
                unique_positives = positive_matrix[h1, h2, :, :] > 0
                negatives = negative_matrix[h1, h2, :, :] > 0
                # Boolean array of unique positives
                unique = unique_positives & ~negatives
                if normalise:
                    # Normalize count by total co-occurrences for this head pair in positive matrix
                    total_co_occurrences = np.sum(positive_matrix[h1, h2, :, :] > 0) + np.sum(negative_matrix[h1, h2, :, :] > 0)
                    if total_co_occurrences > 0:  # Avoid division by zero
                        unique_count_normalized = np.sum(unique) / total_co_occurrences
                    else:
                        unique_count_normalized = 0
                    # Set normalized unique counts for this head pair
                    unique_co_occurrence_counts[h1, h2] = unique_count_normalized
                else:
                    # Count unique co-occurrences
                    unique_co_occurrence_counts[h1, h2] = np.sum(unique)

    return unique_co_occurrence_counts

# Learned activations and then take argmax to discretise
learned_activations = model(resid_streams)[0].detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)

positive_indices = all_indices[:250, :]
negative_indices = all_indices[250:, :]
positive_learned_activations = learned_activations[:250, :, :]
negative_learned_activations = learned_activations[250:, :, :]

# Assume all_indices, positive_indices, and negative_indices are defined, as well as n_heads and n_feat
n_feat = learned_activations.shape[-1]
n_heads = all_indices.shape[1]
positive_co_occurrence_matrix = gen_co_occurrence_matrix(positive_indices, n_heads, n_feat)
negative_co_occurrence_matrix = gen_co_occurrence_matrix(negative_indices, n_heads, n_feat)

# Calculate unique co-occurrences
normalise = False# if task in ['ds', 'ioi']  else True
print(f"Normalise: {normalise}")
unique_co_occurrence_counts = unique_co_occurrences(positive_co_occurrence_matrix, negative_co_occurrence_matrix, normalise=normalise)

In [None]:
# Sort (head, head) pairs by descending unique co-occurrence counts
sorted_indices = np.argsort(unique_co_occurrence_counts.flatten())[::-1]
sorted_indices = np.unravel_index(sorted_indices, unique_co_occurrence_counts.shape)
# Zip them together to create a list of (head, head) pairs
sorted_head_pairs = list(zip(sorted_indices[0], sorted_indices[1]))

circuit_components = []
for i, (h1, h2) in enumerate(sorted_head_pairs):
    (l1, h1) = feature_string_to_head_and_layer(h1, head_labels)
    (l2, h2) = feature_string_to_head_and_layer(h2, head_labels)
    circuit_components.append((l1, h1))
    circuit_components.append((l2, h2))

k = 105000

y_pred = np.zeros_like(ground_truth_array)
for (l, h) in circuit_components[:k]:
    y_pred[l, h] += 1

print(y_pred)

y_true = ground_truth_array.flatten()
y_pred = y_pred.flatten()

# Normalise y_pred with softmax
def softmax_edge(x): return np.exp(x) / np.sum(np.exp(x), axis=0)
y_pred = softmax_edge(y_pred)

fpr, tpr, thresholds = roc_curve(y_true, y_pred)

# Calculate ROC AUC
roc_auc = auc(fpr, tpr)

# Calculate F1
f1 = 2 * (tpr * (1 - fpr)) / (tpr + (1 - fpr))

# Print
print(f"ROC AUC: {roc_auc:.4f}")
best_f1 = np.max(f1)
print(f"Best F1 score: {best_f1:.4f}")

In [None]:
task_mappings = {
    'gt': 'Greater-than',
    'ioi': 'Indirect Object Identification',
    'ds': 'Docstring',
}

# Convert ground_truth to head number
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth_heads = []
for layer, head in ground_truth:
    head_str = f"L{layer}H{head}"
    head_num = head_labels.index(head_str)
    ground_truth_heads.append(head_num)

# Imshow plot of unique co-occurrences
fig = go.Figure(data=go.Heatmap(
    z=unique_co_occurrence_counts,
    x=[f"Head {i}" for i in range(n_heads)],
    y=[f"Head {i}" for i in range(n_heads)],
    colorscale='Viridis'))

fig.update_layout(
    title=f'Unique Positive Co-occurrences for {task_mappings[task]}',
    xaxis_title='Head 1',
    yaxis_title='Head 2',
    width=800,
    height=800)

fig.show()

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

task_mappings = {
    'gt': 'Greater-than',
    'ioi': 'Indirect Object Identification',
    'ds': 'Docstring',
}

# Convert ground_truth to head number
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth_heads = []
for layer, head in ground_truth:
    head_str = f"L{layer}H{head}"
    head_num = head_labels.index(head_str)
    ground_truth_heads.append(head_num)

# Create the subplots
fig = make_subplots(rows=1, cols=2, subplot_titles=("Unique Positive Co-occurrences", "Ground Truth Heads"))

# Heatmap of unique co-occurrences
heatmap = go.Heatmap(
    z=unique_co_occurrence_counts,
    x=[f"Head {i}" for i in range(n_heads)],
    y=[f"Head {i}" for i in range(n_heads)],
    colorscale='YlGnBu',
    colorbar=dict(title='Count', titleside='right', titlefont=dict(size=14), tickfont=dict(size=12))
)
fig.add_trace(heatmap, row=1, col=1)

# Ground truth heads visualization
ground_truth_matrix = np.zeros((n_heads, n_heads))
for i in range(n_heads):
    for j in range(n_heads):
        if i in ground_truth_heads and j in ground_truth_heads:
            ground_truth_matrix[i, j] = 1
        elif i in ground_truth_heads or j in ground_truth_heads:
            ground_truth_matrix[i, j] = 0.5
        else:
            ground_truth_matrix[i, j] = 0

ground_truth_heatmap = go.Heatmap(
    z=ground_truth_matrix,
    x=[f"Head {i}" for i in range(n_heads)],
    y=[f"Head {i}" for i in range(n_heads)],
    colorscale=[[0, 'white'], [0.5, 'lightblue'], [1, 'blue']],
    showscale=False
)
fig.add_trace(ground_truth_heatmap, row=1, col=2)

fontsize=28
fig.update_layout(
    title=dict(text=f'{task_mappings[task]}', font=dict(size=fontsize)),
    width=1200,
    height=600,
    plot_bgcolor='white',
    font=dict(size=fontsize-12),
    xaxis=dict(tickfont=dict(size=fontsize-12)),
    yaxis=dict(tickfont=dict(size=fontsize-12)),
    xaxis2=dict(tickfont=dict(size=fontsize-12)),
    yaxis2=dict(tickfont=dict(size=fontsize-12))
)

# Set subplot titles to be larger
for annotation in fig['layout']['annotations']:
    annotation['font'] = dict(size=fontsize-6)

# Save the figure
fig.write_image(f'../output/{task}/{task}_unique_to_positive_vs_ground_truth_edges.pdf')

fig.show()

## ROC curves for the best hyperparameters 

Requires:
* Ground truth array, stored in data
* Residual streams, stored in data
* Trained model, stored in models

For all tasks: docstring, greater-than, indirect object identification, tracr-reverse.

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

task_mappings = {
    'gt': 'Greater-Than',
    'ds': 'Docstring',
    'ioi': 'Indirect Object Identification'
}

fontsize = 20
# Create a figure with three subplots side-by-side (title font size = fontsize, axis tick label size = fontsize-4)
fig = make_subplots(rows=1, cols=3, subplot_titles=[f"{task_mappings[task]}" for task in ['gt', 'ds', 'ioi']], horizontal_spacing=0.05)

# Define colors for each task
colors = ['blue', 'green', 'red']

for i, task in enumerate(['gt', 'ds', 'ioi']):
    # Load residual streams and other data for the current task
    resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
    head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
    ground_truth = torch.load(f'../data/{task}/ground_truth.pt')
    
    # Load save_dict for the current task
    savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
    save_dict = torch.load(savepath)
    num_unique = save_dict['node_best_num_unique']
    lambda_ = save_dict['node_best_lambda']
    best_roc_auc = save_dict['node_best_roc_auc']
    
    # Load the model for the current task
    model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
    model.load_state_dict(save_dict['model'])
    
    # Generate the ground truth array and other data for the current task
    heads = []
    layers = []
    for j, l in enumerate(head_labels):
        layer, head = feature_string_to_head_and_layer(j, head_labels)
        heads.append(head)
        layers.append(layer)
    heads = list(set(heads))
    layers = list(set(layers))
    ground_truth_array = np.zeros((len(layers), len(heads)))
    for layer, head in ground_truth:
        ground_truth_array[layer, head] = 1
    
    model.eval()
    learned_activations = model(resid_streams)[0].detach().cpu().numpy()
    all_indices = np.argmax(learned_activations, axis=2)
    normalise = False# if task == 'ds' or task == 'ioi' else True
    
    y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

    # Create ROC curve data with right-angle steps
    roc_curve_data = []
    for fpr_val, tpr_val in zip(fpr, tpr):
        roc_curve_data.append((fpr_val, tpr_val))
        roc_curve_data.append((fpr_val, tpr_val))
    roc_curve_data = np.array(roc_curve_data)
    
    # Add the ROC curve for the current task to the corresponding subplot
    fig.add_trace(go.Scatter(x=roc_curve_data[:, 0], y=roc_curve_data[:, 1], mode='lines', name=f" Ours ({node_roc_auc:.2f})", line=dict(color=colors[i], width=4), line_shape='hv'), row=1, col=i+1)
    fig.add_shape(type='line', line=dict(dash='dash', width=4), x0=0, x1=1, y0=0, y1=1, row=1, col=i+1)
    fig.update_xaxes(title_text="False Positive Rate", range=[-0.01, 1.01], row=1, col=i+1)
    if i == 0:
        fig.update_yaxes(title_text="True Positive Rate", range=[-0.01, 1.01], row=1, col=i+1)
    else:
        fig.update_yaxes(range=[-0.01, 1.01], row=1, col=i+1)

# Update the layout of the figure
fig.update_layout(width=1500, height=500, showlegend=True)

# Make all the fonts bigger, including axis tick labels
fontsize = 20
fig.update_layout(font=dict(size=fontsize))
fig.update_xaxes(tickfont=dict(size=fontsize-4))
fig.update_yaxes(tickfont=dict(size=fontsize-4))
# Make subplot titles bigger
for annotation in fig['layout']['annotations']: 
    annotation['font'] = dict(size=fontsize+4)

# Save the figure
fig.write_image(f'../output/figures/roc_curves.pdf')

fig.show()

In [None]:
import json

def get_conmy_auc_curves(task: str, node: bool = True):

    file_mapping = {
        'ioi': ['acdc-ioi-kl_div-False-0.json', 'acdc-ioi-logit_diff-False-0.json'],
        'docstring': ['acdc-docstring-kl_div-False-0.json', 'acdc-docstring-docstring_metric-False-0.json'],
        'greaterthan': ['acdc-greaterthan-kl_div-False-0.json', 'acdc-greaterthan-greaterthan-False-0.json']
    }

    task_type = 'node' if node else 'edge'

    file = file_mapping[task][0]
    with open(f"../data/conmy/{file}") as f:
        data = json.load(f)

    results = data['trained']['random_ablation'][task]['kl_div']['ACDC']
    edge_fpr = np.array(results[f'{task_type}_fpr'])
    edge_tpr = np.array(results[f'{task_type}_tpr'])
    edge_fpr = np.sort(edge_fpr)
    edge_tpr = np.sort(edge_tpr)

    # Create ROC curve data with right-angle steps
    roc_curve_data = []
    for fpr, tpr in zip(edge_fpr, edge_tpr):
        roc_curve_data.append((fpr, tpr))
        roc_curve_data.append((fpr, tpr))
    roc_curve_data = np.array(roc_curve_data)

    # Calculate ROC AUC
    roc_auc = auc(edge_fpr, edge_tpr)


    # Do the same for acdc-ioi-logit-diff-False-0.json
    file = file_mapping[task][1]
    with open(f"../data/conmy/{file}") as f:
        data = json.load(f)

    results = data['trained']['random_ablation'][task]['logit_diff']['ACDC']
    edge_fpr = np.array(results[f'{task_type}_fpr'])
    edge_tpr = np.array(results[f'{task_type}_tpr'])
    edge_fpr = np.sort(edge_fpr)
    edge_tpr = np.sort(edge_tpr)

    # Create ROC curve data with right-angle steps
    roc_curve_data_ld = []
    for fpr, tpr in zip(edge_fpr, edge_tpr):
        roc_curve_data_ld.append((fpr, tpr))
        roc_curve_data_ld.append((fpr, tpr))

    roc_curve_data_ld = np.array(roc_curve_data_ld)

    # Calculate ROC AUC
    roc_auc_ld = auc(edge_fpr, edge_tpr)

    return roc_curve_data, roc_auc, roc_curve_data_ld, roc_auc_ld

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

task_mappings = {
    'gt': 'Greater-Than',
    'ds': 'Docstring',
    'ioi': 'Indirect Object Identification'
}

fontsize = 20
# Create a figure with three subplots side-by-side (title font size = fontsize, axis tick label size = fontsize-4)
fig = make_subplots(rows=1, cols=3, subplot_titles=[f"{task_mappings[task]}" for task in ['gt', 'ds', 'ioi']], horizontal_spacing=0.05)

# Define colors for "Ours" and "ACDC"
our_color = 'blue'
acdc_color = 'lightblue'

task_to_conmy = {
    'gt': 'greaterthan',
    'ds': 'docstring',
    'ioi': 'ioi'
}

for i, task in enumerate(['gt', 'ds', 'ioi']):
    # Load residual streams and other data for the current task
    resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
    head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
    ground_truth = torch.load(f'../data/{task}/ground_truth.pt')
    
    # Load save_dict for the current task
    savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
    save_dict = torch.load(savepath)
    num_unique = save_dict['node_best_num_unique']
    lambda_ = save_dict['node_best_lambda']
    best_roc_auc = save_dict['node_best_roc_auc']
    
    # Load the model for the current task
    model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
    model.load_state_dict(save_dict['model'])
    
    # Generate the ground truth array and other data for the current task
    heads = []
    layers = []
    for j, l in enumerate(head_labels):
        layer, head = feature_string_to_head_and_layer(j, head_labels)
        heads.append(head)
        layers.append(layer)
    heads = list(set(heads))
    layers = list(set(layers))
    ground_truth_array = np.zeros((len(layers), len(heads)))
    for layer, head in ground_truth:
        ground_truth_array[layer, head] = 1
    
    model.eval()
    learned_activations = model(resid_streams)[0].detach().cpu().numpy()
    all_indices = np.argmax(learned_activations, axis=2)
    normalise = False# if task == 'ds' or task == 'ioi' else True
    
    y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

    # Create ROC curve data with right-angle steps
    roc_curve_data = []
    for fpr_val, tpr_val in zip(fpr, tpr):
        roc_curve_data.append((fpr_val, tpr_val))
        roc_curve_data.append((fpr_val, tpr_val))
    roc_curve_data = np.array(roc_curve_data)
    
    # Add the ROC curve for the current task to the corresponding subplot
    showlegend = True if i == 0 else False
    fig.add_trace(go.Scatter(x=roc_curve_data[:, 0], y=roc_curve_data[:, 1], mode='lines', name=f"Ours", line=dict(color=our_color, width=4), line_shape='hv', legendgroup='Ours', showlegend=showlegend), row=1, col=i+1)
    fig.add_shape(type='line', line=dict(dash='dash', width=4), x0=0, x1=1, y0=0, y1=1, row=1, col=i+1)
    fig.update_xaxes(title_text="False Positive Rate", range=[-0.01, 1.01], row=1, col=i+1)
    if i == 0:
        fig.update_yaxes(title_text="True Positive Rate", range=[-0.01, 1.01], row=1, col=i+1)
    else:
        fig.update_yaxes(range=[-0.01, 1.01], row=1, col=i+1)

    # Get Conmy's ROC curves for the current task
    roc_curve_data_conmy, roc_auc_conmy, roc_curve_data_ld_conmy, roc_auc_ld_conmy = get_conmy_auc_curves(task_to_conmy[task], node=True)
    
    # Add Conmy's ROC curves to the corresponding subplot
    fig.add_trace(go.Scatter(x=roc_curve_data_ld_conmy[:, 0], y=roc_curve_data_ld_conmy[:, 1], mode='lines', name=f"ACDC", line=dict(color=acdc_color, width=4), line_shape='hv', legendgroup='ACDC', showlegend=showlegend), row=1, col=i+1)

# Update the layout of the figure
fig.update_layout(width=1500, height=500, showlegend=True)

# Make all the fonts bigger, including axis tick labels
fontsize = 20
fig.update_layout(font=dict(size=fontsize))
fig.update_xaxes(tickfont=dict(size=fontsize-4))
fig.update_yaxes(tickfont=dict(size=fontsize-4))
# Make subplot titles bigger
for annotation in fig['layout']['annotations']: 
    annotation['font'] = dict(size=fontsize+4)

# Save the figure
fig.write_image(f'../output/figures/node_roc_curves.pdf')

fig.show()

In [None]:
task_mappings = {
    'gt': 'Greater-Than',
    'ds': 'Docstring',
    'ioi': 'Indirect Object Identification'
}

# Load all the data
tasks = ['gt', 'ds', 'ioi']
edge_tpr_fpr = {}
for task in tasks:
    path = f'../output/{task}/fpr_tpr.npy'
    tpr_fpr_arrays = np.load(path, allow_pickle=True)
    fpr, tpr = tpr_fpr_arrays
    edge_tpr_fpr[task] = (fpr, tpr)

fontsize = 20
# Create a figure with three subplots side-by-side (title font size = fontsize, axis tick label size = fontsize-4)
fig = make_subplots(rows=1, cols=3, subplot_titles=[f"{task_mappings[task]}" for task in ['gt', 'ds', 'ioi']], horizontal_spacing=0.05)

# Define colors for "Ours" and "ACDC"
our_color = 'green'
acdc_color = 'lightgreen'

task_to_conmy = {
    'gt': 'greaterthan',
    'ds': 'docstring',
    'ioi': 'ioi'
}

for i, task in enumerate(['gt', 'ds', 'ioi']):
    # Get tpr and fpr for the current task
    fpr, tpr = edge_tpr_fpr[task]

    # Create ROC curve data with right-angle steps
    roc_curve_data = []
    for fpr_val, tpr_val in zip(fpr, tpr):
        roc_curve_data.append((fpr_val, tpr_val))
        roc_curve_data.append((fpr_val, tpr_val))
    roc_curve_data = np.array(roc_curve_data)

    # Add the ROC curve for the current task to the corresponding subplot
    showlegend = True if i == 0 else False
    fig.add_trace(go.Scatter(x=roc_curve_data[:, 0], y=roc_curve_data[:, 1], mode='lines', name=f"Ours", line=dict(color=our_color, width=4), line_shape='hv', legendgroup='Ours', showlegend=showlegend), row=1, col=i+1)
    fig.add_shape(type='line', line=dict(dash='dash', width=4), x0=0, x1=1, y0=0, y1=1, row=1, col=i+1)
    fig.update_xaxes(title_text="False Positive Rate", range=[-0.01, 1.01], row=1, col=i+1)
    if i == 0:
        fig.update_yaxes(title_text="True Positive Rate", range=[-0.01, 1.01], row=1, col=i+1)
    else:
        fig.update_yaxes(range=[-0.01, 1.01], row=1, col=i+1)

    # Get Conmy's ROC curves for the current task
    roc_curve_data_conmy, roc_auc_conmy, roc_curve_data_ld_conmy, roc_auc_ld_conmy = get_conmy_auc_curves(task_to_conmy[task], node=False)

    # Add Conmy's ROC curves to the corresponding subplot
    fig.add_trace(go.Scatter(x=roc_curve_data_ld_conmy[:, 0], y=roc_curve_data_ld_conmy[:, 1], mode='lines', name=f"ACDC", line=dict(color=acdc_color, width=5, dash='dot'), line_shape='hv', legendgroup='ACDC', showlegend=showlegend), row=1, col=i+1)

# Update the layout of the figure
fig.update_layout(width=1500, height=500, showlegend=True)

# Make all the fonts bigger, including axis tick labels
fontsize = 20
fig.update_layout(font=dict(size=fontsize))
fig.update_xaxes(tickfont=dict(size=fontsize-4))
fig.update_yaxes(tickfont=dict(size=fontsize-4))
# Make subplot titles bigger
for annotation in fig['layout']['annotations']:
    annotation['font'] = dict(size=fontsize+4)

# Save the figure
fig.write_image(f'../output/figures/edge_roc_curves.pdf')

fig.show()

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

device='cpu'

task_mappings = {
    'gt': 'Greater-Than',
    'ds': 'Docstring',
    'ioi': 'Indirect Object Identification'
}

fontsize = 20
# Create a figure with three subplots side-by-side (title font size = fontsize, axis tick label size = fontsize-4)
fig = make_subplots(rows=1, cols=3, subplot_titles=[f"{task_mappings[task]}" for task in ['gt', 'ds', 'ioi']], horizontal_spacing=0.05)

# Define colors for "Ours" and "ACDC"
our_color_node = 'blue'
acdc_color_node = 'lightblue'
our_color_edge = 'green'
acdc_color_edge = 'lightgreen'

task_to_conmy = {
    'gt': 'greaterthan',
    'ds': 'docstring',
    'ioi': 'ioi'
}

for i, task in enumerate(['gt', 'ds', 'ioi']):
    # Load residual streams and other data for the current task (node-level)
    resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
    head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
    ground_truth = torch.load(f'../data/{task}/ground_truth.pt')
    
    # Load save_dict for the current task (node-level)
    savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
    save_dict = torch.load(savepath)
    num_unique = save_dict['node_best_num_unique']
    lambda_ = save_dict['node_best_lambda']
    best_roc_auc = save_dict['node_best_roc_auc']
    
    # Load the model for the current task (node-level)
    model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
    model.load_state_dict(save_dict['model'])
    
    # Generate the ground truth array and other data for the current task (node-level)
    heads = []
    layers = []
    for j, l in enumerate(head_labels):
        layer, head = feature_string_to_head_and_layer(j, head_labels)
        heads.append(head)
        layers.append(layer)
    heads = list(set(heads))
    layers = list(set(layers))
    ground_truth_array = np.zeros((len(layers), len(heads)))
    for layer, head in ground_truth:
        ground_truth_array[layer, head] = 1
    
    model.eval()
    learned_activations = model(resid_streams)[0].detach().cpu().numpy()
    all_indices = np.argmax(learned_activations, axis=2)
    normalise = False# if task == 'ds' or task == 'ioi' else True
    
    y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

    # Create ROC curve data with right-angle steps (node-level)
    node_roc_curve_data = []
    for fpr_val, tpr_val in zip(fpr, tpr):
        node_roc_curve_data.append((fpr_val, tpr_val))
        node_roc_curve_data.append((fpr_val, tpr_val))
    node_roc_curve_data = np.array(node_roc_curve_data)
    
    # Add the node-level ROC curve for the current task to the corresponding subplot
    showlegend = True if i == 0 else False
    fig.add_trace(go.Scatter(x=node_roc_curve_data[:, 0], y=node_roc_curve_data[:, 1], mode='lines', name=f"Node (Ours)", line=dict(color=our_color_node, width=6), line_shape='hv', legendgroup='Node-Ours', showlegend=showlegend), row=1, col=i+1)
    
    # Get Conmy's node-level ROC curves for the current task
    node_roc_curve_data_conmy, node_roc_auc_conmy, node_roc_curve_data_ld_conmy, node_roc_auc_ld_conmy = get_conmy_auc_curves(task_to_conmy[task], node=True)
    
    # Add Conmy's node-level ROC curves to the corresponding subplot
    fig.add_trace(go.Scatter(x=node_roc_curve_data_ld_conmy[:, 0], y=node_roc_curve_data_ld_conmy[:, 1], mode='lines', name=f"Node (ACDC)", line=dict(color=our_color_node, width=6, dash='dot'), line_shape='hv', legendgroup='Node-ACDC', showlegend=showlegend), row=1, col=i+1)
    
    # Get tpr and fpr for the current task (edge-level)
    path = f'../output/{task}/fpr_tpr.npy'
    tpr_fpr_arrays = np.load(path, allow_pickle=True)
    fpr, tpr = tpr_fpr_arrays
    
    # Create ROC curve data with right-angle steps (edge-level)
    edge_roc_curve_data = []
    for fpr_val, tpr_val in zip(fpr, tpr):
        edge_roc_curve_data.append((fpr_val, tpr_val))
        edge_roc_curve_data.append((fpr_val, tpr_val))
    edge_roc_curve_data = np.array(edge_roc_curve_data)
    
    # Add the edge-level ROC curve for the current task to the corresponding subplot
    fig.add_trace(go.Scatter(x=edge_roc_curve_data[:, 0], y=edge_roc_curve_data[:, 1], mode='lines', name=f"Edge (Ours)", line=dict(color=our_color_edge, width=6), line_shape='hv', legendgroup='Edge-Ours', showlegend=showlegend), row=1, col=i+1)
    
    # Get Conmy's edge-level ROC curves for the current task
    edge_roc_curve_data_conmy, edge_roc_auc_conmy, edge_roc_curve_data_ld_conmy, edge_roc_auc_ld_conmy = get_conmy_auc_curves(task_to_conmy[task], node=False)
    
    # Add Conmy's edge-level ROC curves to the corresponding subplot
    fig.add_trace(go.Scatter(x=edge_roc_curve_data_ld_conmy[:, 0], y=edge_roc_curve_data_ld_conmy[:, 1], mode='lines', name=f"Edge (ACDC)", line=dict(color=our_color_edge, width=6, dash='dot'), line_shape='hv', legendgroup='Edge-ACDC', showlegend=showlegend), row=1, col=i+1)
    
    fig.add_shape(type='line', line=dict(dash='dash', width=4), x0=0, x1=1, y0=0, y1=1, row=1, col=i+1)
    fig.update_xaxes(title_text="False Positive Rate", range=[-0.01, 1.01], row=1, col=i+1)
    if i == 0:
        fig.update_yaxes(title_text="True Positive Rate", range=[-0.01, 1.01], row=1, col=i+1)
    else:
        fig.update_yaxes(range=[-0.01, 1.01], row=1, col=i+1)

# Update the layout of the figure
fig.update_layout(width=1500, height=500, showlegend=True)

# Make all the fonts bigger, including axis tick labels
fontsize = 20
fig.update_layout(font=dict(size=fontsize))
fig.update_xaxes(tickfont=dict(size=fontsize-4))
fig.update_yaxes(tickfont=dict(size=fontsize-4))
# Make subplot titles bigger
for annotation in fig['layout']['annotations']: 
    annotation['font'] = dict(size=fontsize+4)

# Save the figure
fig.write_image(f'../output/figures/combined_roc_curves.pdf')

fig.show()

## How ROC+F1 varies with $k$ for edge-level detection

In [None]:
# Load residual streams
device = 'cpu'
task = 'ds'
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
print(f"Residual streams shape: {resid_streams.shape}")
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

# Load save_dict
savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
# num_unique = 300
# lambda_ = 0.02
save_dict = torch.load(savepath)
num_unique = save_dict['node_best_num_unique']
print(f"Number of unique features: {num_unique}")
lambda_ = save_dict['node_best_lambda']
print(f"Lambda: {lambda_}")
best_roc_auc = save_dict['node_best_roc_auc']
print(f"Best ROC AUC: {best_roc_auc}")


model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)

# Load the model
model.load_state_dict(save_dict['model'])

In [None]:
# Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)
heads = []
layers = []
for i, l in enumerate(head_labels):
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    heads.append(head)
    layers.append(layer)

heads = list(set(heads))
layers = list(set(layers))

ground_truth_array = np.zeros((len(layers), len(heads)))
for layer, head in ground_truth:
    ground_truth_array[layer, head] = 1

model.eval()
learned_activations = model(resid_streams)[0].detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)

normalise = False# if task in ['ds', 'ioi'] else True
print(f"Normalise: {normalise}")
y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

# Print best f1 score (and corresponding threshold)
node_best_f1 = np.max(f1)
best_threshold = thresholds[np.argmax(f1)]
print(f"Best F1 score: {node_best_f1:.4f}")

# Print ROC AUC
print(f"ROC AUC: {node_roc_auc:.4f}")

In [None]:
import numpy as np

def gen_co_occurrence_matrix(all_indices, n_heads, n_feat):
    co_occurrence_matrix = np.zeros((n_heads, n_heads, n_feat, n_feat))

    for e in range(all_indices.shape[0]):  # For each example
        for h1 in range(n_heads):  # For each head
            c1 = all_indices[e, h1]  # Code in head h1
            for h2 in range(n_heads):  # For each other head
                if h1 != h2:  # Skip counting co-occurrence of a head with itself
                    c2 = all_indices[e, h2]  # Code in head h2
                    # Increment co-occurrence count for (h1, h2)
                    co_occurrence_matrix[h1, h2, c1, c2] += 1

    return co_occurrence_matrix

def normalize_co_occurrence_matrix(co_occurrence_matrix):
    # Assuming co_occurrence_matrix is of shape (n_heads, n_heads, n_feat, n_feat)
    n_heads, _, n_feat, _ = co_occurrence_matrix.shape
    normalized_matrix = np.zeros_like(co_occurrence_matrix)

    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Skip self co-occurrences
                total_co_occurrences = np.sum(co_occurrence_matrix[h1, h2, :, :])
                if total_co_occurrences > 0:  # Avoid division by zero
                    normalized_matrix[h1, h2, :, :] = co_occurrence_matrix[h1, h2, :, :] / total_co_occurrences

    return normalized_matrix

def unique_co_occurrences(positive_matrix, negative_matrix, normalise=True):
    # Normalize matrices
    if normalise:
        positive_matrix = normalize_co_occurrence_matrix(positive_matrix)
        negative_matrix = normalize_co_occurrence_matrix(negative_matrix)

    n_heads, _, n_feat, _ = positive_matrix.shape
    unique_co_occurrence_counts = np.zeros((n_heads, n_heads))
    
    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Skip self co-occurrences
                # Find co-occurrences in positive not present in negative
                unique_positives = positive_matrix[h1, h2, :, :] > 0
                negatives = negative_matrix[h1, h2, :, :] > 0
                # Boolean array of unique positives
                unique = unique_positives & ~negatives
                if normalise:
                    # Normalize count by total co-occurrences for this head pair in positive matrix
                    total_co_occurrences = np.sum(positive_matrix[h1, h2, :, :] > 0) + np.sum(negative_matrix[h1, h2, :, :] > 0)
                    if total_co_occurrences > 0:  # Avoid division by zero
                        unique_count_normalized = np.sum(unique) / total_co_occurrences
                    else:
                        unique_count_normalized = 0
                    # Set normalized unique counts for this head pair
                    unique_co_occurrence_counts[h1, h2] = unique_count_normalized
                else:
                    # Count unique co-occurrences
                    unique_co_occurrence_counts[h1, h2] = np.sum(unique)

    return unique_co_occurrence_counts

# Learned activations and then take argmax to discretise
learned_activations = model(resid_streams)[0].detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)

positive_indices = all_indices[:250, :]
negative_indices = all_indices[250:, :]
positive_learned_activations = learned_activations[:250, :, :]
negative_learned_activations = learned_activations[250:, :, :]

# Assume all_indices, positive_indices, and negative_indices are defined, as well as n_heads and n_feat
n_feat = learned_activations.shape[-1]
n_heads = all_indices.shape[1]
positive_co_occurrence_matrix = gen_co_occurrence_matrix(positive_indices, n_heads, n_feat)
negative_co_occurrence_matrix = gen_co_occurrence_matrix(negative_indices, n_heads, n_feat)

# Calculate unique co-occurrences
normalise = False# if task in ['ds', 'ioi']  else True
print(f"Normalise: {normalise}")
unique_co_occurrence_counts = unique_co_occurrences(positive_co_occurrence_matrix, negative_co_occurrence_matrix, normalise=normalise)

In [None]:
# Sort (head, head) pairs by descending unique co-occurrence counts
sorted_indices = np.argsort(unique_co_occurrence_counts.flatten())[::-1]
sorted_indices = np.unravel_index(sorted_indices, unique_co_occurrence_counts.shape)
# Zip them together to create a list of (head, head) pairs
sorted_head_pairs = list(zip(sorted_indices[0], sorted_indices[1]))

circuit_components = []
for i, (h1, h2) in enumerate(sorted_head_pairs):
    (l1, h1) = feature_string_to_head_and_layer(h1, head_labels)
    (l2, h2) = feature_string_to_head_and_layer(h2, head_labels)
    circuit_components.append((l1, h1))
    circuit_components.append((l2, h2))

k = 1050

y_pred = np.zeros_like(ground_truth_array)
for (l, h) in circuit_components[:k]:
    y_pred[l, h] += 1

# Plot the ROC curve in plotly
y_true = ground_truth_array.flatten()
y_pred = y_pred.flatten()

# Normalise y_pred with softmax
def softmax_edge(x): return np.exp(x) / np.sum(np.exp(x), axis=0)
y_pred = softmax_edge(y_pred)

fpr, tpr, thresholds = roc_curve(y_true, y_pred)

# Calculate ROC AUC
roc_auc = auc(fpr, tpr)

# Calculate F1
f1 = 2 * (tpr * (1 - fpr)) / (tpr + (1 - fpr))

# Print
print(f"ROC AUC: {roc_auc:.4f}")
best_f1 = np.max(f1)
print(f"Best F1 score: {best_f1:.4f}")

In [None]:
# For varying values of k, calculate ROC AUC and F1 score
if len(circuit_components) > 10000:
    k_values = np.arange(1, len(circuit_components), 100)
else:
    k_values = np.arange(1, len(circuit_components), 10)
roc_auc_values = []
f1_values = []

for k in k_values:
    y_pred = np.zeros_like(ground_truth_array)
    for (l, h) in circuit_components[:k]:
        y_pred[l, h] += 1

    y_true = ground_truth_array.flatten()
    y_pred = y_pred.flatten()

    y_pred = softmax_edge(y_pred)

    fpr, tpr, thresholds = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)
    roc_auc_values.append(roc_auc)

    f1 = 2 * (tpr * (1 - fpr)) / (tpr + (1 - fpr))
    f1_values.append(np.max(f1))

print(f"Best F1 score across k: {np.max(f1_values):.4f}")
print(f"Best ROC AUC across k: {np.max(roc_auc_values):.4f} (at k={k_values[np.argmax(roc_auc_values)]})")

In [None]:
import plotly.graph_objects as go

task_mappings = {
    'gt': 'Greater-than',
    'ioi': 'Indirect Object Identification',
    'ds': 'Docstring',
}

# Define colors for ROC AUC and F1 Score
roc_auc_color = '#1F77B4'  # Blue
f1_score_color = '#FF7F0E'  # Orange

# Plot ROC AUC and F1 score as a function of k
fig = go.Figure()

fig.add_trace(go.Scatter(x=k_values, y=roc_auc_values, mode='lines', name='ROC AUC',
                         line=dict(color=roc_auc_color, width=5), marker=dict(size=0, color=roc_auc_color)))
fig.add_trace(go.Scatter(x=k_values, y=f1_values, mode='lines', name='F1 Score',
                         line=dict(color=f1_score_color, width=5), marker=dict(size=0, color=f1_score_color)))

fontsize=38
fig.update_layout(
    title=dict(text=f'{task_mappings[task]}', font=dict(size=fontsize)),
    xaxis=dict(title='k', tickfont=dict(size=fontsize-8), gridcolor='lightgray', gridwidth=1),
    yaxis=dict(title='Score', tickfont=dict(size=fontsize-8), gridcolor='lightgray', gridwidth=1),
    legend=dict(font=dict(size=fontsize-8)),
    width=800,
    height=500,
    plot_bgcolor='white',
    margin=dict(l=20, r=20, t=60, b=20),
)

# Set x and y axis fontsize to fontsize-8
fig.update_xaxes(title_font=dict(size=fontsize-4))
fig.update_yaxes(title_font=dict(size=fontsize-4))

# Put a horizontal line at the best F1 score
fig.add_shape(type="line", x0=k_values[0], y0=node_best_f1, x1=k_values[-1]+1, y1=node_best_f1,
              line=dict(color=f1_score_color, width=4, dash="dash"))

# Put a horizontal line at the node AUC
fig.add_shape(type="line", x0=k_values[0], y0=node_roc_auc, x1=k_values[-1]+1, y1=node_roc_auc,
              line=dict(color=roc_auc_color, width=4, dash="dash"))

# Save as pdf
fig.write_image(f"../output/{task}/{task}_k_effect.pdf")
fig.show()

## Table showing distribution of `n_feat` and `lambda` with their AUROC across datasets

Requires:
* Optuna study

For all tasks.

In [149]:
# Load the optuna studies (they're pickle files saved with joblib)
import pickle
import joblib

tasks = ['gt', 'ds', 'ioi']
studies = {}
for task in tasks:
    study = joblib.load(f'../output/{task}/optuna_study.pkl')
    studies[task] = study

In [150]:
print(studies['ds'].best_params)
print(studies['gt'].best_params)
print(studies['ioi'].best_params)

{'num_unique': 270, 'lambda_': 0.06716069945728732}
{'num_unique': 246, 'lambda_': 0.010708647930936725}
{'num_unique': 379, 'lambda_': 0.021682888772690315}


In [151]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd

acdc_means = {
    'ds': [0.938, 0.825],
    'gt': [0.766, 0.783],
    'ioi': [0.777, 0.424]
}


def param_metric_lists_from_study(study):
    # Get all trials
    trials = study.trials
    
    # Get all parameters and metrics
    params = [trial.params for trial in trials]
    metrics = [trial.values[0] for trial in trials]
    
    # Get all parameters and metrics as lists
    num_unique_list = [trial.params['num_unique'] for trial in trials]
    lambda_list = [trial.params['lambda_'] for trial in trials]
    
    return metrics, num_unique_list, lambda_list

metrics, num_unique_list, lambda_list = param_metric_lists_from_study(studies['gt'])

# Create a figure with two subplots side-by-side
fig = make_subplots(rows=1, cols=2, shared_yaxes=True, horizontal_spacing=0.05)

# Create bins for num_unique_list
num_unique_bins = np.linspace(min(num_unique_list), max(num_unique_list), 11)
num_unique_labels = [f'{num_unique_bins[i]:.0f}' for i in range(len(num_unique_bins)-1)]

# Create bins for lambda_list
lambda_bins = np.logspace(np.log10(min(lambda_list)), np.log10(max(lambda_list)), 11)
lambda_labels = [f'{lambda_bins[i]:.3f}' for i in range(len(lambda_bins)-1)]

# Define colors for each study
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']  # Pastel color scheme
acdc_colors = ['#1f77b4', '#ff7f0e', '#2ca02c']

# Define task mappings
task_mappings = {'gt': 'Greater-Than', 'ds': 'Docstring', 'ioi': 'IOI'}

# Iterate over each study and add traces to the subplots
for i, study_name in enumerate(['gt', 'ioi', 'ds']):
    metrics, num_unique_list, lambda_list = param_metric_lists_from_study(studies[study_name])
    
    # Sort number of unique features, then apply the same permutation to metrics
    num_unique_list, metrics_num_unique = zip(*sorted(zip(num_unique_list, metrics)))
    
    # Calculate the mean ROC AUC for each bin of num_unique_list
    binned_data_num_unique = pd.DataFrame({'Number of Unique Features': pd.cut(num_unique_list, bins=num_unique_bins, labels=num_unique_labels),
                                           'ROC AUC': metrics_num_unique})
    mean_roc_auc_num_unique = binned_data_num_unique.groupby('Number of Unique Features', observed=False)['ROC AUC'].mean()
    
    # Add a trace for each study to the first subplot
    fig.add_trace(go.Scatter(x=mean_roc_auc_num_unique.index, y=mean_roc_auc_num_unique.values, mode='lines+markers', name=task_mappings[study_name], line=dict(color=colors[i], width=4), marker=dict(size=8)), row=1, col=1)
     # Add a horizontal dashed line for ACDC mean performance in the first subplot
    fig.add_shape(type='line', x0=0, x1=9, y0=acdc_means[study_name][0], y1=acdc_means[study_name][0], line=dict(color=colors[i], width=3, dash='dash'), row=1, col=1)
    fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', name=f'ACDC ({task_mappings[study_name]})', line=dict(color=colors[i], width=3, dash='dash'), legendgroup=study_name, showlegend=True), row=1, col=1)
    fig.update_yaxes(range=[-0.01, 1.01], row=1, col=1, showgrid=True, gridwidth=1, gridcolor='lightgray')

    
    # Sort lambda values, then apply the same permutation to metrics
    lambda_list, metrics_lambda = zip(*sorted(zip(lambda_list, metrics)))
    
    # Calculate the mean ROC AUC for each bin of lambda_list
    binned_data_lambda = pd.DataFrame({'Lambda': pd.cut(lambda_list, bins=lambda_bins, labels=lambda_labels),
                                       'ROC AUC': metrics_lambda})
    mean_roc_auc_lambda = binned_data_lambda.groupby('Lambda', observed=False)['ROC AUC'].mean()
    
    # Add a trace for each study to the second subplot
    fig.add_trace(go.Scatter(x=mean_roc_auc_lambda.index, y=mean_roc_auc_lambda.values, mode='lines+markers', name=task_mappings[study_name], line=dict(color=colors[i], width=4), marker=dict(size=8), showlegend=False), row=1, col=2)
    fig.add_shape(type='line', x0=0, x1=9, y0=acdc_means[study_name][1], y1=acdc_means[study_name][1], line=dict(color=colors[i], width=3, dash='dash'), row=1, col=2)
    fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', name=f'ACDC ({task_mappings[study_name]})', line=dict(color=colors[i], width=3, dash='dash'), legendgroup=study_name, showlegend=False), row=1, col=2)
    fig.update_yaxes(range=[-0.01, 1.01], row=1, col=2, showgrid=True, gridwidth=1, gridcolor='lightgray')

# Update the layout
fig.update_layout(
    #title=dict(text='ROC AUC vs. Hyperparameters', font=dict(size=24)),
    height=400,
    width=1200,
    legend=dict(x=1.05, y=0.95, font=dict(size=16)),
    plot_bgcolor='white',
    margin=dict(l=60, r=60, t=80, b=60),
    xaxis=dict(gridcolor='lightgray', gridwidth=1),
    xaxis2=dict(gridcolor='lightgray', gridwidth=1, tickfont=dict(size=4)),
    yaxis=dict(gridcolor='lightgray', gridwidth=1, range=[-0.01, 1.01], showgrid=True, title='Mean ROC AUC', tickfont=dict(size=14), titlefont=dict(size=16)),
)

# Update the x-axis labels
fig.update_xaxes(title_text="Learned Features", row=1, col=1, tickfont=dict(size=14), titlefont=dict(size=16))
fig.update_xaxes(title_text="Lambda", row=1, col=2, type='category', tickfont=dict(size=12), titlefont=dict(size=16))

# Save the figure
fig.write_image(f'../output/figures/roc_auc_vs_params.pdf')

fig.show()

In [140]:
acdc_means = {
    'ds': [0.938, 0.825],
    'gt': [0.766, 0.783],
    'ioi': [0.777, 0.424]
}

In [199]:
colors = {'ioi': '#EF553B', 'gt': '#636EFA', 'ds': '#00CC96'}  # Define colors for each task

# Create a figure with two subplots side-by-side
fig = make_subplots(rows=2, cols=1, shared_yaxes=True, horizontal_spacing=0.05)

# Create bins for num_unique_list
num_unique_bins = np.linspace(min(num_unique_list), max(num_unique_list), 11)
num_unique_labels = [f'{num_unique_bins[i]:.0f}' for i in range(len(num_unique_bins)-1)]

# Create bins for lambda_list
lambda_bins = np.logspace(np.log10(min(lambda_list)), np.log10(max(lambda_list)), 11)
lambda_labels = [f'{lambda_bins[i]:.3f}' for i in range(len(lambda_bins)-1)]

# Define task mappings
task_mappings = {'gt': 'Greater-Than', 'ds': 'Docstring', 'ioi': 'IOI'}

# Iterate over each study and add traces to the subplots
for study_name in ['gt', 'ioi', 'ds']:
    metrics, num_unique_list, lambda_list = param_metric_lists_from_study(studies[study_name])
    
    # Sort number of unique features, then apply the same permutation to metrics
    num_unique_list, metrics_num_unique = zip(*sorted(zip(num_unique_list, metrics)))
    
    # Calculate the mean ROC AUC for each bin of num_unique_list
    binned_data_num_unique = pd.DataFrame({'Number of Unique Features': pd.cut(num_unique_list, bins=num_unique_bins, labels=num_unique_labels),
                                           'ROC AUC': metrics_num_unique})
    mean_roc_auc_num_unique = binned_data_num_unique.groupby('Number of Unique Features', observed=False)['ROC AUC'].mean()
    
    # Add a trace for each study to the first subplot
    fig.add_trace(go.Scatter(x=mean_roc_auc_num_unique.index, y=mean_roc_auc_num_unique.values, mode='lines+markers', name=task_mappings[study_name], line=dict(color=colors[study_name], width=4), marker=dict(size=8)), row=1, col=1)
    # Add a horizontal dashed line for ACDC mean performance in the first subplot
    fig.add_shape(type='line', x0=0, x1=9, y0=acdc_means[study_name][0], y1=acdc_means[study_name][0], line=dict(color=colors[study_name], width=3, dash='dash'), row=1, col=1)
    fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', name=f'ACDC ({study_name.upper()})', line=dict(color=colors[study_name], width=3, dash='dash'), legendgroup=study_name, showlegend=True), row=1, col=1)
    fig.update_yaxes(range=[-0.01, 1.01], row=1, col=1, showgrid=True, gridwidth=1, gridcolor='lightgray')
    
    # Sort lambda values, then apply the same permutation to metrics
    lambda_list, metrics_lambda = zip(*sorted(zip(lambda_list, metrics)))
    
    # Calculate the mean ROC AUC for each bin of lambda_list
    binned_data_lambda = pd.DataFrame({'Lambda': pd.cut(lambda_list, bins=lambda_bins, labels=lambda_labels),
                                       'ROC AUC': metrics_lambda})
    mean_roc_auc_lambda = binned_data_lambda.groupby('Lambda', observed=False)['ROC AUC'].mean()
    
    # Add a trace for each study to the second subplot
    fig.add_trace(go.Scatter(x=mean_roc_auc_lambda.index, y=mean_roc_auc_lambda.values, mode='lines+markers', name=task_mappings[study_name], line=dict(color=colors[study_name], width=4), marker=dict(size=8), showlegend=False), row=2, col=1)
    fig.add_shape(type='line', x0=0, x1=9, y0=acdc_means[study_name][1], y1=acdc_means[study_name][1], line=dict(color=colors[study_name], width=3, dash='dash'), row=2, col=1)
    fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', name=f'ACDC ({study_name.upper()})', line=dict(color=colors[study_name], width=3, dash='dash'), legendgroup=study_name, showlegend=False), row=2, col=1)
    fig.update_yaxes(range=[-0.01, 1.01], row=1, col=2, showgrid=True, gridwidth=1, gridcolor='lightgray')

# Update the layout for publication quality
fig.update_layout(
    font=dict(family="Palatino", size=24, color="black"),
    legend=dict(x=1.0, y=0.65, bgcolor='rgba(255, 255, 255, 0.8)', bordercolor="Black", borderwidth=1, font=dict(size=24)),
    plot_bgcolor='white',
    # Grid off
    width=900,
    height=700,
    margin=dict(l=60, r=60, t=80, b=60),
    xaxis=dict(gridcolor='lightgray', gridwidth=1),
    xaxis2=dict(gridcolor='lightgray', gridwidth=1, tickfont=dict(size=18)),
    yaxis=dict(gridcolor='lightgray', gridwidth=1, range=[-0.01, 1.01], showgrid=True, title='Mean ROC AUC', tickfont=dict(size=24), titlefont=dict(size=24)),
    yaxis2=dict(gridcolor='lightgray', gridwidth=1, range=[-0.01, 1.01], showgrid=True, title='Mean ROC AUC', tickfont=dict(size=24), titlefont=dict(size=24)),
   #mathjax='TeX'
)

# Update the x-axis labels
fig.update_xaxes(title_text='Dim. bottleneck', row=1, col=1, tickfont=dict(size=20), titlefont=dict(size=24))
fig.update_xaxes(title_text='Lambda (Sparsity penalty)', row=2, col=1, type='category', tickfont=dict(size=18), titlefont=dict(size=24))
# fig.update_xaxes(title_text="No. SAE Learned Features", row=1, col=1, tickfont=dict(size=20), titlefont=dict(size=24))
# fig.update_xaxes(title_text="Lambda (Sparsity penalty)", row=2, col=1, type='category', tickfont=dict(size=18), titlefont=dict(size=24))

# Save the figure
fig.write_image(f'../output/figures/roc_auc_vs_params.pdf')

fig.show()

## Table showing AUROC and F1 for one set of hyperparameters across datasets (including threshold, which isn't a training hyperparameter)

Requires:
* Models trained with the same selected hyperparameters (likely the maximising value chosen from the previous table)
* Ground truth array, stored in data
* Residual streams, stored in data

For all tasks: docstring, greater-than, indirect object identification, tracr-reverse.

In [119]:
import joblib

# Load the optuna studies (they're pickle files saved with joblib)
tasks = ['gt', 'ds', 'ioi']
studies = {}
for task in tasks:
    study = joblib.load(f'../output/{task}/optuna_study.pkl')
    studies[task] = study

In [120]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd

def param_metric_lists_from_study(study):
    # Get all trials
    trials = study.trials
    
    # Get all parameters and metrics
    params = [trial.params for trial in trials]
    metrics = [trial.values[0] for trial in trials]
    
    # Get all parameters and metrics as lists
    num_unique_list = [trial.params['num_unique'] for trial in trials]
    lambda_list = [trial.params['lambda_'] for trial in trials]
    
    return metrics, num_unique_list, lambda_list

# Define task mappings
task_mappings = {'gt': 'Greater-Than', 'ds': 'Docstring', 'ioi': 'IOI'}

# Initialize lists to store the best num_unique and lambda for each study
best_num_unique_list = []
best_lambda_list = []

# Iterate over each study
for study_name in ['gt', 'ioi', 'ds']:
    metrics, num_unique_list, lambda_list = param_metric_lists_from_study(studies[study_name])
    
    # Find the index of the best metric
    best_index = np.argmax(metrics)
    
    # Append the best num_unique and lambda to the lists
    best_num_unique_list.append(num_unique_list[best_index])
    best_lambda_list.append(lambda_list[best_index])

# Calculate the average of the best num_unique and lambda across all studies
avg_num_unique = np.mean(best_num_unique_list)
avg_lambda = np.mean(best_lambda_list)

# Create a list to store the table data
table_data = []

# Iterate over each study
for study_name in ['gt', 'ioi', 'ds']:
    metrics, num_unique_list, lambda_list = param_metric_lists_from_study(studies[study_name])
    
    # Find the index of the closest num_unique and lambda to the average
    num_unique_diffs = np.abs(np.array(num_unique_list) - avg_num_unique)
    lambda_diffs = np.abs(np.array(lambda_list) - avg_lambda)
    closest_index = np.argmin(num_unique_diffs + lambda_diffs)
    
    # Get the ROC AUC and F1 at the closest num_unique and lambda
    roc_auc = metrics[closest_index]
    f1 = 0.0  # Replace with the actual F1 value if available
    
    # Append the data to the table_data list
    table_data.append([task_mappings[study_name], roc_auc, f1])

# Create a pandas DataFrame from the table_data
df = pd.DataFrame(table_data, columns=['Study/Dataset', 'ROC AUC', 'F1'])
df

Unnamed: 0,Study/Dataset,ROC AUC,F1
0,Greater-Than,0.834568,0.0
1,IOI,0.790743,0.0
2,Docstring,0.798077,0.0


In [173]:
import plotly.graph_objects as go
import torch

import plotly.io as pio

# Define the template with Palatino font
template = pio.templates["plotly"]
template.layout.font.family = "Palatino"

# Set the modified template as the default
pio.templates.default = template


# Define the tasks
tasks = ['ioi', 'gt']#, 'ds']


colors = {'ioi': '#EF553B', 'gt': '#636EFA', 'ds': '#00CC96'}  # Define colors for each task

# Create a figure
fig = go.Figure()

ioi_thresholds = []

for j, task in enumerate(tasks):
    # Load residual streams
    device = 'cpu'
    resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
    print(f"Residual streams shape: {resid_streams.shape}")
    head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
    ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

    # Load save_dict
    savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
    save_dict = torch.load(savepath)
    num_unique = save_dict['node_best_num_unique']
    print(f"Number of unique features: {num_unique}")
    lambda_ = save_dict['node_best_lambda']
    print(f"Lambda: {lambda_}")
    best_roc_auc = save_dict['node_best_roc_auc']
    print(f"Best ROC AUC: {best_roc_auc}")
    model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
    # Load the model
    model.load_state_dict(save_dict['model'])
    # savepath = f"../models/{task}/sparse_autoencoder.pt"
    # model = torch.load(savepath)


    #Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)
    heads = []
    layers = []
    for i, l in enumerate(head_labels):
        layer, head = feature_string_to_head_and_layer(i, head_labels)
        heads.append(head)
        layers.append(layer)
    heads = list(set(heads))
    layers = list(set(layers))
    ground_truth_array = np.zeros((len(layers), len(heads)))
    for layer, head in ground_truth:
        ground_truth_array[layer, head] = 1

    model.eval()
    learned_activations = model(resid_streams)[0].detach().cpu().numpy()
    all_indices = np.argmax(learned_activations, axis=2)
    normalise = False# if task == 'ds' else True
    print(f"Normalise: {normalise}")
    y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

    # Print best f1 score (and corresponding threshold)
    node_best_f1 = np.max(f1)
    best_threshold = thresholds[np.argmax(f1)]
    print(f"Best F1 score: {node_best_f1:.4f}")
    # Print best threshold in scientific notation
    print(f"Best Threshold: {best_threshold:.2e}")
    # Print ROC AUC
    print(f"ROC AUC: {node_roc_auc:.4f}")

    if task == 'ioi':
        ioi_thresholds = thresholds

    if task == 'gt':
        thresholds = ioi_thresholds


    # Add traces for each task with circle markers
    fig.add_trace(go.Scatter(x=thresholds, y=f1, mode='lines+markers', name=task.upper(),
                             line=dict(color=colors[task], width=5),
                             marker=dict(symbol='circle', size=12)))
    # fig.add_trace(go.Scatter(x=thresholds, y=f1, mode='lines+markers', name=task.upper(), line=dict(color=colors[j], width=3), marker=dict(symbol='circle', size=8)))

    # Add a single "Best F1" marker for all tasks
    if j == 0:
        fig.add_trace(go.Scatter(x=[best_threshold], y=[node_best_f1], mode='markers', name='Best F1', marker=dict(color='lightgreen', size=24, symbol='x')))

    # Add a red cross marker at the maximum ROC AUC for each dataset
    max_roc_auc_index = np.argmax(f1)
    fig.add_trace(go.Scatter(x=[thresholds[max_roc_auc_index]], y=[f1[max_roc_auc_index]], mode='markers', name=f'{task.upper()} Max ROC AUC', marker=dict(color='lightgreen', size=24, symbol='x'), legendgroup=task.upper(), showlegend=False))

# Update layout for publication quality
fontsize = 26
fig.update_layout(
    xaxis_title='Log Threshold',
    yaxis_title='F1 Score',
    font=dict(
        family="Palatino",
        size=24,
        color="black"
    ),
    legend=dict(
        x=0.8,
        y=0.65,
        bgcolor='rgba(255, 255, 255, 0.8)',
        bordercolor="Black",
        borderwidth=1
    ),
    plot_bgcolor='white',
    width=900,
    height=700
)

# Set x-axis to log scale
fig.update_xaxes(type='log')

# Save the figure
fig.write_image(f'../output/figures/f1_vs_threshold.pdf')
fig.show()

Residual streams shape: torch.Size([500, 144, 768])
Number of unique features: 187
Lambda: 0.02393969388660661
Best ROC AUC: 0.8893415906127771
Normalise: False
Best F1 score: 0.7934
Best Threshold: 1.37e-11
ROC AUC: 0.8893
Residual streams shape: torch.Size([500, 144, 768])
Number of unique features: 393
Lambda: 0.014734413708943257
Best ROC AUC: 0.9395061728395061
Normalise: False
Best F1 score: 0.8296
Best Threshold: 2.07e-08
ROC AUC: 0.9206


## Our performance when using easy negatives

Requires:
* List of performance (AUROC and F1 score at hyperparameters + threshold chosen from above) across `np.arange(0, 2500, 250)` for the number of easy negatives from Pile

For all tasks: docstring, greater-than, indirect object identification, tracr-reverse.

In [None]:
auc_list = [[0.7817796610169491, 0.83751629726206, 0.8221968709256845, 0.8076923076923077, 0.8047588005215124], 
            [0.7705345501955672, 0.7486962190352021, 0.7632007822685788, 0.7641786179921773, 0.7772164276401564], 
            [0.7884615384615383, 0.7648305084745762, 0.7858539765319426, 0.7465775749674055, 0.7555410691003911], 
            [0.7529335071707952, 0.806877444589309, 0.773794002607562, 0.7588005215123859, 0.7692307692307692], 
            [0.7750977835723598, 0.7501629726205997, 0.7869947848761407, 0.7752607561929596, 0.7632007822685788]]

# Take mean of each list in auc_list
auc_list_std = np.std(auc_list, axis=1)
auc_list = np.mean(auc_list, axis=1)
num_easy_negatives = [0, 250, 500, 750, 1000]
auc_list, auc_list_std, num_easy_negatives

In [None]:
# Define colors
colors = px.colors.qualitative.Plotly

# Plot the ROC AUC vs. Number of Easy Negatives with error bars
fig = go.Figure()

fig.add_trace(go.Scatter(x=num_easy_negatives, y=auc_list, mode='lines+markers',
                         line=dict(color=colors[4], width=6), marker=dict(size=8),
                         error_y=dict(type='data', array=auc_list_std, visible=True, color=colors[4], thickness=4.5)))

fig.update_xaxes(title_text="Number of Easy Negatives", range=[-50, 1050],
                 showgrid=True, gridwidth=1, gridcolor='lightgray', 
                 tickfont=dict(size=18), title_font=dict(size=24))

fig.update_yaxes(title_text="ROC AUC", range=[0.7, 0.85],
                 showgrid=True, gridwidth=1, gridcolor='lightgray',
                 tickfont=dict(size=18), title_font=dict(size=24))

fig.update_layout(title=dict(text="", x=0.5, y=0.95, font=dict(size=28)),
                  width=1000, height=500, showlegend=False,
                  margin=dict(l=100, r=100, t=100, b=100),
                  plot_bgcolor='white')


# Save the figure
fig.write_image(f'../output/figures/roc_auc_vs_num_easy_negatives.pdf')
fig.show()

## Entropy instead of co-occurence

In [None]:
# Load residual streams
device = 'cpu'
task = 'ioi'
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
print(f"Residual streams shape: {resid_streams.shape}")
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

# Load save_dict
savepath = f"../models/{task}/sparse_autoencoder_dict.pt"
# num_unique = 300
# lambda_ = 0.02
save_dict = torch.load(savepath)
num_unique = save_dict['node_best_num_unique']
print(f"Number of unique features: {num_unique}")
lambda_ = save_dict['node_best_lambda']
print(f"Lambda: {lambda_}")
best_roc_auc = save_dict['node_best_roc_auc']
print(f"Best ROC AUC: {best_roc_auc}")


model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)

# Load the model
model.load_state_dict(save_dict['model'])

In [None]:
# Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)
heads = []
layers = []
for i, l in enumerate(head_labels):
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    heads.append(head)
    layers.append(layer)

heads = list(set(heads))
layers = list(set(layers))

ground_truth_array = np.zeros((len(layers), len(heads)))
for layer, head in ground_truth:
    ground_truth_array[layer, head] = 1

model.eval()
learned_activations = model(resid_streams)[0].detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)

normalise = False# if task in ['ds', 'ioi'] else True
print(f"Normalise: {normalise}")
y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

# Print best f1 score (and corresponding threshold)
node_best_f1 = np.max(f1)
best_threshold = thresholds[np.argmax(f1)]
print(f"Best F1 score: {node_best_f1:.4f}")

# Print ROC AUC
print(f"ROC AUC: {node_roc_auc:.4f}")

In [None]:
def gen_co_occurrence_matrix(all_indices, learned_activations):

    n_heads = all_indices.shape[1]
    n_feat = learned_activations.shape[-1]
    co_occurrence_matrix = np.zeros((n_heads, n_heads, n_feat, n_feat))

    for e in range(all_indices.shape[0]):  # For each example
        for h1 in range(n_heads):  # For each head
            c1 = all_indices[e, h1]  # Code in head h1
            for h2 in range(n_heads):  # For each other head
                if h1 != h2:  # Optional: Skip counting co-occurrence of a head with itself
                    c2 = all_indices[e, h2]  # Code in head h2
                    # Increment co-occurrence count for (h1, h2)
                    co_occurrence_matrix[h1, h2, c1, c2] += 1  # Simple version: just counting co-occurrence

    # Post-processing: normalization or softmax can be applied here
    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Optional: Skip counting co-occurrence of a head with itself
                total_co_occurrences = np.sum(co_occurrence_matrix[h1, h2, :, :])
                # Avoid division by zero in case there are no co-occurrences
                if total_co_occurrences > 0:
                    co_occurrence_matrix[h1, h2, :, :] /= total_co_occurrences

    return co_occurrence_matrix

positive_indices = all_indices[:250, :]
negative_indices = all_indices[250:, :]
positive_learned_activations = learned_activations[:250, :, :]
negative_learned_activations = learned_activations[250:, :, :]

positive_co_occurrence_matrix = gen_co_occurrence_matrix(positive_indices, positive_learned_activations)
negative_co_occurrence_matrix = gen_co_occurrence_matrix(negative_indices, negative_learned_activations)

In [None]:
def calculate_entropy(matrix):
    """
    Calculate the entropy of the co-occurrence probabilities for each head pair.
    Entropy is calculated using the formula H(X) = -sum(p(x) * log2(p(x))), where
    p(x) is the probability of each co-occurrence.
    """
    n_heads, _, n_feat, _ = matrix.shape
    entropy_matrix = np.zeros((n_heads, n_heads))

    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:
                # Flatten the matrix for the current head pair to a single vector of probabilities
                co_occurrence_probs = matrix[h1, h2, :, :].flatten()
                # Filter out zero probabilities to avoid log(0)
                co_occurrence_probs = co_occurrence_probs[co_occurrence_probs > 0]
                # Calculate entropy for the current head pair
                entropy = -np.sum(co_occurrence_probs * np.log2(co_occurrence_probs))
                entropy_matrix[h1, h2] = entropy

    return entropy_matrix

def softmax_edge(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x, axis=1, keepdims=True))  # Shift for numerical stability
    return e_x / e_x.sum(axis=1, keepdims=True)

# Calculate entropy for each head pair
positive_entropy_matrix = calculate_entropy(positive_co_occurrence_matrix)
negative_entropy_matrix = calculate_entropy(negative_co_occurrence_matrix)

# Difference entropy matrix
entropy_matrix = -(negative_entropy_matrix - positive_entropy_matrix)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

task_mappings = {
    'gt': 'Greater-than',
    'ioi': 'Indirect Object Identification',
    'ds': 'Docstring',
}

# Convert ground_truth to head number
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
print(head_labels)
ground_truth_heads = []
for layer, head in ground_truth:
    head_str = f"L{layer}H{head}"
    head_num = head_labels.index(head_str)
    ground_truth_heads.append(head_num)

# Create the subplots
fig = make_subplots(rows=1, cols=2, subplot_titles=("Entropy of Co-occurrence Probabilities", "Ground Truth Heads"))

# Heatmap of entropy matrix
entropy_heatmap = go.Heatmap(
    z=entropy_matrix,
    x=[f"Head {i}" for i in range(n_heads)],
    y=[f"Head {i}" for i in range(n_heads)],
    colorscale='Viridis',
    colorbar=dict(title='Entropy', titleside='right', titlefont=dict(size=14), tickfont=dict(size=12))
)
fig.add_trace(entropy_heatmap, row=1, col=1)

# Ground truth heads visualization
ground_truth_matrix = np.zeros((n_heads, n_heads))
for i in range(n_heads):
    for j in range(n_heads):
        if i in ground_truth_heads and j in ground_truth_heads:
            ground_truth_matrix[i, j] = 1
        elif i in ground_truth_heads or j in ground_truth_heads:
            ground_truth_matrix[i, j] = 0.5
        else:
            ground_truth_matrix[i, j] = 0

ground_truth_heatmap = go.Heatmap(
    z=ground_truth_matrix,
    x=[f"Head {i}" for i in range(n_heads)],
    y=[f"Head {i}" for i in range(n_heads)],
    colorscale=[[0, 'white'], [0.5, 'lightblue'], [1, 'blue']],
    showscale=False
)
fig.add_trace(ground_truth_heatmap, row=1, col=2)

fontsize = 28
fig.update_layout(
    title=dict(text=f'{task_mappings[task]}', font=dict(size=fontsize)),
    width=1200,
    height=600,
    plot_bgcolor='white',
    font=dict(size=fontsize-12),
    xaxis=dict(tickfont=dict(size=fontsize-12)),
    yaxis=dict(tickfont=dict(size=fontsize-12)),
    xaxis2=dict(tickfont=dict(size=fontsize-12)),
    yaxis2=dict(tickfont=dict(size=fontsize-12))
)

# Set subplot titles to be larger
for annotation in fig['layout']['annotations']:
    annotation['font'] = dict(size=fontsize-6)

# Save the figure
fig.write_image(f'../output/{task}/{task}_entropy_vs_ground_truth_edges.pdf')
fig.show()

In [None]:
# Sort (head, head) pairs by descending unique co-occurrence counts
sorted_indices = np.argsort(entropy_matrix.flatten())[::-1]
sorted_indices = np.unravel_index(sorted_indices, entropy_matrix.shape)
# Zip them together to create a list of (head, head) pairs
sorted_head_pairs = list(zip(sorted_indices[0], sorted_indices[1]))

circuit_components = []
for i, (h1, h2) in enumerate(sorted_head_pairs):
    (l1, h1) = feature_string_to_head_and_layer(h1, head_labels)
    (l2, h2) = feature_string_to_head_and_layer(h2, head_labels)
    circuit_components.append((l1, h1))
    circuit_components.append((l2, h2))

k = 1050

y_pred = np.zeros_like(ground_truth_array)
for (l, h) in circuit_components[:k]:
    y_pred[l, h] += 1

# Plot the ROC curve in plotly
y_true = ground_truth_array.flatten()
y_pred = y_pred.flatten()

# Normalise y_pred with softmax
# def softmax(x): return np.exp(x) / np.sum(np.exp(x), axis=0)
# y_pred = softmax(y_pred)

fpr, tpr, thresholds = roc_curve(y_true, y_pred)

# Calculate ROC AUC
roc_auc = auc(fpr, tpr)

# Calculate F1
f1 = 2 * (tpr * (1 - fpr)) / (tpr + (1 - fpr))

# Print
print(f"ROC AUC: {roc_auc:.4f}")
best_f1 = np.max(f1)

In [None]:
# For varying values of k, calculate ROC AUC and F1 score
if len(circuit_components) > 10000:
    k_values = np.arange(1, len(circuit_components), 100)
else:
    k_values = np.arange(1, len(circuit_components), 10)
roc_auc_values = []
f1_values = []

for k in k_values:
    y_pred = np.zeros_like(ground_truth_array)
    for (l, h) in circuit_components[:k]:
        y_pred[l, h] += 1

    y_true = ground_truth_array.flatten()
    y_pred = y_pred.flatten()

    # y_pred = softmax(y_pred)

    fpr, tpr, thresholds = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)
    roc_auc_values.append(roc_auc)

    f1 = 2 * (tpr * (1 - fpr)) / (tpr + (1 - fpr))
    f1_values.append(np.max(f1))

print(f"Best F1 score across k: {np.max(f1_values):.4f}")
print(f"Best ROC AUC across k: {np.max(roc_auc_values):.4f} (at k={k_values[np.argmax(roc_auc_values)]})")

In [None]:
task_mappings = {
    'gt': 'Greater-than',
    'ioi': 'Indirect Object Identification',
    'ds': 'Docstring',
}

# Plot ROC AUC and F1 score as a function of k
fig = go.Figure()
fig.add_trace(go.Scatter(x=k_values, y=roc_auc_values, mode='lines', name='ROC AUC'))
fig.add_trace(go.Scatter(x=k_values, y=f1_values, mode='lines', name='F1 Score'))
fig.update_layout(title=f'{task_mappings[task]} (Entropy)',
                  xaxis_title='k',
                  yaxis_title='Score',
                  width=800,
                  height=400)
# Adding dummy traces for the legend
fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', name='Node-Level Best F1 Score', line=dict(color="red", dash="dash")))
fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', name='Node-Level ROC AUC', line=dict(color="blue", dash="dash")))

# Put a horizontal line at the best F1 score
fig.add_shape(type="line", x0=k_values[0], y0=node_best_f1, x1=k_values[-1]+1, y1=node_best_f1, line=dict(color="red", dash="dash"), name="Node-Level Best F1 Score")
# Put a horizontal line at the node AUC
fig.add_shape(type="line", x0=k_values[0], y0=node_roc_auc, x1=k_values[-1]+1, y1=node_roc_auc, line=dict(color="blue", dash="dash"), name="Node-Level ROC AUC")
# Save as pdf
fig.write_image(f"../output/{task}/{task}_unique_co_occurrences_entropy.pdf")
fig.show()

## Contrastive loss

In [None]:
import plotly.graph_objects as go
import numpy as np

data = {0.0: [0.7834093872229466, 0.7323989569752283, 0.7700456323337679, 0.7645045632333768, 0.7550521512385919],
        0.01: [0.7881355932203389, 0.772490221642764, 0.7907431551499349, 0.7273468057366362, 0.7491851368970013],
        0.02: [0.7405475880052151, 0.7477183833116036, 0.7517926988265972, 0.6967079530638851, 0.729954367666232],
        0.03: [0.7001303780964797, 0.7359843546284225, 0.7860169491525424, 0.7162646675358539, 0.7216427640156453],
        0.04: [0.7064863102998695, 0.6973598435462842, 0.7249022164276401, 0.7082790091264668, 0.7053455019556714],
        0.05: [0.7679269882659713, 0.7242503259452411, 0.6608539765319427, 0.6688396349413299, 0.7149608865710562],
        0.1: [0.6957301173402869, 0.747555410691004, 0.7002933507170795, 0.7591264667535853, 0.702411994784876],
        0.2: [0.5842568448500652, 0.7108865710560625, 0.6773142112125163, 0.7328878748370273, 0.5927314211212517],
        0.5: [0.7371251629726205, 0.8158409387222947, 0.7446219035202086, 0.6494458930899608, 0.6895371577574967],
        1.0: [0.7190352020860495, 0.7519556714471969, 0.7187092568448501, 0.7478813559322034, 0.6119621903520208]}

alphas = list(data.keys())
roc_auc_values = list(data.values())

mean_roc_auc = [np.mean(values) for values in roc_auc_values]
std_roc_auc = [np.std(values) for values in roc_auc_values]

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=alphas,
    y=mean_roc_auc,
    mode='lines+markers',
    name='Mean ROC AUC',
    line=dict(width=4),
    marker=dict(size=12),
    error_y=dict(
        type='data',
        array=std_roc_auc,
        visible=True,
        thickness=1.5,
        width=3,
        color='rgba(0, 0, 0, 0.5)'
    )
))

fig.update_layout(
    #title='Node-Level ROC AUC vs. Alpha Hyperparameter',
    xaxis_title='Alpha (contrastive loss strength)',
    yaxis_title='ROC AUC',
    font=dict(size=18),
    legend=dict(title='', x=0.8, y=0.95, bgcolor='rgba(255, 255, 255, 0.8)'),
    plot_bgcolor='white',
    width=800,
    height=600
)

# Make x-axis log scale
fig.update_xaxes(type='log')

fig.show()

In [None]:
import plotly.graph_objects as go
import numpy as np

data = {0.0: [0.8065514993481095, 0.8296936114732725, 0.8096479791395046, 0.8352346805736637, 0.811277705345502], 
        0.01: [0.7811277705345503, 0.7527705345501955, 0.7896023468057367, 0.7258800521512385, 0.7205019556714471], 
        0.02: [0.673076923076923, 0.6610169491525424, 0.7754237288135594, 0.7304432855280312, 0.7710234680573664], 
        0.03: [0.7366362451108214, 0.7175684485006518, 0.7403846153846154, 0.7604302477183832, 0.7617340286831811], 
        0.04: [0.6805736636245111, 0.6782920469361148, 0.7328878748370273, 0.7325619295958279, 0.7024119947848761], 
        0.05: [0.76874185136897, 0.6773142112125163, 0.6960560625814863, 0.6965449804432855, 0.682529335071708], 
        0.1: [0.7363102998696219, 0.7544002607561929, 0.7128422425032594, 0.7076271186440678, 0.6848109517601043], 
        0.2: [0.7045306388526728, 0.740547588005215, 0.7007822685788788, 0.7830834419817472, 0.6678617992177314], 
        0.5: [0.6470013037809649, 0.6870925684485005, 0.7154498044328552, 0.7447848761408083, 0.6564537157757497], 
        1.0: [0.7244132985658409, 0.7346805736636245, 0.6844850065189048, 0.678129074315515, 0.8161668839634941]}

alphas = list(data.keys())
roc_auc_values = list(data.values())

mean_roc_auc = [np.mean(values) for values in roc_auc_values]
std_roc_auc = [np.std(values) for values in roc_auc_values]

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=alphas,
    y=mean_roc_auc,
    mode='lines+markers',
    name='Mean ROC AUC',
    line=dict(width=4),
    marker=dict(size=12),
    error_y=dict(
        type='data',
        array=std_roc_auc,
        visible=True,
        thickness=1.5,
        width=3,
        color='rgba(0, 0, 0, 0.5)'
    )
))

fig.update_layout(
    #title='Node-Level ROC AUC vs. Alpha Hyperparameter',
    xaxis_title='Alpha (contrastive loss strength)',
    yaxis_title='ROC AUC',
    font=dict(size=18),
    legend=dict(title='', x=0.8, y=0.95, bgcolor='rgba(255, 255, 255, 0.8)'),
    plot_bgcolor='white',
    width=800,
    height=600
)

# Make x-axis log scale
fig.update_xaxes(type='log')

fig.show()

In [None]:
import plotly.graph_objects as go
import numpy as np

data_normalized = {0.0: [0.7834093872229466, 0.7323989569752283, 0.7700456323337679, 0.7645045632333768, 0.7550521512385919],
                   0.01: [0.7881355932203389, 0.772490221642764, 0.7907431551499349, 0.7273468057366362, 0.7491851368970013],
                   0.02: [0.7405475880052151, 0.7477183833116036, 0.7517926988265972, 0.6967079530638851, 0.729954367666232],
                   0.03: [0.7001303780964797, 0.7359843546284225, 0.7860169491525424, 0.7162646675358539, 0.7216427640156453],
                   0.04: [0.7064863102998695, 0.6973598435462842, 0.7249022164276401, 0.7082790091264668, 0.7053455019556714],
                   0.05: [0.7679269882659713, 0.7242503259452411, 0.6608539765319427, 0.6688396349413299, 0.7149608865710562],
                   0.1: [0.6957301173402869, 0.747555410691004, 0.7002933507170795, 0.7591264667535853, 0.702411994784876],
                   0.2: [0.5842568448500652, 0.7108865710560625, 0.6773142112125163, 0.7328878748370273, 0.5927314211212517],
                   0.5: [0.7371251629726205, 0.8158409387222947, 0.7446219035202086, 0.6494458930899608, 0.6895371577574967],
                   1.0: [0.7190352020860495, 0.7519556714471969, 0.7187092568448501, 0.7478813559322034, 0.6119621903520208]}

data_not_normalized = {0.0: [0.8065514993481095, 0.8296936114732725, 0.8096479791395046, 0.8352346805736637, 0.811277705345502],
                       0.01: [0.7811277705345503, 0.7527705345501955, 0.7896023468057367, 0.7258800521512385, 0.7205019556714471],
                       0.02: [0.673076923076923, 0.6610169491525424, 0.7754237288135594, 0.7304432855280312, 0.7710234680573664],
                       0.03: [0.7366362451108214, 0.7175684485006518, 0.7403846153846154, 0.7604302477183832, 0.7617340286831811],
                       0.04: [0.6805736636245111, 0.6782920469361148, 0.7328878748370273, 0.7325619295958279, 0.7024119947848761],
                       0.05: [0.76874185136897, 0.6773142112125163, 0.6960560625814863, 0.6965449804432855, 0.682529335071708],
                       0.1: [0.7363102998696219, 0.7544002607561929, 0.7128422425032594, 0.7076271186440678, 0.6848109517601043],
                       0.2: [0.7045306388526728, 0.740547588005215, 0.7007822685788788, 0.7830834419817472, 0.6678617992177314],
                       0.5: [0.6470013037809649, 0.6870925684485005, 0.7154498044328552, 0.7447848761408083, 0.6564537157757497],
                       1.0: [0.7244132985658409, 0.7346805736636245, 0.6844850065189048, 0.678129074315515, 0.8161668839634941]}

alphas = list(data_normalized.keys())
roc_auc_values_normalized = list(data_normalized.values())
roc_auc_values_not_normalized = list(data_not_normalized.values())

mean_roc_auc_normalized = [np.mean(values) for values in roc_auc_values_normalized]
std_roc_auc_normalized = [np.std(values) for values in roc_auc_values_normalized]

mean_roc_auc_not_normalized = [np.mean(values) for values in roc_auc_values_not_normalized]
std_roc_auc_not_normalized = [np.std(values) for values in roc_auc_values_not_normalized]

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=alphas,
    y=mean_roc_auc_normalized,
    mode='lines+markers',
    name='Normalised',
    line=dict(width=4, color='purple'),
    marker=dict(size=12, color='purple'),
    error_y=dict(
        type='data',
        array=std_roc_auc_normalized,
        visible=True,
        thickness=1.5,
        width=3,
        color='rgba(128, 0, 128, 0.5)'  # Purple with transparency
    )
))

fig.add_trace(go.Scatter(
    x=alphas,
    y=mean_roc_auc_not_normalized,
    mode='lines+markers',
    name='Not Normalised',
    line=dict(width=4, color='green'),
    marker=dict(size=12, color='green'),
    error_y=dict(
        type='data',
        array=std_roc_auc_not_normalized,
        visible=True,
        thickness=1.5,
        width=3,
        color='rgba(0, 128, 0, 0.5)'  # Green with transparency
    )
))

fig.update_layout(
    title='Node-Level ROC AUC vs. Alpha',
    xaxis_title='Alpha',
    yaxis_title='ROC AUC',
    font=dict(size=18),
    legend=dict(title='', x=0.8, y=0.95, bgcolor='rgba(255, 255, 255, 0.8)'),
    plot_bgcolor='white',
    width=800,
    height=600
)

# X-axis log scale
fig.update_xaxes(type='log')

# Save the figure
fig.write_image('../output/figures/roc_auc_vs_alpha.pdf')

fig.show()

## Full comparison grouped bar chart

In [178]:
print(px.colors.qualitative.Plotly)

['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']


In [183]:
print(px.colors.sequential.Blues)

['rgb(247,251,255)', 'rgb(222,235,247)', 'rgb(198,219,239)', 'rgb(158,202,225)', 'rgb(107,174,214)', 'rgb(66,146,198)', 'rgb(33,113,181)', 'rgb(8,81,156)', 'rgb(8,48,107)']


In [190]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

# Data
tasks = ['Docstring', 'Greater-than', 'IOI']
methods = ['ACDC', 'HISP', 'SP', 'Ours']

our_results = {
    'IOI': {
        'N': [0.8463168187744459, 0.8829856584093873, 0.8559322033898304, 0.8391460234680574, 0.8402868318122556],
        'E': [0.8393089960886572, 0.8265971316818774, 0.866199478487614, 0.8455019556714471, 0.8223598435462841],
    },
    'Docstring': {
        'N': [0.9134615384615384, 0.8974358974358975, 0.9038461538461539, 0.9358974358974359, 0.9262820512820513],
        'E': [0.951923076923077, 0.9166666666666666, 0.9070512820512822, 0.8942307692307693, 0.9006410256410257],
    },
    'Greater-than': {
        'N': [0.8238683127572016, 0.8172839506172839, 0.8740740740740741, 0.7378600823045267, 0.9078189300411523],
        'E': [0.8304526748971194, 0.8814814814814815, 0.8551440329218106, 0.876954732510288, 0.8366255144032922],
    }
}

data = {
    'Docstring': {
        'ACDC(E)': [0.982, 0.972, 0.906, 0.929],
        'HISP(E)': [0.805, 0.821, 0.805, 0.821],
        'SP(E)': [0.937, 0.942, 0.428, 0.482],
        'ACDC(N)': [0.950, 0.938, 0.837, 0.825],
        'HISP(N)': [0.881, 0.889, 0.881, 0.889],
        'SP(N)': [0.928, 0.941, 0.420, 0.398]
    },
    'Greater-than': {
        'ACDC(E)': [0.853, 0.461, 0.701, 0.491],
        'HISP(E)': [0.693, 0.706, 0.693, 0.706],
        'SP(E)': [0.806, 0.812, 0.163, 0.639],
        'ACDC(N)': [0.890, 0.766, 0.887, 0.783],
        'HISP(N)': [0.642, 0.631, 0.642, 0.631],
        'SP(N)': [0.827, 0.811, 0.134, 0.522]
    },
    'IOI': {
        'ACDC(E)': [0.869, 0.589, 0.539, 0.447],
        'HISP(E)': [0.789, 0.836, 0.792, 0.836],
        'SP(E)': [0.823, 0.707, 0.486, 0.393],
        'ACDC(N)': [0.880, 0.777, 0.458, 0.424],
        'HISP(N)': [0.668, 0.728, 0.671, 0.728],
        'SP(N)': [0.842, 0.797, 0.605, 0.479]
    }
}

# Calculate average AUC values and standard deviations
data_avg = {}
data_std = {}
for task in tasks:
    data_avg[task] = {}
    data_std[task] = {}
    for method in methods:
        data_avg[task][method] = {}
        data_std[task][method] = {}
        for level in ['E', 'N']:
            if method == 'Ours':
                data_avg[task][method][level] = np.mean(our_results[task][level])
                data_std[task][method][level] = np.std(our_results[task][level])
            else:
                data_avg[task][method][level] = np.mean(data[task][f'{method}({level})'])
                data_std[task][method][level] = np.std(data[task][f'{method}({level})'])

# Set colors for each method
#colors = {'ACDC': 'rgb(31, 119, 180)', 'HISP': 'rgb(255, 127, 14)', 'SP': 'rgb(44, 160, 44)', 'Ours': 'rgb(214, 39, 40)'}
colors = {'ACDC': '#AB63FA', 'HISP': '#FFA15A', 'SP': '#19D3F3', 'Ours': '#FF6692'}
plotly_colours = ['rgb(198,219,239)', 'rgb(158,202,225)', 'rgb(107,174,214)', 'rgb(66,146,198)', 'rgb(33,113,181)']#, 'rgb(8,81,156)', 'rgb(8,48,107)'] #['#636EFA', '#EF553B', '#00CC96', '#AB63FA']
colors = dict(zip(methods, plotly_colours))

# Create subplots
fig = make_subplots(rows=1, cols=2, subplot_titles=('Edge-Level AUCs', 'Node-Level AUCs'), horizontal_spacing=0.05)

# Add traces for each method and level
for method in methods:
    for level, col in zip(['E', 'N'], [1, 2]):
        fig.add_trace(go.Bar(
            name=method,
            x=tasks,
            y=[data_avg[task][method][level] for task in tasks],
            error_y=dict(
                type='data',
                array=[data_std[task][method][level] for task in tasks],
                visible=True,
            ),
            marker_color=colors[method],
            showlegend=level == 'E'  # Only show legend for edge-level
        ), row=1, col=col)

# Set font sizes
fontsize = 28
fig.update_layout(
    # title=dict(
    #     text='AUCs Averaged Across Metrics',
    #     font=dict(size=fontsize)
    # ),
    #xaxis=dict(title='Task', tickfont=dict(size=fontsize-4)),
    yaxis=dict(title='AUC', tickfont=dict(size=fontsize-4)),
    #xaxis2=dict(title='Task', tickfont=dict(size=fontsize-4)),
    #yaxis2=dict(title='AUC', tickfont=dict(size=fontsize-4)),
    legend=dict(
        title='Method',
        font=dict(size=fontsize-4),
        tracegroupgap=10
    ),
    font=dict(size=fontsize-4),
    plot_bgcolor='rgba(0,0,0,0)',
    width=1400,
    height=500,
    #template='plotly_white',
    # Set background to white
    paper_bgcolor='white',
    # Remove grid
    xaxis=dict(showgrid=False),
)

# Update subplot title font size
for i in fig['layout']['annotations']:
    i['font'] = dict(size=fontsize)

# Set ylim to be [0, 1] on both plots
fig.update_yaxes(range=[0, 1], row=1, col=1)
fig.update_yaxes(range=[0, 1], row=1, col=2)

# Save the figure
fig.write_image('../output/figures/edge_node_auc.pdf')

# Show plot
fig.show()

In [180]:
# Print mean and std of each of our results
for task in ['IOI', 'Greaterthan', 'Docstring']:
    for level in ['E', 'N']:
        print(f'{task} ({level}): {np.mean(our_results[task][level]):.3f} ± {np.std(our_results[task][level]):.3f}')

IOI (E): 0.840 ± 0.016
IOI (N): 0.853 ± 0.016
Greaterthan (E): 0.856 ± 0.021
Greaterthan (N): 0.832 ± 0.058
Docstring (E): 0.914 ± 0.020
Docstring (N): 0.915 ± 0.014


## Induction task qualitative results

In [None]:
# Turn grad on
torch.set_grad_enabled(True)

task = 'induction'
task_type = 'node'
assert task_type in ['node', 'edge'], "Type must be either 'node' or 'edge'"
print(f"Type: {task_type}")
task_mappings = {
    'gt': 'Greater-than',
    'ioi': 'Indirect Object Identification',
    'ds': 'Docstring',
    'induction': 'Induction',
}

print(f"Task: {task_mappings[task]}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

num_unique = 300
n_epochs = 500


roc_results = []

# Load residual streams
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')
print(ground_truth)


# Shuffle and create the labels
labels = torch.ones(resid_streams.shape[0]//2) # BIG ASSUMPTION: assumes first half is positive and second half is negative
labels = torch.cat((labels, torch.zeros_like(labels)))
permutation = torch.randperm(resid_streams.shape[0])
resid_shuffled = resid_streams[permutation, :, :]
labels_shuffled = labels[permutation]
cutoff = 10# int(resid_shuffled.shape[1] * 0.8)
train_streams = resid_shuffled[:cutoff, :, :].to(device)
train_labels = labels_shuffled[:cutoff].to(device)
eval_streams = resid_shuffled[cutoff:, :, :].to(device)
eval_labels = labels_shuffled[cutoff:].to(device)


model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

model = train(model, n_epochs, optimizer, train_streams, eval_streams, lambda_=0.02, alpha_=0.0)
model = model.to('cpu')
resid_streams = resid_streams.to('cpu')
# Save model
torch.save(model, f'../models/{task}/sparse_autoencoder.pt')

heads = []
layers = []
for i, l in enumerate(head_labels):
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    heads.append(head)
    layers.append(layer)

heads = list(set(heads))
layers = list(set(layers))

ground_truth_array = np.zeros((len(layers), len(heads)))
for layer, head in ground_truth:
    ground_truth_array[layer, head] = 1

normalise = False# if task == 'ds' else True

# Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)

model.eval()
learned_activations = model(resid_streams)[0].detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)

print(f"\n\nNormalise: {normalise}")
y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

# Print best f1 score (and corresponding threshold)
node_best_f1 = np.max(f1)
best_threshold = thresholds[np.argmax(f1)]
print(f"Best F1 score: {node_best_f1:.4f}")

# Print ROC AUC
print(f"ROC AUC: {node_roc_auc:.4f}\n\n")

# Turn grad on
torch.set_grad_enabled(False)

In [None]:
heads = []
layers = []
for i, l in enumerate(head_labels):
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    heads.append(head)
    layers.append(layer)

heads = list(set(heads))
layers = list(set(layers))

print(ground_truth)
ground_truth_array = np.zeros((len(layers), len(heads)))
for layer, head in ground_truth:
    ground_truth_array[layer, head] = 1

model.eval()
learned_activations = model.encoder(resid_streams).detach().cpu().numpy()
print(f"Learned activations shape: {learned_activations.shape}")

all_indices = np.argmax(learned_activations, axis=2)
print(f"All indices shape: {all_indices.shape}")    

positive_indices = all_indices[:250, :]
negative_indices = all_indices[250:, :]

unique_to_positive_array = gen_array_template(head_labels)
unique_to_negative_array = gen_array_template(head_labels)

normalise = False

for i in range(len(head_labels)):
    # Calculate head and layer
    layer, head = feature_string_to_head_and_layer(i, head_labels)

    positive = set(positive_indices[:, i].tolist())
    negative = set(negative_indices[:, i].tolist())
    total_unique = positive.union(negative)

    # In positive but not negative
    unique_to_positive = list(positive - negative)
    # In negative but not positive
    unique_to_negative = list(negative - positive)

    if normalise:
        # Normalise by total number of unique indices
        unique_to_positive_array[layer, head] = len(unique_to_positive) / len(total_unique)
        unique_to_negative_array[layer, head] = len(unique_to_negative) / len(total_unique)
    
    else:
        # Set the values
        unique_to_positive_array[layer, head] = len(unique_to_positive)
        unique_to_negative_array[layer, head] = len(unique_to_negative)

In [None]:
# Assuming unique_to_positive_array and ground_truth_array are numpy arrays
arrays_sequence = [unique_to_positive_array, ground_truth_array]

# Titles for each subplot
titles = ["Codes Unique to +ve Examples", "Ground-Truth Circuit"]

# Create a subplot layout: 1 row, 2 columns, with specified horizontal spacing
fig = make_subplots(rows=1, cols=2, subplot_titles=titles, horizontal_spacing=0.15)

# Add each array as a separate heatmap trace with a unique color scale
fig.add_trace(go.Heatmap(z=arrays_sequence[0], colorscale='Blues', coloraxis="coloraxis1"), row=1, col=1)
fig.add_trace(go.Heatmap(z=arrays_sequence[1], colorscale='Reds', coloraxis="coloraxis2"), row=1, col=2)

fontsize=24

# Manually specify the layout for each color axis to include a colorbar
fig.update_layout(
    coloraxis1=dict(colorscale='Blues', colorbar=dict(x=0.43)),
    coloraxis2=dict(colorscale='Blues', colorbar=dict(x=1)),
    width=900,  # Adjust the figure width
    height=400,  # Adjust the figure height to ensure the aspect ratio makes the heatmaps appear square
    margin=dict(l=50, r=50, t=50, b=50),  # Adjust margins if necessary
    title_font=dict(size=fontsize+4),  # Increase title font size
    font=dict(size=fontsize-4),  # Update global font size, affects tick labels and legend
)

# Update axes titles and reverse the y-axis with larger font sizes
fig.update_xaxes(title_text="Head", title_font=dict(size=20), tickfont=dict(size=16), row=1, col=1)
fig.update_xaxes(title_text="Head", title_font=dict(size=20), tickfont=dict(size=16), row=1, col=2)
fig.update_yaxes(title_text="Layer", title_font=dict(size=20), tickfont=dict(size=16), autorange="reversed", row=1, col=1)
fig.update_yaxes(title_font=dict(size=fontsize), tickfont=dict(size=fontsize-4), autorange="reversed", row=1, col=2)

# For subplot titles
for annotation in fig['layout']['annotations']: 
    annotation['font'] = dict(size=fontsize)

# Save the figure
#fig.write_image(f'../output/{task}/{task}_unique_to_positive_vs_ground_truth.pdf')

# Show the figure
fig.show()

In [None]:
from jaxtyping import Float, Int
from torch import Tensor
from functools import partial
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix, HookedTransformerConfig

# tl_model = HookedTransformer.from_pretrained(model_name = "gpt2-small", device='cpu')

# all_tokens = torch.load(f'../data/{task}/all_tokens.pt')[:20, :]
# all_tokens.shape

In [None]:
cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=50278,
    attention_dir="causal",
    attn_only=True, # defaults to False
    tokenizer_name="EleutherAI/gpt-neox-20b",
    seed=398,
    use_attn_result=True,
    normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer"
)

weights_dir = '../models/induction/attn_only_2L_half.pth'
# url = "https://drive.google.com/uc?id=1vcZLJnJoYKQs-2KOjkd6LvHZrkSdoxhu"
# output = str(weights_dir)
# gdown.download(url, output)

device = 'cpu'
tl_model = HookedTransformer(cfg)
pretrained_weights = torch.load(weights_dir, map_location=device)
tl_model.load_state_dict(pretrained_weights)
tl_model = tl_model.to('cpu')

all_tokens = torch.load(f'../data/{task}/all_tokens.pt')[:100, :]
all_tokens.shape

In [None]:
# KL divergence baby - we will need some hooks here
layer_to_ablate = 1
head_index_to_ablate = 1

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

original_loss = tl_model(all_tokens, return_type="loss")
ablated_loss = tl_model.run_with_hooks(
    all_tokens, 
    return_type="loss", 
    fwd_hooks=[(
        utils.get_act_name("result", layer_to_ablate), 
        head_ablation_hook
        )]
    )
print(f"Original Loss: {original_loss.item():.3f}")
print(f"Ablated Loss: {ablated_loss.item():.3f}")

In [None]:
import torch as t
import functools

def head_ablation_hook(
    attn_result: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
    head_index_to_ablate: int
) -> Float[Tensor, "batch seq n_heads d_model"]:
    # SOLUTION
    #print("We got here!")
    attn_result[:, :, head_index_to_ablate, :] = 0.0
    return attn_result


def cross_entropy_loss(logits, tokens):
    '''
    Computes the mean cross entropy between logits (the model's prediction) and tokens (the true values).
    '''
    log_probs = F.log_softmax(logits, dim=-1)
    pred_log_probs = t.gather(log_probs[:, :-1], -1, tokens[:, 1:, None])[..., 0]
    return -pred_log_probs.mean()


def get_ablation_scores(
    model: HookedTransformer,
    tokens: Int[Tensor, "batch seq"]
) -> Float[Tensor, "n_layers n_heads"]:
    '''
    Returns a tensor of shape (n_layers, n_heads) containing the increase in cross entropy loss from ablating the output of each head.
    '''
    # Initialize an object to store the ablation scores
    ablation_scores = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

    # Calculating loss without any ablation, to act as a baseline
    model.reset_hooks()
    logits = model(tokens, return_type="logits")
    loss_no_ablation = cross_entropy_loss(logits, tokens)

    for layer in tqdm(range(model.cfg.n_layers)):
        for head in range(model.cfg.n_heads):
            # Use functools.partial to create a temporary hook function with the head number fixed
            temp_hook_fn = functools.partial(head_ablation_hook, head_index_to_ablate=head)
            # Run the model with the ablation hook
            ablated_logits = model.run_with_hooks(tokens, fwd_hooks=[
                (utils.get_act_name("result", layer), temp_hook_fn),
            ])
            #print(utils.get_act_name("result", layer))
            # Calculate the logit difference
            loss = cross_entropy_loss(ablated_logits, tokens)
            # Store the result, subtracting the clean loss so that a value of zero means no change in loss
            ablation_scores[layer, head] = loss - loss_no_ablation

    return ablation_scores


ablation_scores = get_ablation_scores(tl_model, all_tokens)
ablation_scores

In [None]:
# Imshow ablation scores
fig = px.imshow(ablation_scores.cpu().numpy(), color_continuous_scale='RdBu', labels=dict(x="Head", y="Layer", color="Ablation Score"),
                text_auto=".2f")

# Update layout
fig.update_layout(
    title='Ablation Scores',
    xaxis_title='Head',
    yaxis_title='Layer',
    font=dict(size=18),
    plot_bgcolor='white',
    width=800,
    height=600
)

fig.show()

In [None]:
# REWRITE THIS FOR THE KL DIVERGENCE (ONE ABLATION)

def head_ablation_hook(
    attn_result: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
    head_index_to_ablate: int
) -> Float[Tensor, "batch seq n_heads d_model"]:
    # SOLUTION
    attn_result[:, :, head_index_to_ablate, :] = 0.0
    return attn_result


def cross_entropy_loss(logits, tokens):
    '''
    Computes the mean cross entropy between logits (the model's prediction) and tokens (the true values).
    '''
    log_probs = F.log_softmax(logits, dim=-1)
    pred_log_probs = t.gather(log_probs[:, :-1], -1, tokens[:, 1:, None])[..., 0]
    return -pred_log_probs.mean()

def kl_divergence(logits: torch.Tensor, base_model_logits: torch.Tensor, 
                  last_seq_element_only: bool = True, return_one_element: bool = True) -> torch.Tensor:
    if last_seq_element_only:
        logits = logits[:, -1, :]
        base_model_logits = base_model_logits[:, -1, :]

    logprobs = F.log_softmax(logits, dim=-1)
    base_model_logprobs = F.log_softmax(base_model_logits, dim=-1)
    kl_div = F.kl_div(logprobs, base_model_logprobs, log_target=True, reduction="none").sum(dim=-1)

    if return_one_element:
        return kl_div.mean()

    return kl_div

def get_ablation_scores(
    model: HookedTransformer,
    tokens: Int[Tensor, "batch seq"]
) -> Float[Tensor, "n_layers n_heads"]:
    '''
    Returns a tensor of shape (n_layers, n_heads) containing the increase in cross entropy loss from ablating the output of each head.
    '''
    # Initialize an object to store the ablation scores
    ablation_scores = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

    # Calculating loss without any ablation, to act as a baseline
    model.reset_hooks()
    logits = model(tokens, return_type="logits")

    for layer in tqdm(range(model.cfg.n_layers)):
        for head in range(model.cfg.n_heads):
            # Use functools.partial to create a temporary hook function with the head number fixed
            temp_hook_fn = functools.partial(head_ablation_hook, head_index_to_ablate=head)
            # Run the model with the ablation hook
            ablated_logits = model.run_with_hooks(tokens, fwd_hooks=[
                (utils.get_act_name("result", layer), temp_hook_fn),
            ])
            # Calculate KL Divergence
            kl_div = kl_divergence(logits, ablated_logits)
            # Store the result, subtracting the clean loss so that a value of zero means no change in loss
            ablation_scores[layer, head] = kl_div

    # Normalise the ablation scores
    #ablation_scores = ablation_scores / t.max(ablation_scores)

    return ablation_scores


ablation_scores = get_ablation_scores(tl_model, all_tokens)
ablation_scores

In [None]:
# Imshow ablation scores
fig = px.imshow(ablation_scores.cpu().numpy(), color_continuous_scale='RdBu', labels=dict(x="Head", y="Layer", color="Ablation Score"),
                text_auto=".2f")

# Update layout
fig.update_layout(
    title='Ablation Scores',
    xaxis_title='Head',
    yaxis_title='Layer',
    font=dict(size=18),
    plot_bgcolor='white',
    width=800,
    height=600
)

fig.show()

In [None]:
# Import List and Tuple
from typing import List, Tuple

def head_ablation_hook(
    attn_result: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
    head_index_to_ablate: int
) -> Float[Tensor, "batch seq n_heads d_model"]:
    attn_result[:, :, head_index_to_ablate, :] = 0.0
    return attn_result

def cross_entropy_loss(logits, tokens):
    '''
    Computes the mean cross entropy between logits (the model's prediction) and tokens (the true values).
    '''
    log_probs = F.log_softmax(logits, dim=-1)
    pred_log_probs = t.gather(log_probs[:, :-1], -1, tokens[:, 1:, None])[..., 0]
    return -pred_log_probs.mean()

def kl_divergence(logits: torch.Tensor, base_model_logits: torch.Tensor, 
                  last_seq_element_only: bool = True, return_one_element: bool = True) -> torch.Tensor:
    if last_seq_element_only:
        logits = logits[:, -1, :]
        base_model_logits = base_model_logits[:, -1, :]

    logprobs = F.log_softmax(logits, dim=-1)
    base_model_logprobs = F.log_softmax(base_model_logits, dim=-1)
    kl_div = F.kl_div(logprobs, base_model_logprobs, log_target=True, reduction="none").sum(dim=-1)

    if return_one_element:
        return kl_div.mean()

    return kl_div

def get_ablation_scores(
    model: HookedTransformer,
    tokens: Int[Tensor, "batch seq"],
    heads_to_ablate: List[Tuple[int, int]]
) -> Float[Tensor, "1"]:
    '''
    Returns the increase in cross entropy loss from ablating the output of specified heads.
    '''

    # Calculating loss without any ablation, to act as a baseline
    model.reset_hooks()
    logits = model(tokens, return_type="logits")

    # Run the model with the ablation hook
    ablated_logits = model.run_with_hooks(tokens, fwd_hooks=[
        (utils.get_act_name("result", layer), functools.partial(head_ablation_hook, head_index_to_ablate=head)) for layer, head in heads_to_ablate
    ])

    # Store the result, subtracting the clean loss so that a value of zero means no change in loss
    kl_div = kl_divergence(logits, ablated_logits)

    return kl_div

# Example usage
heads_to_ablate = [(0, 10)]  # List of tuples (layer, head) to ablate
ablation_score = get_ablation_scores(tl_model, all_tokens, heads_to_ablate)
print(ablation_score)

In [None]:
model.eval()
learned_activations = model.encoder(resid_streams).detach().cpu().numpy()
print(f"Learned activations shape: {learned_activations.shape}")

all_indices = np.argmax(learned_activations, axis=2)
print(f"All indices shape: {all_indices.shape}")    

positive_indices = all_indices[:250, :]
negative_indices = all_indices[250:, :]

unique_to_positive_array = gen_array_template(head_labels)

normalise = False

for i in range(len(head_labels)):
    # Calculate head and layer
    layer, head = feature_string_to_head_and_layer(i, head_labels)

    positive = set(positive_indices[:, i].tolist())
    negative = set(negative_indices[:, i].tolist())
    total_unique = positive.union(negative)

    # In positive but not negative
    unique_to_positive = list(positive - negative)
    # In negative but not positive
    unique_to_negative = list(negative - positive)

    if normalise: unique_to_positive_array[layer, head] = len(unique_to_positive) / len(total_unique)
    
    else: unique_to_positive_array[layer, head] = len(unique_to_positive)

array_shape = unique_to_positive_array.shape
unique_to_positive_array = unique_to_positive_array.flatten()

# Apply softmax
unique_to_positive_array = np.exp(unique_to_positive_array) / np.sum(np.exp(unique_to_positive_array))

# Reshape
unique_to_positive_array = unique_to_positive_array.reshape(array_shape)

In [None]:
def positive_array_to_ablations(pos_array, threshold):
    binary_array = (pos_array > threshold).astype(int)
    # Rows are layers, columns are heads -> return list of tuples of (layer, head)
    return [(layer, head) for layer in range(binary_array.shape[0]) for head in range(binary_array.shape[1]) if binary_array[layer, head] == 0]

y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

kl_divs = []
num_nodes = []
num_elements_in_pos_array = torch.tensor(unique_to_positive_array).numel()

for threshold in thresholds:
    heads_to_ablate = positive_array_to_ablations(unique_to_positive_array, threshold)
    print(heads_to_ablate)
    ablation_score = get_ablation_scores(tl_model, all_tokens, heads_to_ablate)
    kl_divs.append(ablation_score)
    num_nodes.append(num_elements_in_pos_array - len(heads_to_ablate))
    print(ablation_score)
    print()

In [None]:
# Zero ablation
conmy_num_nodes_zero = [10, 17, 18, 20, 34, 39, 48, 62]
conmy_kl_zero = [4.3, 0.7, 0.7, 0.9, 0.7, 0.6, 0.5, 0.4, 0.3]

# Random ablation
conmy_num_nodes_random = [6, 7, 8, 23, 49, 51, 55, 56, 78, 82, 84, 95]
conmy_kl_random = [10.3, 10.3, 10.3, 8.4, 5.3, 3.7, 2.4, 1.9, 1.0, 0.9, 0.8, 0.5]

In [None]:
# Plotly line plot of threshold vs kl div 
fig = go.Figure(data=go.Scatter(x=num_nodes, y=kl_divs, mode='lines+markers', name="Ours", line=dict(width=6), marker=dict(size=15))) 

# # Add trace of Conmy et al. results
# fig.add_trace(go.Scatter(x=conmy_num_nodes_zero, y=conmy_kl_zero, mode='lines+markers', name='ACDC (Zero)'))
# fig.add_trace(go.Scatter(x=conmy_num_nodes_random, y=conmy_kl_random, mode='lines+markers', name='ACDC (Random)'))

# Update layout 
fig.update_layout(xaxis_title='Number of nodes', yaxis_title='KL(Model, Ablated)', font=dict(size=24), plot_bgcolor='white', width=1000, height=600) 

# Save fig
fig.write_image('../output/figures/kl_nodes.pdf')

fig.show()

In [None]:
# Turn grad on
torch.set_grad_enabled(True)

task = 'induction'
task_type = 'node'
assert task_type in ['node', 'edge'], "Type must be either 'node' or 'edge'"
print(f"Type: {task_type}")
task_mappings = {
    'gt': 'Greater-than',
    'ioi': 'Indirect Object Identification',
    'ds': 'Docstring',
    'induction': 'Induction',
}

print(f"Task: {task_mappings[task]}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

num_unique = 300
n_epochs = 500


roc_results = []

# Load residual streams
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')
print(ground_truth)


# Shuffle and create the labels
labels = torch.ones(resid_streams.shape[0]//2) # BIG ASSUMPTION: assumes first half is positive and second half is negative
labels = torch.cat((labels, torch.zeros_like(labels)))
permutation = torch.randperm(resid_streams.shape[0])
resid_shuffled = resid_streams[permutation, :, :]
labels_shuffled = labels[permutation]
cutoff = 10 #int(resid_shuffled.shape[1] * 0.8)
train_streams = resid_shuffled[:cutoff, :, :].to(device)
train_labels = labels_shuffled[:cutoff].to(device)
eval_streams = resid_shuffled[cutoff:, :, :].to(device)
eval_labels = labels_shuffled[cutoff:].to(device)


model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

model = train(model, n_epochs, optimizer, train_streams, eval_streams, lambda_=0.01, alpha_=0.0)
model = model.to('cpu')
resid_streams = resid_streams.to('cpu')
# Save model
torch.save(model, f'../models/{task}/sparse_autoencoder.pt')

heads = []
layers = []
for i, l in enumerate(head_labels):
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    heads.append(head)
    layers.append(layer)

heads = list(set(heads))
layers = list(set(layers))

ground_truth_array = np.zeros((len(layers), len(heads)))
for layer, head in ground_truth:
    ground_truth_array[layer, head] = 1

normalise = False# if task == 'ds' else True

# Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)

model.eval()
learned_activations = model(resid_streams)[0].detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)

print(f"\n\nNormalise: {normalise}")
y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

# Print best f1 score (and corresponding threshold)
node_best_f1 = np.max(f1)
best_threshold = thresholds[np.argmax(f1)]
print(f"Best F1 score: {node_best_f1:.4f}")

# Print ROC AUC
print(f"ROC AUC: {node_roc_auc:.4f}\n\n")

# Turn grad on
torch.set_grad_enabled(False)

In [None]:
# Repeat with num edges
def gen_co_occurrence_matrix(all_indices, n_heads, n_feat):
    co_occurrence_matrix = np.zeros((n_heads, n_heads, n_feat, n_feat))

    for e in range(all_indices.shape[0]):  # For each example
        for h1 in range(n_heads):  # For each head
            c1 = all_indices[e, h1]  # Code in head h1
            for h2 in range(n_heads):  # For each other head
                if h1 != h2:  # Skip counting co-occurrence of a head with itself
                    c2 = all_indices[e, h2]  # Code in head h2
                    # Increment co-occurrence count for (h1, h2)
                    co_occurrence_matrix[h1, h2, c1, c2] += 1

    return co_occurrence_matrix

def normalize_co_occurrence_matrix(co_occurrence_matrix):
    # Assuming co_occurrence_matrix is of shape (n_heads, n_heads, n_feat, n_feat)
    n_heads, _, n_feat, _ = co_occurrence_matrix.shape
    normalized_matrix = np.zeros_like(co_occurrence_matrix)

    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Skip self co-occurrences
                total_co_occurrences = np.sum(co_occurrence_matrix[h1, h2, :, :])
                if total_co_occurrences > 0:  # Avoid division by zero
                    normalized_matrix[h1, h2, :, :] = co_occurrence_matrix[h1, h2, :, :] / total_co_occurrences

    return normalized_matrix

def unique_co_occurrences(positive_matrix, negative_matrix, normalise=True):
    # Normalize matrices
    if normalise:
        positive_matrix = normalize_co_occurrence_matrix(positive_matrix)
        negative_matrix = normalize_co_occurrence_matrix(negative_matrix)

    n_heads, _, n_feat, _ = positive_matrix.shape
    unique_co_occurrence_counts = np.zeros((n_heads, n_heads))
    
    for h1 in range(n_heads):
        for h2 in range(n_heads):
            if h1 != h2:  # Skip self co-occurrences
                # Find co-occurrences in positive not present in negative
                unique_positives = positive_matrix[h1, h2, :, :] > 0
                negatives = negative_matrix[h1, h2, :, :] > 0
                # Boolean array of unique positives
                unique = unique_positives & ~negatives
                if normalise:
                    # Normalize count by total co-occurrences for this head pair in positive matrix
                    total_co_occurrences = np.sum(positive_matrix[h1, h2, :, :] > 0) + np.sum(negative_matrix[h1, h2, :, :] > 0)
                    if total_co_occurrences > 0:  # Avoid division by zero
                        unique_count_normalized = np.sum(unique) / total_co_occurrences
                    else:
                        unique_count_normalized = 0
                    # Set normalized unique counts for this head pair
                    unique_co_occurrence_counts[h1, h2] = unique_count_normalized
                else:
                    # Count unique co-occurrences
                    unique_co_occurrence_counts[h1, h2] = np.sum(unique)

    return unique_co_occurrence_counts

# Learned activations and then take argmax to discretise
learned_activations = model(resid_streams)[0].detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)

positive_indices = all_indices[:250, :]
negative_indices = all_indices[250:, :]
positive_learned_activations = learned_activations[:250, :, :]
negative_learned_activations = learned_activations[250:, :, :]

# Assume all_indices, positive_indices, and negative_indices are defined, as well as n_heads and n_feat
n_feat = learned_activations.shape[-1]
n_heads = all_indices.shape[1]
positive_co_occurrence_matrix = gen_co_occurrence_matrix(positive_indices, n_heads, n_feat)
negative_co_occurrence_matrix = gen_co_occurrence_matrix(negative_indices, n_heads, n_feat)

# Calculate unique co-occurrences
normalise = False# if task in ['ds', 'ioi']  else True
print(f"Normalise: {normalise}")
unique_co_occurrence_counts = unique_co_occurrences(positive_co_occurrence_matrix, negative_co_occurrence_matrix, normalise=normalise)

# Sort (head, head) pairs by descending unique co-occurrence counts
sorted_indices = np.argsort(unique_co_occurrence_counts.flatten())[::-1]
sorted_indices = np.unravel_index(sorted_indices, unique_co_occurrence_counts.shape)
# Zip them together to create a list of (head, head) pairs
sorted_head_pairs = list(zip(sorted_indices[0], sorted_indices[1]))
print(sorted_head_pairs)
print(len(sorted_head_pairs))

circuit_components = []
for i, (h1, h2) in enumerate(sorted_head_pairs):
    (l1, h1) = feature_string_to_head_and_layer(h1, head_labels)
    (l2, h2) = feature_string_to_head_and_layer(h2, head_labels)
    circuit_components.append((l1, h1))
    circuit_components.append((l2, h2))

print(len(circuit_components))
k = 200

# y_pred is 24x24 zeros
y_pred = np.zeros((24, 24))
print(y_pred.shape)

for (h_1, h_2) in sorted_head_pairs[:k]:
    y_pred[h_1, h_2] += 1

In [None]:
def calculate_total_edges(connections):
    # Create a dictionary to store the count of heads in each layer
    layer_heads_count = {}

    # Iterate over the connections and count the heads in each layer
    for layer, head in connections:
        if layer not in layer_heads_count:
            layer_heads_count[layer] = set()
        layer_heads_count[layer].add(head)

    # Sort the layers in ascending order
    sorted_layers = sorted(layer_heads_count.keys())

    total_edges = 0

    # Iterate over the layers and calculate the edges between adjacent layers
    for i in range(len(sorted_layers) - 1):
        current_layer = sorted_layers[i]
        next_layer = sorted_layers[i + 1]

        # Calculate the number of edges between the current layer and the next layer
        edges = len(layer_heads_count[current_layer]) * len(layer_heads_count[next_layer])
        total_edges += edges

    return total_edges

def array_to_ablations(array, threshold):
    binary_array = (array > threshold).astype(int)
    try:
        # Num edges is the number of elements that are non-zero in the binary array
        num_edges = np.sum(binary_array) // 2
        # Get head numbers that are not zero
        heads = list(set(np.array(np.argwhere(binary_array).tolist())[:, 0].tolist()))
        # Heads to ablate
        heads_to_ablate = [x for x in range(24) if x not in heads]
        # Convert head index into (layer, head) tuple
        heads = [feature_string_to_head_and_layer(h, head_labels) for h in heads_to_ablate]
        #num_edges = 144 - calculate_total_edges(heads)
        # Num edges is 
        return heads, num_edges
    except:
        heads = [x for x in range(24)]
        heads = [feature_string_to_head_and_layer(h, head_labels) for h in heads]
        return heads, 0

kl_divs = []
num_nodes = []

thresholds = []

old_heads_to_ablate = 0
for threshold in range(0, int(unique_co_occurrence_counts.max())):
    heads_to_ablate, num_nodes_in_circuit = array_to_ablations(unique_co_occurrence_counts, threshold)
    if len(heads_to_ablate) == old_heads_to_ablate:
        continue
    old_heads_to_ablate = len(heads_to_ablate)
    thresholds.append(threshold)

print(thresholds)


for threshold in thresholds:
    heads_to_ablate, num_nodes_in_circuit = array_to_ablations(unique_co_occurrence_counts, threshold)
    print(heads_to_ablate)
    ablation_score = get_ablation_scores(tl_model, all_tokens, heads_to_ablate)
    kl_divs.append(ablation_score)
    num_nodes.append(num_nodes_in_circuit)
    print(ablation_score)
    print()

In [None]:
import plotly.graph_objects as go

# Plotly line plot of threshold vs kl div
fig = go.Figure(data=go.Scatter(
    x=num_nodes,
    y=kl_divs,
    mode='lines+markers',
    name="Ours",
    line=dict(width=6),  # Increase line width
    marker=dict(size=15)  # Increase marker size
))

# Add trace of Conmy et al. results
fig.add_trace(go.Scatter(
    x=conmy_num_nodes_zero,
    y=conmy_kl_zero,
    mode='lines+markers',
    name='ACDC (Zero)',
    line=dict(width=6),  # Increase line width
    marker=dict(size=15)  # Increase marker size
))

fig.add_trace(go.Scatter(
    x=conmy_num_nodes_random,
    y=conmy_kl_random,
    mode='lines+markers',
    name='ACDC (Random)',
    line=dict(width=6),  # Increase line width
    marker=dict(size=15)  # Increase marker size
))

# Update layout
fig.update_layout(
    xaxis_title='Number of edges',
    yaxis_title='KL(Model, Ablated)',
    font=dict(size=24),  # Increase font size
    plot_bgcolor='white',
    width=1000,
    height=600,
    legend=dict(font=dict(size=18))  # Increase legend font size
)

# Y axis log
fig.update_yaxes(type="log")

# Set x axis max to 100
fig.update_xaxes(range=[0, 100])

# Save figure
#fig.write_image('../output/figures/edge_node_kl_induction.pdf')

fig.show()

In [None]:
def calculate_total_edges(connections):
    # Create a dictionary to store the count of heads in each layer
    layer_heads_count = {}

    # Iterate over the connections and count the heads in each layer
    for layer, head in connections:
        if layer not in layer_heads_count:
            layer_heads_count[layer] = set()
        layer_heads_count[layer].add(head)

    # Sort the layers in ascending order
    sorted_layers = sorted(layer_heads_count.keys())

    total_edges = 0

    # Iterate over the layers and calculate the edges between adjacent layers
    for i in range(len(sorted_layers) - 1):
        current_layer = sorted_layers[i]
        next_layer = sorted_layers[i + 1]

        # Calculate the number of edges between the current layer and the next layer
        edges = len(layer_heads_count[current_layer]) * len(layer_heads_count[next_layer])
        total_edges += edges

    return total_edges

con_list = []
for l in range(2):
    for h in range(12):
        con_list.append((l, h))

calculate_total_edges(con_list)

## KL div plot

In [None]:
# Import List and Tuple
from typing import List, Tuple

def head_ablation_hook(
    attn_result: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
    head_index_to_ablate: int
) -> Float[Tensor, "batch seq n_heads d_model"]:
    attn_result[:, :, head_index_to_ablate, :] = 0.0
    return attn_result

def cross_entropy_loss(logits, tokens):
    '''
    Computes the mean cross entropy between logits (the model's prediction) and tokens (the true values).
    '''
    log_probs = F.log_softmax(logits, dim=-1)
    pred_log_probs = t.gather(log_probs[:, :-1], -1, tokens[:, 1:, None])[..., 0]
    return -pred_log_probs.mean()

def kl_divergence(logits: torch.Tensor, base_model_logits: torch.Tensor, 
                  last_seq_element_only: bool = True, return_one_element: bool = True) -> torch.Tensor:
    if last_seq_element_only:
        logits = logits[:, -1, :]
        base_model_logits = base_model_logits[:, -1, :]

    logprobs = F.log_softmax(logits, dim=-1)
    base_model_logprobs = F.log_softmax(base_model_logits, dim=-1)
    kl_div = F.kl_div(logprobs, base_model_logprobs, log_target=True, reduction="none").sum(dim=-1)

    if return_one_element:
        return kl_div.mean()

    return kl_div

def get_ablation_scores(
    model: HookedTransformer,
    tokens: Int[Tensor, "batch seq"],
    heads_to_ablate: List[Tuple[int, int]]
) -> Float[Tensor, "1"]:
    '''
    Returns the increase in cross entropy loss from ablating the output of specified heads.
    '''

    # Calculating loss without any ablation, to act as a baseline
    model.reset_hooks()
    logits = model(tokens, return_type="logits")
    loss = cross_entropy_loss(logits, tokens)

    # Run the model with the ablation hook
    ablated_logits = model.run_with_hooks(tokens, fwd_hooks=[
        (utils.get_act_name("z", layer, "attn"), functools.partial(head_ablation_hook, head_index_to_ablate=head)) for layer, head in heads_to_ablate
    ])

    # Store the result, subtracting the clean loss so that a value of zero means no change in loss
    kl_div = kl_divergence(logits, ablated_logits)

    ablated_loss = cross_entropy_loss(ablated_logits, tokens)

    return kl_div, ablated_loss - loss

In [None]:
tl_model = HookedTransformer.from_pretrained(model_name = "gpt2-small", device='cpu')

In [None]:
task = 'ioi'
all_tokens = torch.load(f'../data/{task}/all_tokens.pt')
if task == 'ioi': 
    # Remove exclamation marks
    all_tokens = [x.replace('!', '') for x in all_tokens]
all_tokens = tl_model.to_tokens(all_tokens, padding_side='left')[:50, :]

In [None]:
# Example usage
heads_to_ablate = [(0, 10)]  # List of tuples (layer, head) to ablate
kl_div, ablated_score = get_ablation_scores(tl_model, all_tokens, heads_to_ablate)
kl_div, ablated_score

In [None]:
# Load residual streams and other data for the current task
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

# Ground truth array
heads = []
layers = []
for i, l in enumerate(head_labels):
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    heads.append(head)
    layers.append(layer)

heads = list(set(heads))
layers = list(set(layers))

ground_truth_array = np.zeros((len(layers), len(heads)))
for layer, head in ground_truth:
    ground_truth_array[layer, head] = 1

# Load save_dict for the current task
savepath = f"../models/{task}/sparse_autoencoder.pt"
save_dict = torch.load(savepath)
num_unique = save_dict['node_best_num_unique']
lambda_ = save_dict['node_best_lambda']
best_roc_auc = save_dict['node_best_roc_auc']

# Load the model for the current task
model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
model.load_state_dict(save_dict['model'])

model.eval()
learned_activations = model.encoder(resid_streams).detach().cpu().numpy()
print(f"Learned activations shape: {learned_activations.shape}")

all_indices = np.argmax(learned_activations, axis=2)
print(f"All indices shape: {all_indices.shape}")    

positive_indices = all_indices[:250, :]
negative_indices = all_indices[250:, :]

unique_to_positive_array = gen_array_template(head_labels)

normalise = False

for i in range(len(head_labels)):
    # Calculate head and layer
    layer, head = feature_string_to_head_and_layer(i, head_labels)

    positive = set(positive_indices[:, i].tolist())
    negative = set(negative_indices[:, i].tolist())
    total_unique = positive.union(negative)

    # In positive but not negative
    unique_to_positive = list(positive - negative)
    # In negative but not positive
    unique_to_negative = list(negative - positive)

    if normalise: unique_to_positive_array[layer, head] = len(unique_to_positive) / len(total_unique)
    
    else: unique_to_positive_array[layer, head] = len(unique_to_positive)

array_shape = unique_to_positive_array.shape
unique_to_positive_array = unique_to_positive_array.flatten()

# Apply softmax
unique_to_positive_array = np.exp(unique_to_positive_array) / np.sum(np.exp(unique_to_positive_array))

# Reshape
unique_to_positive_array = unique_to_positive_array.reshape(array_shape)

def positive_array_to_ablations(pos_array, threshold):
    binary_array = (pos_array > threshold).astype(int)
    # Rows are layers, columns are heads -> return list of tuples of (layer, head)
    return [(layer, head) for layer in range(binary_array.shape[0]) for head in range(binary_array.shape[1]) if binary_array[layer, head] == 0]

y_true, y_pred, fpr, tpr, node_roc_auc, f1, thresholds = gen_softmaxed_unique_to_pos(all_indices, ground_truth_array, head_labels, normalise=normalise)

kl_divs = []
ablated_scores = []
num_nodes = []
num_elements_in_pos_array = torch.tensor(unique_to_positive_array).numel()

for threshold in thresholds:
    heads_to_ablate = positive_array_to_ablations(unique_to_positive_array, threshold)
    print(heads_to_ablate)
    kl_div, ablated_score = get_ablation_scores(tl_model, all_tokens, heads_to_ablate)
    kl_divs.append(kl_div)
    ablated_scores.append(ablated_score)
    num_nodes.append(num_elements_in_pos_array - len(heads_to_ablate))
    print(kl_div)
    print(ablated_score)
    print()

In [None]:
# Only keep lists up to where num nodes <= 30
# Index where num nodes <= 30
idx = (np.array(num_nodes) <= 30).astype(int).argmin()
print(idx)
kl_divs = kl_divs[:idx]
ablated_scores = ablated_scores[:idx]
num_nodes = num_nodes[:idx]

In [None]:
# Plotly line plot of num nodes vs. kl div and ablated score
# Both lines on same fig (trace)
fig = go.Figure()
fig.add_trace(go.Scatter(x=num_nodes, y=kl_divs, mode='lines+markers', name='KL Divergence'))
fig.add_trace(go.Scatter(x=num_nodes, y=ablated_scores, mode='lines+markers', name='Ablated Score'))

# Update layout
fig.update_layout(
    xaxis_title='Nodes Remaining',
    font=dict(size=18),
    plot_bgcolor='white',
    width=800,
    height=600
)

fig.show()


## Number of examples required

In [118]:
import plotly.graph_objects as go
import numpy as np

data = {
    'gt': {
        5: [0.8522633744855967, 0.8197530864197531, 0.882716049382716, 0.8069958847736625, 0.9419753086419753],
        10: [0.8209876543209877, 0.8633744855967078, 0.8275720164609053, 0.8761316872427983, 0.859670781893004],
        25: [0.8213991769547324, 0.8230452674897119, 0.7473251028806585, 0.8773662551440329, 0.8502057613168724],
        50: [0.817283950617284, 0.8687242798353909, 0.8497942386831275, 0.8061728395061728, 0.7888888888888889],
        100: [0.6823045267489711, 0.7415637860082305, 0.6806584362139918, 0.8045267489711935, 0.7057613168724279],
        250: [0.6312757201646091, 0.7090534979423868, 0.6847736625514403, 0.674074074074074, 0.697119341563786]
    },
    'ioi': {
        5: [0.8872229465449804, 0.8811929595827901, 0.8714146023468057, 0.8547913950456323, 0.902542372881356],
        10: [0.868155149934811, 0.8521838331160365, 0.8468057366362451, 0.901238591916558, 0.8725554106910038],
        25: [0.8244784876140808, 0.8446870925684486, 0.8479465449804433, 0.8399608865710562, 0.8329530638852672],
        50: [0.8265971316818774, 0.8239895697522817, 0.8632659713168188, 0.8288787483702738, 0.811277705345502],
        100: [0.8129074315514994, 0.7980769230769231, 0.8231747066492829, 0.8192633637548892, 0.8160039113428944],
        250: [0.7713494132985659, 0.7886245110821382, 0.8155149934810952, 0.7845501955671448, 0.7371251629726205]
    },
    'ds': {
        5: [0.7916666666666666, 0.8685897435897436, 0.858974358974359, 0.8461538461538461, 0.875],
        10: [0.8461538461538461, 0.8333333333333334, 0.9230769230769231, 0.8942307692307693, 0.8493589743589743],
        25: [0.8621794871794871, 0.8461538461538461, 0.842948717948718, 0.8141025641025641, 0.8621794871794872],
        50: [0.8461538461538463, 0.8910256410256411, 0.8846153846153847, 0.9262820512820513, 0.8685897435897436],
        100: [0.8589743589743589, 0.8621794871794872, 0.8717948717948718, 0.8653846153846154, 0.8365384615384616],
        250: [0.8717948717948718, 0.8237179487179488, 0.858974358974359, 0.8429487179487181, 0.858974358974359]
    }
}

fig = go.Figure()

# Iterate through each task and extract x, y, and error values
for task, task_data in data.items():
    x = sorted(task_data.keys())
    y = [np.mean(values) for key, values in sorted(task_data.items())]
    error = [np.std(values) for key, values in sorted(task_data.items())]

    fig.add_trace(go.Scatter(
        x=x,
        y=y,
        mode='lines+markers',
        name=task.upper(),
        error_y=dict(
            type='data',
            array=error,
            visible=True,
            thickness=3
        ),
        line=dict(width=6),
        marker=dict(size=12)
    ))

fig.update_layout(
    xaxis_title='No. of examples (for training SAE)',
    yaxis_title='Mean ROC AUC',
    font=dict(
        family="Palatino",
        size=24,
        color="black"
    ),
    legend=dict(
        x=0.8,
        y=0.15,
        bgcolor='rgba(255, 255, 255, 0.8)',
        bordercolor="Black",
        borderwidth=1
    ),
    plot_bgcolor='white',
    width=1000,
    height=600
)

# Log x-axis
fig.update_xaxes(type="log")

# Save figure
fig.write_image('../output/figures/roc_auc_vs_examples.pdf')

fig.show()

In [116]:
import json
import numpy as np
import plotly.graph_objects as go

# Task mappings
task_mappings = {'gt': 'Greater-than', 'ioi': 'Indirect Object Identification', 'ds': 'Docstring'}

# Read in the JSON files
results = {}
for task in ['gt', 'ioi', 'ds']:
    with open(f'../output/data/roc_results_num_codes_{task}.json', 'r') as f:
        results[task] = json.load(f)

# Convert results to arrays and calculate mean and standard deviation
means = {}
stds = {}
for task, task_results in results.items():
    means[task] = []
    stds[task] = []
    for level, level_results in sorted(task_results.items(), key=lambda x: int(x[0])):
        level_results = np.array(level_results)
        means[task].append(level_results.mean())
        stds[task].append(level_results.std())

# Create the plot
fig = go.Figure()
levels = sorted(results['ioi'].keys(), key=int)

for task in ['gt', 'ioi', 'ds']:
    fig.add_trace(go.Scatter(
        x=levels,
        y=means[task],
        error_y=dict(
            type='data',
            array=stds[task],
            visible=True,
            thickness=3
        ),
        mode='lines+markers',
        line=dict(width=6),
        marker=dict(size=12),
        name=task.upper()
    ))

fig.update_layout(
    xaxis_title='No. of examples (for counting codes)',
    yaxis_title='Mean ROC AUC',
    font=dict(
        family="Palatino",
        size=24,
        color="black"
    ),
    legend=dict(
        x=0.8,
        y=0.15,
        bgcolor='rgba(255, 255, 255, 0.8)',
        bordercolor="Black",
        borderwidth=1
    ),
    plot_bgcolor='white',
    width=1000,
    height=600
)

# Save figure
fig.write_image('../output/figures/roc_auc_vs_examples_counting.pdf')

fig.show()

## Types of easy negatives for Greater-than

In [None]:
import plotly.graph_objects as go

# data = {1: [0.8967078189300411, 0.8798353909465021, 0.8633744855967078, 0.8674897119341565, 0.8028806584362139],
#         2: [0.8308641975308642, 0.7962962962962963, 0.7876543209876543, 0.7893004115226336, 0.7925925925925925],
#         3: [0.7395061728395061, 0.7584362139917695, 0.7329218106995884, 0.7382716049382716, 0.7925925925925925],
#         4: [0.7230452674897119, 0.5740740740740741, 0.6041152263374485, 0.642798353909465, 0.6781893004115226],
#         5: [0.5, 0.5, 0.5, 0.5, 0.5]}

data = {1: [0.8654320987654321, 0.8440329218106997, 0.8864197530864198, 0.8687242798353909, 0.8263374485596707], 
        2: [0.8415637860082305, 0.8098765432098765, 0.8119341563786008, 0.8320987654320988, 0.7835390946502058], 
        3: [0.7098765432098765, 0.7320987654320987, 0.7934156378600823, 0.6835390946502056, 0.7835390946502058], 
        4: [0.7366255144032922, 0.7061728395061728, 0.854320987654321, 0.6962962962962963, 0.7], 
        5: [0.6201646090534979, 0.6275720164609053, 0.6189300411522634, 0.5584362139917696, 0.6337448559670782]}

labels = ['Range', 'Year', 'Random', 'Unrelated', 'Copy']

means = [sum(data[i])/len(data[i]) for i in range(1, 6)]
errors = [np.std(data[i]) for i in range(1, 6)]

fig = go.Figure(data=[
    go.Bar(x=labels, y=means, error_y=dict(type='data', array=errors))
])

fig.update_layout(
    xaxis_title='Negative Example Type',
    yaxis_title='ROC AUC',
    font=dict(size=24),
    width=1000,
    height=600,
    # White background,
    plot_bgcolor='white'
)

# Save the figure
fig.write_image('../output/figures/easy_negative_types.pdf')

fig.show()

## Wall time

In [None]:
import plotly.graph_objects as go
import numpy as np

data_10_200 = {
    'gpt2-small': [2.8065803050994873, 2.555044412612915, 2.550591230392456, 2.5294189453125, 2.524600028991699],
    'gpt2-medium': [3.583122730255127, 3.4121508598327637, 3.357593297958374, 3.396043300628662, 3.3576877117156982],
    'gpt2-large': [5.276214122772217, 4.897955656051636, 4.998744964599609, 4.8978612422943115, 4.905563592910767],
    'gpt2-xl': [10.40829610824585, 9.616154670715332, 9.63361382484436, 9.654004573822021, 9.636883020401001]
}

data_100_200 = {
    'gpt2-small': [2.9565839767456055, 2.6795241832733154, 2.676745891571045, 2.6934847831726074, 2.6799378395080566],
    'gpt2-medium': [8.740079879760742, 8.61795711517334, 8.620184421539307, 8.562703132629395, 8.595884799957275],
    'gpt2-large': [20.022331714630127, 19.663774251937866, 19.67916774749756, 19.648465871810913, 19.684297561645508],
    'gpt2-xl': [41.54132032394409, 40.73045110702515, 40.85405707359314, 40.80383634567261, 40.80262112617493]
}

models = list(data_10_200.keys())

means_10_200 = [np.mean(data_10_200[model]) for model in models]
errors_10_200 = [np.std(data_10_200[model]) for model in models]

means_100_200 = [np.mean(data_100_200[model]) for model in models]
errors_100_200 = [np.std(data_100_200[model]) for model in models]

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=models,
    y=means_10_200,
    mode='lines+markers',
    name='10 SAE, 200 Pos',
    error_y=dict(type='data', array=errors_10_200, visible=True, thickness=4),
    line=dict(color='#1f77b4', width=5),
    marker=dict(size=8, color='#1f77b4', symbol='circle')
))

fig.add_trace(go.Scatter(
    x=models,
    y=means_100_200,
    mode='lines+markers',
    name='100 SAE, 200 Pos',
    error_y=dict(type='data', array=errors_100_200, visible=True, thickness=4),
    line=dict(color='#ff7f0e', width=5),
    marker=dict(size=8, color='#ff7f0e', symbol='square')
))

fig.update_layout(
    xaxis_title='Model',
    yaxis_title='Mean Wall Time (seconds)',
    font=dict(size=24),
    width=1000,
    height=600,
    plot_bgcolor='white',
    legend_title_text='SAE & Pos Examples'
)

# Save figure
fig.write_image('../output/figures/wall_time.pdf')

fig.show()

## Sparsity examination

Requires:
* $\ell_0$ norm statistics for all examples, and also just on positive and just on negative
* Number of active codes and dead codes across all examples, and also just on positive and just on negative
* Distribution of firing rates for individual codes across all examples; also segmented by positive and negative
* Histogram of code activations on positive and negative data (don't discretise here, keep scalar value)
* How threshold affects the firing rates of codes

For all tasks: docstring, greater-than, indirect object identification.

Basically, just do some data analysis.

In [None]:
device = 'cpu'
task = 'ioi'

# Load the data
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

# Load the ioi model
save_dict = torch.load(f"../models/{task}/sparse_autoencoder_dict.pt")
num_unique = save_dict['node_best_num_unique']
print(f"Number of unique features: {num_unique}")
lambda_ = save_dict['node_best_lambda']
best_roc_auc = save_dict['node_best_roc_auc']
model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
model.load_state_dict(save_dict['model'])

activations = model.encoder(resid_streams)

In [None]:
activations = model.encoder(resid_streams)
activations.shape

In [None]:
# activations have shape b x h x n
# b = batch size, h = number of heads, n = number of unique features

# Print the average number of non-zero activations for each head
non_zero_activations = torch.sum(activations != 0, dim=2).float()
average_non_zero_activations = torch.mean(non_zero_activations, dim=0)
# Plotly histogram of average number of non-zero activations
fig = go.Figure()
fig.add_trace(go.Histogram(x=average_non_zero_activations, nbinsx=20, marker=dict(color='blue')))
fig.update_xaxes(title_text="Average Non-Zero Activations", showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_yaxes(title_text="Frequency", showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_layout(height=600, width=800, title_text="Average Non-Zero Activations per Head")
fig.write_image(f'../output/figures/average_non_zero_activations.pdf')
fig.show()

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Calculate the average number of non-zero activations for each head for positive and negative examples
positive_activations = activations[:250, :, :]
negative_activations = activations[250:, :, :]
positive_non_zero_activations = torch.sum(positive_activations != 0, dim=2).float()
negative_non_zero_activations = torch.sum(negative_activations != 0, dim=2).float()
average_positive_non_zero_activations = torch.mean(positive_non_zero_activations, dim=0)
average_negative_non_zero_activations = torch.mean(negative_non_zero_activations, dim=0)

# Plotly histogram of average number of non-zero activations for positive examples
fig = make_subplots(rows=1, cols=2, subplot_titles=("Positive Examples", "Negative Examples"), shared_yaxes=True)

fig.add_trace(go.Histogram(x=average_positive_non_zero_activations, nbinsx=20, marker=dict(color='#00CC96')), row=1, col=1)
fig.add_trace(go.Histogram(x=average_negative_non_zero_activations, nbinsx=20, marker=dict(color='#AB63FA')), row=1, col=2)

fig.update_xaxes(title_text="Average Non-Zero Activations", showgrid=True, gridwidth=1, gridcolor='lightgray', row=1, col=1)
fig.update_yaxes(title_text="Frequency", showgrid=True, gridwidth=1, gridcolor='lightgray', row=1, col=1)
fig.update_xaxes(title_text="Average Non-Zero Activations", showgrid=True, gridwidth=1, gridcolor='lightgray', row=1, col=2)
fig.update_yaxes(title_text="", showgrid=True, gridwidth=1, gridcolor='lightgray', row=1, col=2)

# Increase font sizes
title_font_size = 32
axis_title_font_size = 28
tick_font_size = 24

fig.update_layout(
    height=600,
    width=1200,
    title_text="",
    showlegend=False,
    title_font=dict(size=title_font_size),
    font=dict(size=tick_font_size)
)

# Update axis title font sizes
fig.update_xaxes(title_font=dict(size=axis_title_font_size))
fig.update_yaxes(title_font=dict(size=axis_title_font_size))

# Update subplot title font sizes
for annotation in fig['layout']['annotations']:
    annotation['font'] = dict(size=title_font_size)

fig.write_image(f'../output/figures/average_non_zero_activations_pos_neg.pdf')
fig.show()

In [None]:
# Print the overall average for each
print(f"Positive Examples: {torch.mean(average_positive_non_zero_activations):.4f}")
print(f"Negative Examples: {torch.mean(average_negative_non_zero_activations):.4f}")

In [None]:
import torch

codes = torch.argmax(activations, dim=2)

unique_counts = torch.zeros(codes.shape[1], dtype=torch.long)

for i in range(codes.shape[1]):
    unique_counts[i] = codes[:, i].unique().numel()

# Plotly bar chart of number of unique features per head
fig = go.Figure()
fig.add_trace(go.Bar(x=list(range(unique_counts.shape[0])), y=unique_counts, marker=dict(color='blue')))
fig.update_xaxes(title_text="Head", showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_yaxes(title_text="Number of Unique Features", showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_layout(height=600, width=800, title_text="Number of Unique Features per Head")
fig.write_image(f'../output/figures/unique_features_per_head.pdf')
fig.show()

In [None]:
# Now we can calculate the number of unique features per head for positive and negative examples
positive_codes = codes[:250, :]
negative_codes = codes[250:, :]
positive_unique_counts = torch.zeros(positive_codes.shape[1], dtype=torch.long)
negative_unique_counts = torch.zeros(negative_codes.shape[1], dtype=torch.long)

for i in range(positive_codes.shape[1]):
    positive_unique_counts[i] = positive_codes[:, i].unique().numel()
    negative_unique_counts[i] = negative_codes[:, i].unique().numel()

# Plotly bar chart of number of unique features per head for positive and negative examples
fig = make_subplots(rows=1, cols=2, subplot_titles=("Positive Examples", "Negative Examples"), shared_yaxes=True)
fig.add_trace(go.Bar(x=list(range(positive_unique_counts.shape[0])), y=positive_unique_counts, marker=dict(color='blue')), row=1, col=1)
fig.add_trace(go.Bar(x=list(range(negative_unique_counts.shape[0])), y=negative_unique_counts, marker=dict(color='red')), row=1, col=2)
fig.update_xaxes(title_text="Head", showgrid=True, gridwidth=1, gridcolor='lightgray', row=1, col=1)
fig.update_yaxes(title_text="Number of Unique Features", showgrid=True, gridwidth=1, gridcolor='lightgray', row=1, col=1)
fig.update_xaxes(title_text="Head", showgrid=True, gridwidth=1, gridcolor='lightgray', row=1, col=2)
fig.update_yaxes(title_text="", showgrid=True, gridwidth=1, gridcolor='lightgray', row=1, col=2)
fig.update_layout(height=600, width=1200, title_text="Number of Unique Codes per Head", showlegend=False)
fig.write_image(f'../output/figures/unique_features_per_head_pos_neg.pdf')
fig.show()

In [None]:
device = 'cpu'
task = 'gt'

# Load the data
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

# Load the ioi model
save_dict = torch.load(f"../models/{task}/sparse_autoencoder_dict.pt")
num_unique = save_dict['node_best_num_unique']
print(f"Number of unique features: {num_unique}")
lambda_ = save_dict['node_best_lambda']
best_roc_auc = save_dict['node_best_roc_auc']
model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
model.load_state_dict(save_dict['model'])

activations = model.encoder(resid_streams)

codes = torch.argmax(activations, dim=2)

unique_counts = torch.zeros(codes.shape[1], dtype=torch.long)

for i in range(codes.shape[1]):
    unique_counts[i] = codes[:, i].unique().numel()

# Now we can calculate the number of unique features per head for positive and negative examples
positive_codes = codes[:250, :]
negative_codes = codes[250:, :]
positive_unique_counts = torch.zeros(positive_codes.shape[1], dtype=torch.long)
negative_unique_counts = torch.zeros(negative_codes.shape[1], dtype=torch.long)

for i in range(positive_codes.shape[1]):
    positive_unique_counts[i] = positive_codes[:, i].unique().numel()
    negative_unique_counts[i] = negative_codes[:, i].unique().numel()

# Convert ground_truth to head number
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth_heads = []
for layer, head in ground_truth:
    head_str = f"L{layer}H{head}"
    head_num = head_labels.index(head_str)
    ground_truth_heads.append(head_num)
    
fig = go.Figure()

# Create a list of colors based on the ground_truth_heads
ground_truth_color = '#00CC96'
other_color = '#AB63FA'

# Create a list of colors based on the ground_truth_heads
colors = [ground_truth_color if i in ground_truth_heads else other_color for i in range(positive_unique_counts.shape[0])]

fig.add_trace(go.Bar(x=list(range(positive_unique_counts.shape[0])),
                     y=positive_unique_counts - negative_unique_counts,
                     marker=dict(color=colors),
                     showlegend=False, width=1))  # Set showlegend to False for the bar trace

fig.update_xaxes(title_text="Head", showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_yaxes(title_text="Difference in # Unique Features", showgrid=True, gridwidth=1, gridcolor='lightgray')

# Add legend for "In circuit" and "Not in circuit"
fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color=ground_truth_color, size=10), name='In circuit'))
fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color=other_color, size=10), name='Not in circuit'))

fig.update_layout(height=600, width=1000, title_text="Difference in No. Unique Features per Head",
                  plot_bgcolor='white', showlegend=True)

fig.write_image(f'../output/figures/unique_features_per_head_diff.pdf')
fig.show()

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

device = 'cpu'
tasks = ['gt', 'ioi']

fig = make_subplots(rows=1, cols=2, subplot_titles=("GT Task", "IOI Task"), horizontal_spacing=0.05)

for i, task in enumerate(tasks):
    # Load the data
    resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
    head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
    ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

    # Load the ioi model
    save_dict = torch.load(f"../models/{task}/sparse_autoencoder_dict.pt")
    num_unique = save_dict['node_best_num_unique']
    print(f"Number of unique features: {num_unique}")
    lambda_ = save_dict['node_best_lambda']
    best_roc_auc = save_dict['node_best_roc_auc']
    model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
    model.load_state_dict(save_dict['model'])

    activations = model.encoder(resid_streams)
    codes = torch.argmax(activations, dim=2)
    unique_counts = torch.zeros(codes.shape[1], dtype=torch.long)

    for j in range(codes.shape[1]):
        unique_counts[j] = codes[:, j].unique().numel()

    # Now we can calculate the number of unique features per head for positive and negative examples
    positive_codes = codes[:250, :]
    negative_codes = codes[250:, :]
    positive_unique_counts = torch.zeros(positive_codes.shape[1], dtype=torch.long)
    negative_unique_counts = torch.zeros(negative_codes.shape[1], dtype=torch.long)

    for j in range(positive_codes.shape[1]):
        positive_unique_counts[j] = positive_codes[:, j].unique().numel()
        negative_unique_counts[j] = negative_codes[:, j].unique().numel()

    # Convert ground_truth to head number
    head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
    ground_truth_heads = []
    for layer, head in ground_truth:
        head_str = f"L{layer}H{head}"
        head_num = head_labels.index(head_str)
        ground_truth_heads.append(head_num)

    # Create a list of colors based on the ground_truth_heads
    ground_truth_color = '#00CC96'
    other_color = '#AB63FA'
    colors = [ground_truth_color if j in ground_truth_heads else other_color for j in range(positive_unique_counts.shape[0])]

    fig.add_trace(go.Bar(x=list(range(positive_unique_counts.shape[0])),
                         y=positive_unique_counts - negative_unique_counts,
                         marker=dict(color=colors),
                         showlegend=False, width=1.1),  # Set showlegend to False for the bar trace
                  row=1, col=i+1)

    fig.update_xaxes(title_text="Head", showgrid=True, gridwidth=1, gridcolor='lightgray', row=1, col=i+1)
    if i == 0:
        fig.update_yaxes(title_text="Diff. in # Unique Codes", showgrid=True, gridwidth=1, gridcolor='lightgray', row=1, col=i+1)
    else:
        fig.update_yaxes(title_text="", showgrid=True, gridwidth=1, gridcolor='lightgray', row=1, col=i+1)

# Add legend for "In circuit" and "Not in circuit"
fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color=ground_truth_color, size=10), name='In circuit'))
fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color=other_color, size=10), name='Not in circuit'))

fig.update_layout(height=600, width=1600,# title_text="Difference in No. Unique Features per Head",
                  plot_bgcolor='white', showlegend=True)

# Make the fonts bigger
fontsize = 24
fig.update_layout(font=dict(size=fontsize-2))
fig.update_xaxes(tickfont=dict(size=fontsize-4))
fig.update_yaxes(tickfont=dict(size=fontsize-4))
# Make subplot titles bigger
for annotation in fig['layout']['annotations']: 
    annotation['font'] = dict(size=fontsize+4)
# Make legend font smaller
fig.update_layout(legend=dict(font=dict(size=fontsize-4)))

fig.write_image(f'../output/figures/unique_features_per_head_diff_combined.pdf')
fig.show()

In [None]:
task = 'ioi'
device = 'cpu'
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

# Load the ioi model
save_dict = torch.load(f"../models/{task}/sparse_autoencoder_dict.pt")
num_unique = save_dict['node_best_num_unique']
lambda_ = save_dict['node_best_lambda']
best_roc_auc = save_dict['node_best_roc_auc']
model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
model.load_state_dict(save_dict['model'])

activations = model.encoder(resid_streams)
codes = torch.argmax(activations, dim=2)
unique_counts = torch.zeros(codes.shape[1], dtype=torch.long)

for j in range(codes.shape[1]):
    unique_counts[j] = codes[:, j].unique().numel()

# Now we can calculate the number of unique features per head for positive and negative examples
positive_codes = codes[:250, :]
negative_codes = codes[250:, :]
positive_unique_counts = torch.zeros(positive_codes.shape[1], dtype=torch.long)
negative_unique_counts = torch.zeros(negative_codes.shape[1], dtype=torch.long)

ground_truth_heads = [141, 127, 93]

# Specify the head number (k) for which you want to plot the histograms
for h in ground_truth_heads:
    k = h

    # Find the most common code in the positive indices/codes for head k
    positive_codes_k = positive_codes[:, k]
    unique_codes, code_counts = torch.unique(positive_codes_k, return_counts=True)
    code_counts[0] = 0  # Remove zero activations
    # print(unique_codes)
    # print()
    # print(code_counts)
    # print()
    most_common_code = unique_codes[torch.argmax(code_counts)]
    # print(most_common_code)
    # print()

    # Retrieve the activations corresponding to the most common code for both positive and negative examples
    positive_activations = activations[:250, k, most_common_code]
    negative_activations = activations[250:, k, most_common_code]

    # Remove zero activations from both positive and negative activations
    positive_activations = positive_activations[positive_activations != 0]
    negative_activations = negative_activations[negative_activations != 0]

    # Create the histogram traces for positive and negative activations
    positive_trace = go.Histogram(x=positive_activations.cpu().numpy(), name='Positive', opacity=0.75, nbinsx=20)
    negative_trace = go.Histogram(x=negative_activations.cpu().numpy(), name='Negative', opacity=0.75, nbinsx=20)

    # Create the layout for the plot
    layout = go.Layout(
        title=f'Activation Histograms for Most Common Code in Head {k}',
        xaxis=dict(title='Activation'),
        yaxis=dict(title='Frequency'),
        bargap=0.1,
        bargroupgap=0.1
    )

    # Create the figure and add the traces
    fig = go.Figure(data=[positive_trace, negative_trace], layout=layout)

    # Display the plot
    fig.show()

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

task = 'ioi' 
device = 'cpu'
resid_streams = torch.load(f"../data/{task}/resid_heads_mean.pt").to(device)
head_labels = torch.load(f'../data/{task}/labels_heads_mean.pt')
ground_truth = torch.load(f'../data/{task}/ground_truth.pt')

# Load the ioi model
save_dict = torch.load(f"../models/{task}/sparse_autoencoder_dict.pt")
num_unique = save_dict['node_best_num_unique']
lambda_ = save_dict['node_best_lambda']
best_roc_auc = save_dict['node_best_roc_auc']

model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
model.load_state_dict(save_dict['model'])

activations = model.encoder(resid_streams)
codes = torch.argmax(activations, dim=2)
unique_counts = torch.zeros(codes.shape[1], dtype=torch.long)
for j in range(codes.shape[1]):
    unique_counts[j] = codes[:, j].unique().numel()

# Now we can calculate the number of unique features per head for positive and negative examples
positive_codes = codes[:250, :]
negative_codes = codes[250:, :]
positive_unique_counts = torch.zeros(positive_codes.shape[1], dtype=torch.long)
negative_unique_counts = torch.zeros(negative_codes.shape[1], dtype=torch.long)

ground_truth_heads = [141, 127, 93]

# Create subplots
fig = make_subplots(rows=3, cols=1, subplot_titles=[f"Head {h}" for h in ground_truth_heads])

# Make subplot title font size larger
for annotation in fig['layout']['annotations']:
    annotation['font'] = dict(size=24)

for i, h in enumerate(ground_truth_heads):
    k = h
    
    # Find the most common code in the positive indices/codes for head k
    positive_codes_k = positive_codes[:, k]
    unique_codes, code_counts = torch.unique(positive_codes_k, return_counts=True)
    code_counts[0] = 0  # Remove zero activations
    most_common_code = unique_codes[torch.argmax(code_counts)]
    
    # Retrieve the activations corresponding to the most common code for both positive and negative examples
    positive_activations = activations[:250, k, most_common_code]
    negative_activations = activations[250:, k, most_common_code]
    
    # Remove zero activations from both positive and negative activations
    positive_activations = positive_activations[positive_activations != 0]
    negative_activations = negative_activations[negative_activations != 0]
    
    # Create the histogram traces for positive and negative activations
    positive_trace = go.Histogram(x=positive_activations.cpu().numpy(), name='Positive', opacity=0.75, nbinsx=30, marker=dict(color='#00CC96'), showlegend=(i==0))
    negative_trace = go.Histogram(x=negative_activations.cpu().numpy(), name='Negative', opacity=0.75, nbinsx=30, marker=dict(color='#AB63FA'), showlegend=(i==0))
    
    # Add traces to the corresponding subplot
    fig.add_trace(positive_trace, row=i+1, col=1)
    fig.add_trace(negative_trace, row=i+1, col=1)
    
    # Update x-axis and y-axis labels for each subplot
    if i == 2:
        fig.update_xaxes(title_text="Activation", row=i+1, col=1)
    else:
        fig.update_xaxes(title_text="", row=i+1, col=1)
    fig.update_yaxes(title_text="Frequency", row=i+1, col=1)

# Update layout for publication quality
fig.update_layout(
    #title="Activation Histograms for Most Common Code in Selected Heads",
    height=900,
    width=1200,
    plot_bgcolor='white',
    showlegend=True,
    legend=dict(x=0.8, y=1.0),
    font=dict(size=24)
)

# Save the figure
fig.write_image(f'../output/figures/activation_histograms.pdf')

# Display the plot
fig.show()

# Individual task examinations

## Greater-than

In [5]:
# Create a dataset by filling in the year
def create_prompt(year, century):
    # Format year
    if year < 10: year = f"0{year}"
    else: year = str(year)
    # Format century
    century = str(century)
    prompt = f"The war lasted from the year {century}{year} to {century}"
    return prompt

# Create a dataset by filling in the year as we go up
years = list(range(0, 100))
centuries = [15, 16, 17, 18, 19]
prompts = []
for century in centuries:
    for year in years:
        prompt = create_prompt(year, century)
        prompts.append(prompt)

prompts

['The war lasted from the year 1500 to 15',
 'The war lasted from the year 1501 to 15',
 'The war lasted from the year 1502 to 15',
 'The war lasted from the year 1503 to 15',
 'The war lasted from the year 1504 to 15',
 'The war lasted from the year 1505 to 15',
 'The war lasted from the year 1506 to 15',
 'The war lasted from the year 1507 to 15',
 'The war lasted from the year 1508 to 15',
 'The war lasted from the year 1509 to 15',
 'The war lasted from the year 1510 to 15',
 'The war lasted from the year 1511 to 15',
 'The war lasted from the year 1512 to 15',
 'The war lasted from the year 1513 to 15',
 'The war lasted from the year 1514 to 15',
 'The war lasted from the year 1515 to 15',
 'The war lasted from the year 1516 to 15',
 'The war lasted from the year 1517 to 15',
 'The war lasted from the year 1518 to 15',
 'The war lasted from the year 1519 to 15',
 'The war lasted from the year 1520 to 15',
 'The war lasted from the year 1521 to 15',
 'The war lasted from the year 1

In [7]:
from transformer_lens import HookedTransformer, HookedTransformerConfig

tl_model = HookedTransformer.from_pretrained('gpt2', device='cpu')
tokens = tl_model.to_tokens(prompts)
tokens

Loaded pretrained model gpt2 into HookedTransformer


tensor([[50256,   464,  1175,  ...,   284,  1315, 50256],
        [50256,   464,  1175,  ...,   486,   284,  1315],
        [50256,   464,  1175,  ...,    17,   284,  1315],
        ...,
        [50256,   464,  1175,  ...,   284,   678, 50256],
        [50256,   464,  1175,  ...,   284,   678, 50256],
        [50256,   464,  1175,  ...,   284,   678, 50256]])

In [8]:
_, cache = tl_model.run_with_cache(tokens)
cache.compute_head_results()
head_resid, head_labels = cache.stack_head_results(return_labels=True)
last_token_accum = head_resid.mean(dim=2).squeeze()
resid_streams = einops.rearrange(last_token_accum, 'b n h -> n b h')

In [9]:
resid_streams.shape

torch.Size([500, 144, 768])

In [10]:
import torch
from common_utils import *
# import sparse autoencoder
from sparse_autoencoder import SparseAutoencoder

task = 'gt'
device = 'cpu'
print(resid_streams.shape)

torch.Size([500, 144, 768])


In [12]:
num_unique = 100
lambda_ = 0.01
n_epochs = 500

# Turn torch grad on
torch.set_grad_enabled(True)

# Set random seed
torch.manual_seed(42)

model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_unique, geometric_median_dataset=None).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train the model
model = train(model, n_epochs, optimizer, resid_streams, resid_streams, lambda_)

  0%|          | 1/500 [00:00<02:46,  2.99it/s]

Train loss = 0.4896, Eval loss = 0.4468


 10%|█         | 52/500 [00:09<01:37,  4.61it/s]

Train loss = 0.1321, Eval loss = 0.1296


 20%|██        | 102/500 [00:19<01:18,  5.06it/s]

Train loss = 0.0825, Eval loss = 0.0823


 30%|███       | 152/500 [00:28<01:07,  5.12it/s]

Train loss = 0.0744, Eval loss = 0.0742


 40%|████      | 202/500 [00:37<00:57,  5.15it/s]

Train loss = 0.0689, Eval loss = 0.0688


 50%|█████     | 252/500 [00:46<00:49,  5.05it/s]

Train loss = 0.0652, Eval loss = 0.0651


 60%|██████    | 302/500 [00:55<00:38,  5.09it/s]

Train loss = 0.0623, Eval loss = 0.0622


 70%|███████   | 352/500 [01:04<00:29,  5.03it/s]

Train loss = 0.0600, Eval loss = 0.0599


 80%|████████  | 402/500 [01:13<00:19,  5.06it/s]

Train loss = 0.0580, Eval loss = 0.0580


 90%|█████████ | 452/500 [01:22<00:09,  5.05it/s]

Train loss = 0.0562, Eval loss = 0.0562


100%|██████████| 500/500 [01:31<00:00,  5.46it/s]


In [13]:
torch.set_grad_enabled(False)

# Get learned activations from the model
activations = model.encoder(resid_streams)

# Get the codes
codes = torch.argmax(activations, dim=2)

In [14]:
codes

tensor([[ 0,  0, 20,  ..., 85, 37,  0],
        [ 0,  0,  0,  ..., 85,  0, 85],
        [ 0,  0,  0,  ..., 85, 81,  0],
        ...,
        [ 0,  0,  0,  ..., 85,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ..., 85,  0,  0]])

In [15]:
import re

def prompt_to_year_number(prompt: str):
    """
    Takes in a prompt and returns the year (not century) number.

    E.g. 'The case lasted from the year 1902 to the year 19' -> 2
    E.g. 'The evaluation lasted from the year 1504 to the year 15' -> 4
    """
    year = re.findall(r'year (\d+)', prompt)
    if year:
        return int(year[0][-2:])
    return None

def prompt_to_century_number(prompt: str):
    """
    Takes in a prompt and returns the century number.

    E.g. 'The case lasted from the year 1902 to the year 19' -> 19
    E.g. 'The evaluation lasted from the year 1504 to the year 15' -> 15
    """
    century = re.findall(r'year (\d+)', prompt)
    if century:
        return int(century[0][:-2])
    return None

year_numbers = [prompt_to_year_number(prompt) for prompt in prompts]
century_numbers = [prompt_to_century_number(prompt) for prompt in prompts]

In [16]:
prompt_to_year_number('The war lasted from the year 1800 to 18')

0

In [17]:
# Plotly histogram of year numbers
fig = go.Figure()
fig.add_trace(go.Histogram(x=year_numbers, nbinsx=10, marker=dict(color='blue')))
fig.update_xaxes(title_text="Year Number", showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_yaxes(title_text="Frequency", showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_layout(height=600, width=800, title_text="Year Number Histogram")
fig.show()

In [18]:
len(codes), len(year_numbers)

(500, 500)

In [19]:
imshow(codes)

In [20]:
# We want to sort the first axis of codes by the ascending order of year number
codes_sorted = codes[torch.argsort(torch.tensor(year_numbers)), :]

imshow(codes_sorted)

In [44]:
import numpy as np
import plotly.graph_objects as go

# Assuming you have the following variables:
# - codes: a tensor of shape (n_examples, n_heads) containing the codes for each example and head
# - year_numbers: a list of year numbers corresponding to each example
# - n_heads: the number of attention heads
n_heads = 144

# Get the unique years and codes
to_use = century_numbers.copy() #year_numbers.copy()

unique_years = sorted(set(to_use))
unique_codes = sorted(torch.unique(codes).tolist())

# Create a dictionary to store the code frequencies for each year and head
code_freq_dict = {year: {head: {code: 0 for code in unique_codes} for head in range(n_heads)} for year in unique_years}

# Count the code frequencies for each year and head
for example_idx, year in enumerate(to_use):
    for head in range(n_heads):
        code = codes[example_idx, head].item()
        code_freq_dict[year][head][code] += 1

# Convert the code frequency dictionary to a 2D numpy array for visualization
code_freq_array = np.zeros((len(unique_years), n_heads * len(unique_codes)))
for year_idx, year in enumerate(unique_years):
    for head in range(n_heads):
        for code_idx, code in enumerate(unique_codes):
            code_freq_array[year_idx, head * len(unique_codes) + code_idx] = code_freq_dict[year][head][code]

# Create x-tick labels for each head and code combination
#x_tick_labels = [f'H{head+1}C{code}' for head in range(n_heads) for code in unique_codes]

# Create a heatmap using Plotly
fig = go.Figure(data=go.Heatmap(
    z=code_freq_array,
    x=x_tick_labels,
    y=unique_years,
    colorscale='Blues',
    colorbar=dict(title='Code Frequency')
))

# Customize the layout
fig.update_layout(
    title='Code Frequency Heatmap',
    xaxis=dict(title='Head and Code', tickangle=-90),
    yaxis=dict(title='Year'),
    width=1500,
    height=600
)

# Display the interactive heatmap
fig.show()

In [61]:
import numpy as np
import plotly.graph_objects as go

# Assuming you have the following variables:
# - codes: a tensor of shape (n_examples, n_heads) containing the codes for each example and head
# - year_numbers: a list of year numbers corresponding to each example
# - n_heads: the number of attention heads
n_heads = 144

# Get the unique years and codes to_use
to_use = century_numbers.copy() #year_numbers.copy()
unique_years = sorted(set(to_use))
unique_codes = sorted(torch.unique(codes).tolist())

# Create a dictionary to store the code frequencies for each year and head
code_freq_dict = {year: {head: {code: 0 for code in unique_codes} for head in range(n_heads)} for year in unique_years}

# Count the code frequencies for each year and head
for example_idx, year in enumerate(to_use):
    for head in range(n_heads):
        code = codes[example_idx, head].item()
        code_freq_dict[year][head][code] += 1

# Convert the code frequency dictionary to a 2D numpy array for visualization
code_freq_array = np.zeros((len(unique_years), n_heads * len(unique_codes)))
for year_idx, year in enumerate(unique_years):
    for head in range(n_heads):
        for code_idx, code in enumerate(unique_codes):
            code_freq_array[year_idx, head * len(unique_codes) + code_idx] = code_freq_dict[year][head][code]

# Create x-tick labels for each head and code combination
x_tick_labels = [f'H{head+1}C{code}' for head in range(n_heads) for code in unique_codes]

# Convert the code frequency array to a binary array
binary_array = (code_freq_array > 0).astype(int)

# Create a heatmap using Plotly
fig = go.Figure(data=go.Heatmap(
    z=binary_array,
    x=x_tick_labels,
    y=unique_years,
    colorscale=[[0, 'white'], [1, 'navy']],
    showscale=False
))

# Customize the layout
fontsize = 28
fig.update_layout(
    xaxis=dict(title='Head and Code', titlefont=dict(size=fontsize), tickangle=-90, tickfont=dict(size=14)),
    yaxis=dict(title='Year', titlefont=dict(size=fontsize), tickfont=dict(size=fontsize-4)),
    width=1500,
    height=600,
    # legend=dict(
    #     x=1.02,
    #     y=1,
    #     orientation='v',
    #     font=dict(size=14)
    # )
)

# Save figure
fig.write_image(f'../output/figures/code_frequency_binary_heatmap_century.pdf')

# Display the interactive heatmap
fig.show()


In [None]:
activations.shape

In [None]:
head_labels_of_interest = ['L1H0', 'L6H0', 'L10H8', 'L11H4', 'L11H8']
head_numbers_of_interest = [head_labels.index(head) for head in head_labels_of_interest]
head_numbers_of_interest

In [63]:
import numpy as np
from sklearn.decomposition import PCA
import plotly.express as px
import pandas as pd

# Assuming you have the following variables:
# - activations: a tensor of shape (n_examples, n_heads, n_codes) containing the activations for each example, head, and code
# - year_numbers: a list of year numbers corresponding to each example
# - prompts: a list of text examples corresponding to each example

head_labels_of_interest = ['L1H0', 'L6H0', 'L10H8', 'L11H8']
head_numbers_of_interest = [head_labels.index(head) for head in head_labels_of_interest]

for head_number, layer_head_str in zip(head_numbers_of_interest, head_labels_of_interest):
    layer_head_str = head_labels[head_number]
    print(f"Head {head_number}, Label = {layer_head_str}")
    reshaped_activations = activations[:, head_number, :]
    
    # Apply PCA to reduce the dimensionality to 2 components
    pca = PCA(n_components=2)
    reduced_activations = pca.fit_transform(reshaped_activations)
    
    # Create a DataFrame with the reduced activations, corresponding years, and text examples
    df = pd.DataFrame(reduced_activations, columns=['PC1', 'PC2'])
    df['Year'] = year_numbers
    df['Text'] = prompts
    
    # Create a scatter plot using Plotly Graph Objects
    fig = go.Figure(data=go.Scatter(
        x=df['PC1'],
        y=df['PC2'],
        mode='markers',
        marker=dict(
            size=8,
            color=df['Year'],
            colorscale='Viridis',
            showscale=True,
            colorbar=dict(title='Year', titlefont=dict(size=24, family='Palatino'), tickfont=dict(size=18, family='Palatino'))
        ),
        text=[f"Year: {year}<br>Text: {text}" for year, text in zip(df['Year'], df['Text'])],
        hoverinfo='text'
    ))
    
    # Customize the layout
    fontsize = 24
    fig.update_layout(
        xaxis=dict(title='PC1', titlefont=dict(size=fontsize, family='Palatino'), tickfont=dict(size=16, family='Palatino')),
        yaxis=dict(title='PC2', titlefont=dict(size=fontsize, family='Palatino'), tickfont=dict(size=16, family='Palatino')),
        font=dict(family='Palatino', size=fontsize),
        plot_bgcolor='white',
        width=800,
        height=600,
        showlegend=False
    )
    
    # Save the plot
    fig.write_image(f'../output/figures/pca_plot_gt_{layer_head_str}.pdf')
    
    # Display the plot
    fig.show()

Head 12, Label = L1H0


Head 72, Label = L6H0


Head 128, Label = L10H8


Head 140, Label = L11H8


In [37]:
import numpy as np
from sklearn.manifold import TSNE
import plotly.graph_objects as go
import pandas as pd

# Assuming you have the following variables:
# - activations: a tensor of shape (n_examples, n_heads, n_codes) containing the activations for each example, head, and code
# - year_numbers: a list of year numbers corresponding to each example
# - prompts: a list of text examples corresponding to each example

# Reshape the activations tensor to (n_examples, n_heads * n_codes)
reshaped_activations = activations.reshape(activations.shape[0], -1)

# Apply t-SNE to reduce the dimensionality to 2 components
tsne = TSNE(n_components=2, random_state=39, init='pca', perplexity=100, n_iter=10000, n_iter_without_progress=500, verbose=0)
reduced_activations = tsne.fit_transform(reshaped_activations)

# Create a DataFrame with the reduced activations, corresponding years, and text examples
df = pd.DataFrame(reduced_activations, columns=['t-SNE1', 't-SNE2'])
df['Year'] = year_numbers
df['Text'] = prompts

# Create a scatter plot using Plotly Graph Objects
fig = go.Figure(data=go.Scatter(
    x=df['t-SNE1'],
    y=df['t-SNE2'],
    mode='markers',
    marker=dict(
        size=8,
        color=df['Year'],
        colorscale='Viridis',
        showscale=True,
        colorbar=dict(title='Year', titlefont=dict(size=24, family='Palatino'), tickfont=dict(size=18, family='Palatino'))
    ),
    text=[f"Year: {year}<br>Text: {text}" for year, text in zip(df['Year'], df['Text'])],
    hoverinfo='text'
))

# Customize the layout
fontsize = 24
fig.update_layout(
    xaxis=dict(title='t-SNE Dimension 1', titlefont=dict(size=fontsize, family='Palatino'), tickfont=dict(size=16, family='Palatino')),
    yaxis=dict(title='t-SNE Dimension 2', titlefont=dict(size=fontsize, family='Palatino'), tickfont=dict(size=16, family='Palatino')),
    font=dict(family='Palatino', size=fontsize),
    plot_bgcolor='white',
    width=800,
    height=600,
    showlegend=False
)

# Save the plot
fig.write_image('../output/figures/tsne_plot_gt.pdf')

# Display the plot
fig.show()

In [42]:
import numpy as np
from sklearn.manifold import TSNE
import plotly.graph_objects as go
import pandas as pd

# Assuming you have the following variables:
# - activations: a tensor of shape (n_examples, n_heads, n_codes) containing the activations for each example, head, and code
# - century_numbers: a list of century numbers corresponding to each example
# - prompts: a list of text examples corresponding to each example

# Reshape the activations tensor to (n_examples, n_heads * n_codes)
reshaped_activations = activations.reshape(activations.shape[0], -1)

# Apply t-SNE to reduce the dimensionality to 2 components
tsne = TSNE(n_components=2, random_state=39, init='pca', perplexity=100, n_iter=10000, n_iter_without_progress=500, verbose=0)
reduced_activations = tsne.fit_transform(reshaped_activations)

# Create a DataFrame with the reduced activations, corresponding century numbers, and text examples
df = pd.DataFrame(reduced_activations, columns=['t-SNE1', 't-SNE2'])
df['Century'] = century_numbers
df['Text'] = prompts

# Create a scatter plot using Plotly Graph Objects
unique_centuries = sorted(set(century_numbers))
fig = go.Figure(data=go.Scatter(
    x=df['t-SNE1'],
    y=df['t-SNE2'],
    mode='markers',
    marker=dict(
        size=8,
        color=df['Century'],
        colorscale='Viridis',
        showscale=True,
        colorbar=dict(
            title='Century',
            titlefont=dict(size=24, family='Palatino'),
            tickfont=dict(size=18, family='Palatino'),
            tickvals=unique_centuries,
            ticktext=[str(century) for century in unique_centuries]
        )
    ),
    text=[f"Century: {century}<br>Text: {text}" for century, text in zip(df['Century'], df['Text'])],
    hoverinfo='text'
))

# Customize the layout
fontsize = 24
fig.update_layout(
    xaxis=dict(title='t-SNE Dimension 1', titlefont=dict(size=fontsize, family='Palatino'), tickfont=dict(size=16, family='Palatino')),
    yaxis=dict(title='t-SNE Dimension 2', titlefont=dict(size=fontsize, family='Palatino'), tickfont=dict(size=16, family='Palatino')),
    font=dict(family='Palatino', size=fontsize),
    plot_bgcolor='white',
    width=800,
    height=600,
    showlegend=False
)

# Save the plot
fig.write_image('../output/figures/tsne_plot_gt_century_integer_ticks.pdf')

# Display the plot
fig.show()

In [28]:
import numpy as np
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
import plotly.graph_objects as go
import pandas as pd

# Assuming you have the following variables:
# - activations: a tensor of shape (n_examples, n_heads, n_codes) containing the activations for each example, head, and code
# - year_numbers: a list of year numbers corresponding to each example
# - century_numbers: a list of century numbers corresponding to each example
# - prompts: a list of text examples corresponding to each example

# Reshape the activations tensor to (n_examples, n_heads * n_codes)
reshaped_activations = activations.reshape(activations.shape[0], -1)

# Apply t-SNE to reduce the dimensionality to 2 components
tsne = TSNE(n_components=2, random_state=39, init='pca', perplexity=100, n_iter=10000, n_iter_without_progress=500, verbose=0)
reduced_activations = tsne.fit_transform(reshaped_activations)

# Create a DataFrame with the reduced activations, corresponding years, centuries, and text examples
df = pd.DataFrame(reduced_activations, columns=['t-SNE1', 't-SNE2'])
df['Year'] = year_numbers
df['Century'] = century_numbers
df['Text'] = prompts

# Create a grid of points for the background
grid_resolution = 100
x_min, x_max = df['t-SNE1'].min(), df['t-SNE1'].max()
y_min, y_max = df['t-SNE2'].min(), df['t-SNE2'].max()
xx, yy = np.meshgrid(np.linspace(x_min, x_max, grid_resolution), np.linspace(y_min, y_max, grid_resolution))
background_points = np.vstack((xx.ravel(), yy.ravel())).T

# Find the k-nearest neighbors for each background point
k = 10
nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(df[['t-SNE1', 't-SNE2']])
distances, indices = nbrs.kneighbors(background_points)

# Assign the majority century among the neighbors as the color of each background point
background_colors = []
for idx in indices:
    neighbor_centuries = df.loc[idx, 'Century']
    majority_century = neighbor_centuries.mode()[0]
    background_colors.append(majority_century)

# Create the background trace
background_trace = go.Heatmap(
    x=np.linspace(x_min, x_max, grid_resolution),
    y=np.linspace(y_min, y_max, grid_resolution),
    z=np.array(background_colors).reshape((grid_resolution, grid_resolution)),
    colorscale='Viridis',
    showscale=False,
    hoverinfo='none'
)

# Create the scatter plot trace
scatter_trace = go.Scatter(
    x=df['t-SNE1'],
    y=df['t-SNE2'],
    mode='markers',
    marker=dict(
        size=8,
        color=df['Year'],
        colorscale='Viridis',
        showscale=True,
        colorbar=dict(title='Year', titlefont=dict(size=24, family='Palatino'), tickfont=dict(size=18, family='Palatino'))
    ),
    text=[f"Year: {year}<br>Text: {text}" for year, text in zip(df['Year'], df['Text'])],
    hoverinfo='text'
)

# Create the figure with both traces
fig = go.Figure(data=[background_trace, scatter_trace])

# Customize the layout
fontsize = 24
fig.update_layout(
    xaxis=dict(title='t-SNE Dimension 1', titlefont=dict(size=fontsize, family='Palatino'), tickfont=dict(size=16, family='Palatino')),
    yaxis=dict(title='t-SNE Dimension 2', titlefont=dict(size=fontsize, family='Palatino'), tickfont=dict(size=16, family='Palatino')),
    font=dict(family='Palatino', size=fontsize),
    plot_bgcolor='white',
    width=800,
    height=600,
    showlegend=False
)

# Save the plot
fig.write_image('../output/figures/tsne_plot_gt_background.pdf')

# Display the plot
fig.show()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

X does not have valid feature names, but NearestNeighbors was fitted with feature names

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [36]:
import numpy as np
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
import plotly.graph_objects as go
import pandas as pd

# Assuming you have the following variables:
# - activations: a tensor of shape (n_examples, n_heads, n_codes) containing the activations for each example, head, and code
# - year_numbers: a list of year numbers corresponding to each example
# - century_numbers: a list of century numbers corresponding to each example
# - prompts: a list of text examples corresponding to each example

# Reshape the activations tensor to (n_examples, n_heads * n_codes)
reshaped_activations = activations.reshape(activations.shape[0], -1)

# Apply t-SNE to reduce the dimensionality to 2 components
tsne = TSNE(n_components=2, random_state=39, init='pca', perplexity=100, n_iter=10000, n_iter_without_progress=500, verbose=0)
reduced_activations = tsne.fit_transform(reshaped_activations)

# Create a DataFrame with the reduced activations, corresponding years, centuries, and text examples
df = pd.DataFrame(reduced_activations, columns=['t-SNE1', 't-SNE2'])
df['Year'] = year_numbers
df['Century'] = century_numbers
df['Text'] = prompts

# Create a grid of points for the background
grid_resolution = 100
x_min, x_max = df['t-SNE1'].min(), df['t-SNE1'].max()
y_min, y_max = df['t-SNE2'].min(), df['t-SNE2'].max()
xx, yy = np.meshgrid(np.linspace(x_min, x_max, grid_resolution), np.linspace(y_min, y_max, grid_resolution))
background_points = np.vstack((xx.ravel(), yy.ravel())).T

# Find the k-nearest neighbors for each background point
k = 10
nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(df[['t-SNE1', 't-SNE2']])
distances, indices = nbrs.kneighbors(background_points)

# Assign the majority century among the neighbors as the color of each background point
background_colors = []
for idx in indices:
    neighbor_centuries = df.loc[idx, 'Century']
    majority_century = neighbor_centuries.mode()[0]
    background_colors.append(majority_century)

# Create the background trace with a different colorscale and transparency
background_colorscale = 'Plasma'
background_opacity = 0.4  # Adjust the opacity value between 0 and 1
background_trace = go.Heatmap(
    x=np.linspace(x_min, x_max, grid_resolution),
    y=np.linspace(y_min, y_max, grid_resolution),
    z=np.array(background_colors).reshape((grid_resolution, grid_resolution)),
    colorscale=background_colorscale,
    showscale=False,
    hoverinfo='none',
    opacity=background_opacity
)

# Create the scatter plot trace
scatter_trace = go.Scatter(
    x=df['t-SNE1'],
    y=df['t-SNE2'],
    mode='markers',
    marker=dict(
        size=8,
        color=df['Year'],
        colorscale='Viridis',
        showscale=True,
        colorbar=dict(title='Year', titlefont=dict(size=24, family='Palatino'), tickfont=dict(size=18, family='Palatino'))
    ),
    text=[f"Year: {year}<br>Text: {text}" for year, text in zip(df['Year'], df['Text'])],
    hoverinfo='text'
)

# Create the figure with the background and scatter traces
fig = go.Figure(data=[background_trace, scatter_trace])

# Customize the layout
fontsize = 24
fig.update_layout(
    xaxis=dict(title='t-SNE Dimension 1', titlefont=dict(size=fontsize, family='Palatino'), tickfont=dict(size=16, family='Palatino')),
    yaxis=dict(title='t-SNE Dimension 2', titlefont=dict(size=fontsize, family='Palatino'), tickfont=dict(size=16, family='Palatino')),
    font=dict(family='Palatino', size=fontsize),
    plot_bgcolor='white',
    width=800,
    height=600
)

# Save the plot
fig.write_image('../output/figures/tsne_plot_gt_background.pdf')

# Display the plot
fig.show()


X does not have valid feature names, but NearestNeighbors was fitted with feature names

