In [125]:
import io
from logging import warning
from typing import Union, List
from site import PREFIXES
import warnings
import numpy as np
from tqdm import tqdm
import pandas as pd
from transformers import AutoTokenizer
import random
import re
import matplotlib.pyplot as plt
import random as rd
import copy
import random
from typing import List, Union
from pathlib import Path
import torch
from transformer_lens import HookedTransformer

from sparse_autoencoder import SparseAutoencoder

from sklearn.metrics import roc_curve, roc_auc_score

import torch
import torch.nn as nn
import torch.nn.functional as F
# Optim
import torch.optim as optim

from typing import List, Tuple, Dict, Union, Optional, Callable, Any
from time import ctime
import einops
import torch
import numpy as np
from copy import deepcopy
from collections import OrderedDict
import pickle
from subprocess import call

In [126]:
import transformer_lens.utils as utils
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

## Generating a dataset algorithmically

In [340]:
device = 'cpu'
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [341]:
templates = [
    "So {name} is a really great friend, isn't",
    "So {name} is such a good cook, isn't",
    "So {name} is a very good athlete, isn't",
    "So {name} is a really nice person, isn't",
    "So {name} is such a funny person, isn't"
    ]

male_names = [
    "John",
    "David",
    "Mark",
    "Paul",
    "Ryan",
    "Gary",
    "Jack",
    "Sean",
    "Carl",
    "Joe",    
]
female_names = [
    "Mary",
    "Lisa",
    "Anna",
    "Sarah",
    "Amy",
    "Carol",
    "Karen",
    "Susan",
    "Julie",
    "Judy"
]

positive_examples = []
negative_examples = []
# Storing the answers
answers = []
wrongs = []

responses = [' he', ' she']

count = 0

for name in male_names + female_names:
    for template in templates:
        cur_sentence = template.format(name = name)
        positive_examples.append(cur_sentence)

        neg_sentence = template.format(name = 'that person')
        negative_examples.append(neg_sentence)

batch_size = len(positive_examples)

count = 0

for _ in range(batch_size):
    if count < (0.5 * len(positive_examples)):
        answers.append(responses[0])
        wrongs.append(responses[1])
        count += 1
    else:
        answers.append(responses[1])
        wrongs.append(responses[0])

positive_examples = model.to_tokens(positive_examples, prepend_bos = True)
answers = torch.tensor(model.tokenizer(answers)["input_ids"]).squeeze()
wrongs = torch.tensor(model.tokenizer(wrongs)["input_ids"]).squeeze()
negative_examples = model.to_tokens(negative_examples, prepend_bos = True)

# Pad the second dimension of positive examples to the second dimension of negative examples
positive_examples = torch.cat([positive_examples, torch.zeros((positive_examples.shape[0], negative_examples.shape[1] - positive_examples.shape[1]), dtype = torch.long)], dim = 1)
positive_examples.shape

torch.Size([100, 12])

In [38]:
def pronoun_metric(model, tokens):
    logits = model(tokens)
    logits_on_correct = logits[torch.arange(batch_size), -1, answers]
    logits_on_wrong = logits[torch.arange(batch_size), -1, wrongs]
    result = torch.mean(logits_on_correct - logits_on_wrong)
    return result.item()

In [39]:
pronoun_metric(model, positive_examples)

4.727617263793945

In [40]:
pronoun_metric(model, negative_examples)

0.0

In [342]:
# Let's stack the positive and negative examples along the first dimension
all_examples = torch.cat([positive_examples, negative_examples], dim = 0)

In [48]:
# Set up the model
def prompt_to_resid_stream(prompt: str, model: HookedTransformer, resid_type: str = 'accumulated', position: str = 'last') -> torch.Tensor:
    """
    Convert a prompt to a residual stream of size (n_layers, d_model)
    """
    # Run the model over the prompt
    with torch.no_grad():
        _, cache = model.run_with_cache(prompt)

        # Get the accumulated residuals
        if resid_type == 'accumulated':
            resid, _ = cache.accumulated_resid(return_labels=True, apply_ln=True)
        elif resid_type == 'decomposed':
            resid, _ = cache.decompose_resid(return_labels=True)
        elif resid_type == 'heads':
            cache.compute_head_results()
            head_resid, head_labels = cache.stack_head_results(return_labels=True)
            #mlp_resid, mlp_labels = cache.decompose_resid(mode='mlp', incl_embeds=False, return_labels=True)
            # Combine
            # resid = torch.cat([head_resid, mlp_resid], dim=0)
            # labels = head_labels + mlp_labels
            resid = head_resid
            labels = head_labels
        else:
            raise ValueError("resid_type must be one of 'accumulated', 'decomposed', 'heads'")

    # POSITION
    if position == 'last':
        last_token_accum = resid[:, 0, -1, :]  # layer, batch, pos, d_model
    elif position == 'mean':
        last_token_accum = resid.mean(dim=2).squeeze()
    else:
        raise ValueError("position must be one of 'last', 'mean'")
    return last_token_accum, labels

def all_prompts_to_resid_streams(prompts, prompts_cf, model, resid_type='accumulated', position='mean'):
    """
    Convert all prompts and counterfactual prompts to residual streams
    """
    # Stack prompts and prompts cf
    resid_streams = []
    all_prompts = torch.cat([prompts, prompts_cf], dim=0)
    for i in tqdm(range(all_prompts.shape[0])):
        prompt = model.to_string(all_prompts[i])
        # Strip the prompt of any exclamation marks
        prompt = prompt.replace("!", "")
        resid_stream, labels = prompt_to_resid_stream(prompt, model, resid_type, position)
        resid_streams.append(resid_stream)
    # Stack the residual streams into a single tensor
    return torch.stack(resid_streams), labels

resid_streams, labels = all_prompts_to_resid_streams(positive_examples, negative_examples, model, resid_type='heads', position='mean')

100%|██████████| 200/200 [00:09<00:00, 21.29it/s]


In [52]:
# Save residual streams and labels to our data folder
torch.save(resid_streams, '../data/gender/resid_streams.pt')

## Training the sparse autoencoder

In [55]:
# Loss function is MSE (reconstruction loss) + L1 norm of the learned activations + similarity loss
def loss_fn(decoded_activations, learned_activations, resid_streams, lambda_=0.01):

    # RECONSTRUCTION LOSS
    recon_loss = F.mse_loss(decoded_activations, resid_streams)

    # SPARSITY LOSS
    learned_activations_flat = einops.rearrange(learned_activations, 'b s n -> (b s) n')
    sparsity_loss = torch.mean(torch.norm(learned_activations_flat, p=1, dim=1))

    # combine
    return recon_loss + (lambda_ * sparsity_loss)


def train(model, n_epochs, optimizer, train_streams, eval_streams, lambda_=0.01):
    for epoch in tqdm(range(n_epochs)):
        model.train()
        optimizer.zero_grad()
        learned_activations, decoded_activations = model(train_streams)
        loss = loss_fn(decoded_activations, learned_activations, train_streams, lambda_=lambda_)
        loss.backward()
        optimizer.step()
        if epoch % (n_epochs // 10) == 0:
            model.eval()
            with torch.no_grad():
                eval_learned_activations, eval_decoded_activations = model(eval_streams)
                eval_loss = loss_fn(eval_decoded_activations, eval_learned_activations,
                                                             eval_streams, lambda_=lambda_)
                print(f"Train loss = {loss.item():.4f}, Eval loss = {eval_loss.item():.4f}")
    return model

In [358]:
# Set manual seed
torch.manual_seed(1000)

# Shuffle and create the labels
resid_streams = torch.load('../data/gender/resid_streams.pt')
labels = torch.ones(resid_streams.shape[0]//2) # BIG ASSUMPTION: assumes first half is positive and second half is negative
labels = torch.cat((labels, torch.zeros_like(labels)))
permutation = torch.randperm(resid_streams.shape[0])
resid_shuffled = resid_streams[permutation, :, :]
labels_shuffled = labels[permutation]
cutoff = 10 #int(resid_shuffled.shape[0] * 0.8)
train_streams = resid_shuffled[:cutoff, :, :].to(device)
train_labels = labels_shuffled[:cutoff].to(device)
eval_streams = resid_shuffled[cutoff:, :, :].to(device)
eval_labels = labels_shuffled[cutoff:].to(device)

num_learned_features = 200
lambda_ = 0.02
n_epochs = 500

model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_learned_features, geometric_median_dataset=None).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
model = train(model, n_epochs, optimizer, train_streams, eval_streams, lambda_=lambda_)

  2%|▏         | 12/500 [00:00<00:07, 63.00it/s]

Train loss = 0.8355, Eval loss = 0.7403


 16%|█▌        | 79/500 [00:00<00:02, 148.59it/s]

Train loss = 0.1707, Eval loss = 0.1672


 28%|██▊       | 141/500 [00:00<00:02, 179.44it/s]

Train loss = 0.1143, Eval loss = 0.1136


 36%|███▋      | 182/500 [00:01<00:01, 186.11it/s]

Train loss = 0.1002, Eval loss = 0.0998


 45%|████▌     | 225/500 [00:01<00:01, 194.22it/s]

Train loss = 0.0915, Eval loss = 0.0912


 59%|█████▉    | 294/500 [00:01<00:00, 207.12it/s]

Train loss = 0.0861, Eval loss = 0.0859


 68%|██████▊   | 338/500 [00:01<00:00, 203.21it/s]

Train loss = 0.0815, Eval loss = 0.0813


 76%|███████▋  | 382/500 [00:02<00:00, 200.46it/s]

Train loss = 0.0769, Eval loss = 0.0768


 85%|████████▌ | 426/500 [00:02<00:00, 199.90it/s]

Train loss = 0.0732, Eval loss = 0.0731


 99%|█████████▉| 494/500 [00:02<00:00, 202.25it/s]

Train loss = 0.0706, Eval loss = 0.0706


100%|██████████| 500/500 [00:02<00:00, 180.09it/s]


In [359]:
# Let's save our model
save_dict = {
    'model': model.state_dict(),
    'num_learned_features': num_learned_features,
    'lambda': lambda_,
    'n_epochs': n_epochs
}
save_path = '../models/gender/sparse_autoencoder_dict.pt'
torch.save(save_dict, save_path)

## Perform circuit identification

In [360]:
# Load the save dict
save_dict = torch.load(save_path)
model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=save_dict['num_learned_features'], geometric_median_dataset=None).to(device)
model.load_state_dict(save_dict['model'])

<All keys matched successfully>

In [377]:
def feature_string_to_head_and_layer(feature_index, head_labels):

    extraction = head_labels[feature_index]

    if 'mlp' in extraction.lower(): 
        layer = int(extraction.split('_')[0])
        head = 12
        return layer, head

    # Get head and layer e.g. 'L0H1' -> (0, 1)
    # Layer is everything after L and before H
    layer = int(re.findall(r'L(\d+)H', extraction)[0])
    # Head is everything after H
    head = int(re.findall(r'H(\d+)', extraction)[0])

    return layer, head

def gen_array_template(head_labels):

    # Plot the ground truth (head, layer) pairs (1 if in ground truth, 0 otherwise)
    heads = []
    layers = []
    for i, l in enumerate(head_labels):
        layer, head = feature_string_to_head_and_layer(i, head_labels)
        heads.append(head)
        layers.append(layer)

    heads = list(set(heads))
    layers = list(set(layers))

    return np.zeros((len(layers), len(heads)))

def softmax(x, axis):
    """Return the softmax of x (if x is a vector) or the softmax of each row (if x is a matrix)"""
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / e_x.sum(axis=axis, keepdims=True)

def circuit_prediction(all_indices, head_labels, normalise=False, across_layer=False):
    # Negative and positive indices
    halfway = all_indices.shape[0] // 2
    positive_indices = all_indices[:halfway, :]
    negative_indices = all_indices[halfway:, :]

    unique_to_positive_array = gen_array_template(head_labels)
    unique_to_negative_array = gen_array_template(head_labels)

    n_layers, n_heads = unique_to_positive_array.shape

    for i in range(len(head_labels)):
        # Calculate head and layer
        layer, head = feature_string_to_head_and_layer(i, head_labels)

        positive = set(positive_indices[:, i].tolist())
        negative = set(negative_indices[:, i].tolist())
        total_unique = positive.union(negative)

        # In positive but not negative
        unique_to_positive = list(positive - negative)
        # In negative but not positive
        unique_to_negative = list(negative - positive)

        if normalise:
            # Normalise by total number of unique indices
            unique_to_positive_array[layer, head] = len(unique_to_positive) / len(total_unique)
            unique_to_negative_array[layer, head] = len(unique_to_negative) / len(total_unique)
        
        else:
            # Set the values
            unique_to_positive_array[layer, head] = len(unique_to_positive)
            unique_to_negative_array[layer, head] = len(unique_to_negative)

    # Normalise y_pred with softmax
    if not across_layer: 
        y_pred = unique_to_positive_array.flatten()
        y_pred = softmax(y_pred, axis=0)
    else:
        y_pred = unique_to_positive_array.copy()
        # Softmax across rows
        y_pred = softmax(y_pred, axis=0)


    # Reshape
    y_pred = y_pred.reshape((n_layers, n_heads))

    return y_pred, unique_to_positive_array, unique_to_negative_array

In [378]:
head_labels = torch.load('../data/ioi/labels_heads_mean.pt')

ground_truth = [(3, 4), (1, 4), (2, 6), (9, 7), (10, 9), (4, 3), (6, 0)]

heads = []
layers = []
for i, l in enumerate(head_labels):
    layer, head = feature_string_to_head_and_layer(i, head_labels)
    heads.append(head)
    layers.append(layer)

heads = list(set(heads))
layers = list(set(layers))

ground_truth_array = np.zeros((len(layers), len(heads)))
for layer, head in ground_truth:
    ground_truth_array[layer, head] = 1

normalise = False

model.eval()
learned_activations = model.encoder(resid_streams).detach().cpu().numpy()
all_indices = np.argmax(learned_activations, axis=2)

y_pred, unique_to_positive_array, unique_to_negative_array = circuit_prediction(all_indices, head_labels, normalise=normalise, across_layer=True)

In [379]:
fpr, tpr, _ = roc_curve(ground_truth_array.flatten(), y_pred.flatten())
roc_auc = roc_auc_score(ground_truth_array.flatten(), y_pred.flatten())
print(roc_auc)

0.7194994786235662


In [317]:
y_pred.shape

(12, 12)

In [318]:
imshow(unique_to_positive_array)

In [319]:
imshow(y_pred)

In [168]:
circuit_array = (y_pred > 1e-10).astype(int)
# Create list of tuples of (layer, head) where layer is row, head is column
circuit_indices = []
for i in range(circuit_array.shape[0]):
    for j in range(circuit_array.shape[1]):
        if circuit_array[i, j] == 1:
            circuit_indices.append((i, j))

In [172]:
circuit_indices

[(0, 11),
 (1, 4),
 (1, 5),
 (1, 8),
 (1, 11),
 (2, 0),
 (2, 2),
 (2, 4),
 (2, 5),
 (2, 10),
 (3, 6),
 (3, 8),
 (4, 3),
 (4, 5),
 (4, 11),
 (5, 5),
 (6, 0),
 (6, 6),
 (7, 9),
 (8, 8),
 (8, 10),
 (8, 11),
 (9, 5),
 (10, 0),
 (10, 1),
 (10, 2),
 (10, 6),
 (11, 10),
 (11, 11)]

## Circuit performance

In [129]:
tl_model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [130]:
clean_logits, clean_cache = tl_model.run_with_cache(positive_examples)
corrupted_logits, corrupted_cache = tl_model.run_with_cache(negative_examples)

In [136]:
original_average_logit_diff = pronoun_metric_from_logits(clean_logits)
print(f"Original average logit difference: {original_average_logit_diff:.4f}")
corrupted_average_logit_diff = pronoun_metric_from_logits(corrupted_logits) 
print(f"Corrupted average logit difference: {corrupted_average_logit_diff:.100f}")

Original average logit difference: 4.3714
Corrupted average logit difference: 0.0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000


In [137]:
from jaxtyping import Float
from functools import partial
from tqdm import tqdm

def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    ablation_type = 'zero',
    corrupted_cache = corrupted_cache,
):
    assert ablation_type in ['zero', 'random'], f"ablation_type must be one of 'zero', 'random'"
    if ablation_type == 'zero':
        corrupted_head_vector[:, :, head_index, :] = 0.0
    else: 
        corrupted_head_vector[:, :, head_index, :] = corrupted_cache[hook.name][
        :, :, head_index, :
    ]
    return corrupted_head_vector

def pronoun_metric_from_logits(logits):
    logits_on_correct = logits[torch.arange(batch_size), -1, answers]
    logits_on_wrong = logits[torch.arange(batch_size), -1, wrongs]
    result = torch.mean(logits_on_correct - logits_on_wrong)
    return result.item()

original_average_logit_diff = pronoun_metric_from_logits(clean_logits)
corrupted_average_logit_diff = pronoun_metric_from_logits(corrupted_logits) 

def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff) / (
        original_average_logit_diff - corrupted_average_logit_diff
    )

patched_head_z_diff_zero = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)

patched_head_z_diff_random = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)

for layer in tqdm(range(model.cfg.n_layers)):
    for head_index in range(model.cfg.n_heads):
        # Zero ablation
        hook_fn = partial(patch_head_vector, head_index=head_index, ablation_type='zero', corrupted_cache=corrupted_cache)
        patched_logits = model.run_with_hooks(
            positive_examples,
            fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_metric = pronoun_metric_from_logits(patched_logits)
        normalised_patched_metric = normalize_patched_logit_diff(patched_metric)
        patched_head_z_diff_zero[layer, head_index] = normalised_patched_metric

        # Random ablation
        hook_fn = partial(patch_head_vector, head_index=head_index, ablation_type='random', corrupted_cache=corrupted_cache)
        patched_logits = model.run_with_hooks(
            positive_examples,
            fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_metric = pronoun_metric_from_logits(patched_logits)
        normalised_patched_metric = normalize_patched_logit_diff(patched_metric)
        patched_head_z_diff_random[layer, head_index] = normalised_patched_metric

100%|██████████| 12/12 [02:08<00:00, 10.75s/it]


In [138]:
imshow(patched_head_z_diff_zero)

In [139]:
imshow(patched_head_z_diff_random)

In [179]:
from typing import List, Tuple, Dict, Union, Optional, Callable, Any
from jaxtyping import Float, Int
from torch import Tensor
import functools
from transformer_lens.hook_points import (
    HookPoint,
)

def head_ablation_hook(
    attn_result: Float[Tensor, "batch seq n_heads d_model"],
    hook: HookPoint,
    head_index_to_ablate: int
) -> Float[Tensor, "batch seq n_heads d_model"]:
    attn_result[:, :, head_index_to_ablate, :] = 0.0
    return attn_result

def get_ablation_scores(
    model: HookedTransformer,
    tokens: Int[Tensor, "batch seq"],
    heads_to_ablate: List[Tuple[int, int]],
) -> Float[Tensor, "1"]:

    # Calculating loss without any ablation, to act as a baseline
    model.reset_hooks()
    base_logits = model(tokens, return_type="logits")

    # Run the model with the ablation hook
    ablated_logits = model.run_with_hooks(tokens, return_type="logits", fwd_hooks=[
        (utils.get_act_name("z", layer), functools.partial(head_ablation_hook, head_index_to_ablate=head)) for layer, head in heads_to_ablate
    ])

    # Calculate the pronoun metric for the ablated logits
    base_metric = pronoun_metric_from_logits(base_logits)
    ablated_metric = pronoun_metric_from_logits(ablated_logits)

    return ablated_metric

circuit = [(3, 4), (1, 4), (2, 6), (9, 7), (10, 9), (4, 3), (6, 0)]
# circuit = [(0, 11), (1, 4), (1, 5), (1, 8), (1, 11), (2, 0), (2, 2), (2, 4), (2, 5), (2, 6),
#            (2, 10), (3, 4), (3, 6), (3, 8), (4, 3), (4, 5), (4, 11), (5, 5), (6, 0), (6, 6), 
#            (7, 9), (8, 8), (8, 10), (8, 11), (9, 5), (9, 7), (10, 0), (10, 1), (10, 2), (10, 6), (10, 9), (11, 10), (11, 11)]
circuit = [(layer, head) for layer in range(11) for head in range(11)]
full_model = [(layer, head) for layer in range(12) for head in range(12)]

not_in_circuit = [x for x in full_model if x not in circuit]
print(circuit)
print()
print(not_in_circuit)

[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (0, 8), (0, 9), (0, 10), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (2, 9), (2, 10), (3, 0), (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8), (3, 9), (3, 10), (4, 0), (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10), (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), (5, 9), (5, 10), (6, 0), (6, 1), (6, 2), (6, 3), (6, 4), (6, 5), (6, 6), (6, 7), (6, 8), (6, 9), (6, 10), (7, 0), (7, 1), (7, 2), (7, 3), (7, 4), (7, 5), (7, 6), (7, 7), (7, 8), (7, 9), (7, 10), (8, 0), (8, 1), (8, 2), (8, 3), (8, 4), (8, 5), (8, 6), (8, 7), (8, 8), (8, 9), (8, 10), (9, 0), (9, 1), (9, 2), (9, 3), (9, 4), (9, 5), (9, 6), (9, 7), (9, 8), (9, 9), (9, 10), (10, 0), (10, 1), (10, 2), (10, 3), (10, 4), (10, 5), (10, 6), (10, 7), (10, 8), (10, 9), (10, 10)]

[(0, 11)

In [180]:
pronoun_score = get_ablation_scores(tl_model, positive_examples, not_in_circuit)
pronoun_score

4.930678844451904

## Trying with keys, queries and values

In [411]:
tl_model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [412]:
logits, cache = tl_model.run_with_cache(all_examples)

In [413]:
cache['blocks.0.attn.hook_q'].shape, cache['blocks.0.attn.hook_k'].shape, cache['blocks.0.attn.hook_v'].shape

(torch.Size([200, 12, 12, 64]),
 torch.Size([200, 12, 12, 64]),
 torch.Size([200, 12, 12, 64]))

In [414]:
resid_streams = []
head_labels = []
for layer in layers:
    for calc_type in ['q', 'k', 'v']:
        resid_streams.append(cache[f'blocks.{layer}.attn.hook_{calc_type}'].mean(dim=2))
        head_labels.extend([f'L{layer}_H{i}_{calc_type}' for i in range(12)])

In [415]:
# Combine all tensors in resid stream 
resid_streams = torch.cat(resid_streams, dim=1)
resid_streams.shape

torch.Size([200, 432, 64])

In [416]:
# Labels
halfway = resid_streams.shape[0] // 2
labels = torch.cat([torch.ones(halfway), torch.zeros(halfway)])
permutation = torch.randperm(resid_streams.shape[0])
resid_shuffled = resid_streams[permutation, :, :]
labels_shuffled = labels[permutation]
cutoff = 10 #int(resid_shuffled.shape[0] * 0.8)
train_streams = resid_shuffled[:cutoff, :, :].to(device)
train_labels = labels_shuffled[:cutoff].to(device)
eval_streams = resid_shuffled[cutoff:, :, :].to(device)
eval_labels = labels_shuffled[cutoff:].to(device)


# Train an autoencoder on the residual streams
num_learned_features = 200
lambda_ = 0.02
n_epochs = 500

model = SparseAutoencoder(n_input_features=resid_streams.shape[-1], n_learned_features=num_learned_features, geometric_median_dataset=None).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
model = train(model, n_epochs, optimizer, train_streams, eval_streams, lambda_=lambda_)

  7%|▋         | 33/500 [00:00<00:02, 173.98it/s]

Train loss = 0.8274, Eval loss = 0.8018


 15%|█▍        | 74/500 [00:00<00:02, 183.84it/s]

Train loss = 0.2995, Eval loss = 0.2968


 28%|██▊       | 140/500 [00:00<00:01, 186.11it/s]

Train loss = 0.1547, Eval loss = 0.1553


 35%|███▌      | 176/500 [00:01<00:02, 146.16it/s]

Train loss = 0.1284, Eval loss = 0.1292


 45%|████▌     | 226/500 [00:01<00:01, 180.54it/s]

Train loss = 0.1141, Eval loss = 0.1152


 55%|█████▍    | 274/500 [00:01<00:01, 196.89it/s]

Train loss = 0.1044, Eval loss = 0.1056


 70%|██████▉   | 348/500 [00:01<00:00, 217.38it/s]

Train loss = 0.0972, Eval loss = 0.0985


 79%|███████▉  | 397/500 [00:02<00:00, 224.53it/s]

Train loss = 0.0915, Eval loss = 0.0929


 89%|████████▉ | 445/500 [00:02<00:00, 219.48it/s]

Train loss = 0.0874, Eval loss = 0.0889


100%|██████████| 500/500 [00:02<00:00, 196.64it/s]

Train loss = 0.0841, Eval loss = 0.0856





In [417]:
learned_activations = model.encoder(resid_streams).detach().cpu().cpu()
all_indices = np.argmax(learned_activations, axis=2)
all_indices.shape

torch.Size([200, 432])

In [443]:
# Plotly imshow all indices
fig = px.imshow(all_indices)
# Update layout
fig.update_layout(
    title="Head activations",
    xaxis_title="Head index",
    yaxis_title="Layer",
)
fig.show()

In [436]:
print(learned_activations.shape)

# Negative and positive indices
halfway = all_indices.shape[0] // 2
positive_activations = learned_activations[:halfway, :, :]
negative_activations = learned_activations[halfway:, :, :]

dotproduct_values = np.zeros((len(head_labels)))

for i in range(len(head_labels)):
    v = positive_activations.mean(dim=0)[i, :]
    u = negative_activations.mean(dim=0)[i, :]
    dotproduct_values[i] = -np.dot(v, u) #/ (np.linalg.norm(v) * np.linalg.norm(u))
    

# Reshape them into layers (12) * the rest (36)
dotproduct_values = dotproduct_values.reshape((12, 36))

across_layer = True

# Normalise y_pred with softmax
if not across_layer: 
    y_pred = dotproduct_values.flatten()
    y_pred = softmax(y_pred, axis=0)
else:
    y_pred = dotproduct_values.copy()
    # Softmax across rows
    y_pred = softmax(y_pred, axis=1)


# Reshape
y_pred = y_pred.reshape((12, 36))

torch.Size([200, 432, 200])


In [437]:
imshow(y_pred)

In [440]:
# For each layer, take the sum of the groups of 3 as we move along the row
y_pred_final = np.zeros((12, 12))
for i in range(12):
    for j in range(12):
        y_pred_final[i, j] = np.sum(y_pred[i, j*3:(j+1)*3])

#y_pred_final = softmax(y_pred_final.flatten(), axis=0).reshape((12, 12))

imshow(y_pred_final)

In [441]:
imshow(ground_truth_array)

In [402]:
# Negative and positive indices
halfway = all_indices.shape[0] // 2
positive_indices = all_indices[:halfway, :]
negative_indices = all_indices[halfway:, :]

unique_to_positive_array = np.zeros((len(head_labels)))
unique_to_negative_array = np.zeros((len(head_labels)))

for i in range(len(head_labels)):

    positive = set(positive_indices[:, i].tolist())
    negative = set(negative_indices[:, i].tolist())
    total_unique = positive.union(negative)

    # In positive but not negative
    unique_to_positive = list(positive - negative)
    # In negative but not positive
    unique_to_negative = list(negative - positive)

    if normalise:
        # Normalise by total number of unique indices
        unique_to_positive_array[i] = len(unique_to_positive) / len(total_unique)
        unique_to_negative_array[i] = len(unique_to_negative) / len(total_unique)
    
    else:
        # Set the values
        unique_to_positive_array[i] = len(unique_to_positive)
        unique_to_negative_array[i] = len(unique_to_negative)

# Reshape them into layers (12) * the rest (36)
unique_to_positive_array = unique_to_positive_array.reshape((12, 36))
unique_to_negative_array = unique_to_negative_array.reshape((12, 36))

across_layer = True

# Normalise y_pred with softmax
if not across_layer: 
    y_pred = unique_to_positive_array.flatten()
    y_pred = softmax(y_pred, axis=0)
else:
    y_pred = unique_to_positive_array.copy()
    # Softmax across rows
    y_pred = softmax(y_pred, axis=0)


# Reshape
y_pred = y_pred.reshape((12, 36))

In [406]:
imshow(y_pred)

In [403]:
imshow(unique_to_positive_array)

In [404]:
# For each layer, take the sum of the groups of 3 as we move along the row
y_pred_final = np.zeros((12, 12))
for i in range(12):
    for j in range(12):
        y_pred_final[i, j] = np.sum(unique_to_positive_array[i, j*3:(j+1)*3])

In [405]:
imshow(y_pred_final)

In [444]:
# Negative and positive indices
halfway = all_indices.shape[0] // 2
positive_indices = all_indices[:halfway, :]
negative_indices = all_indices[halfway:, :]

unique_to_positive_array = np.zeros((len(head_labels)))
unique_to_negative_array = np.zeros((len(head_labels)))

for i in range(len(head_labels)):

    positive = set(positive_indices[:, i].tolist())
    negative = set(negative_indices[:, i].tolist())
    
    # Size of positive set minus size of negative set
    unique_to_positive = len(list(positive)) - len(list((negative)))

    if normalise:
        # Normalise by total number of unique indices
        unique_to_positive_array[i] = unique_to_positive / len(total_unique)
        unique_to_negative_array[i] = unique_to_positive / len(total_unique)
    
    else:
        # Set the values
        unique_to_positive_array[i] = unique_to_positive
        unique_to_negative_array[i] = unique_to_positive

# Reshape them into layers (12) * the rest (36)
unique_to_positive_array = unique_to_positive_array.reshape((12, 36))
unique_to_negative_array = unique_to_negative_array.reshape((12, 36))

across_layer = True

# Normalise y_pred with softmax
if not across_layer: 
    y_pred = unique_to_positive_array.flatten()
    y_pred = softmax(y_pred, axis=0)
else:
    y_pred = unique_to_positive_array.copy()
    # Softmax across rows
    y_pred = softmax(y_pred, axis=0)


# Reshape
y_pred = y_pred.reshape((12, 36))

In [445]:
imshow(y_pred)

In [446]:
# For each layer, take the sum of the groups of 3 as we move along the row
y_pred_final = np.zeros((12, 12))
for i in range(12):
    for j in range(12):
        y_pred_final[i, j] = np.sum(unique_to_positive_array[i, j*3:(j+1)*3])

imshow(y_pred_final)