In [None]:
try:
    import google.colab
    IN_COLAB = True
    !pip install circuitsvis
    !pip install git+https://github.com/neelnanda-io/TransformerLens.git
except:
    IN_COLAB = False

In [None]:
from transformer_lens import HookedTransformer
import torch
import einops
from transformer_lens import utils
from tqdm.auto import tqdm
from datasets import load_dataset

In [None]:
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = HookedTransformer.from_pretrained("attn-only-2l", device=device, fold_ln=True)
print(device)

In [None]:
def get_bigram_logits(tokens, model: HookedTransformer, add_bias=False):
    embed = model.W_E[tokens, :]
    unembed = einops.einsum(embed, model.W_U, "batch pos d_model, d_model d_vocab_out -> batch pos d_vocab_out")
    if add_bias:
        unembed += model.b_U
    return unembed

def get_topk_words(logits, model, pos=-1, k=10):
    # Input with batch dim = 1
    topk, topk_indices = torch.topk(logits[0, pos], k=k)
    tokens = model.tokenizer.convert_ids_to_tokens(topk_indices)
    return tokens

def get_log_probs(logits, tokens):
    log_probs = logits.log_softmax(dim=-1)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return log_probs_for_tokens

def get_log_probs_batched(logits, tokens):
    log_probs = logits.log_softmax(dim=-1)
    print("Dims:", tokens.shape, log_probs.shape)
    log_probs_for_tokens = log_probs[:, :, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    print("Res dim", log_probs_for_tokens.shape)
    return log_probs_for_tokens

## Bigram and embedding analysis

In [None]:
text = "Social security is a government program that produces exuberant daisies"
tokens = model.to_tokens(text)
logits, cache = model.run_with_cache(tokens)
print(logits.shape)

In [None]:
predicted_tokens = logits[0].argmax(dim=-1)
predicted_text = model.to_string(predicted_tokens)
print(predicted_text)

In [None]:
media_token = predicted_tokens[1]
security_token = tokens[0, 2]

print(media_token, model.to_string(media_token))
print(security_token, model.to_string(security_token))

In [None]:
head_results = cache.stack_head_results(apply_ln=True)
head_results.shape # head batch pos res

In [None]:
print("Unigram most likely tokens")
get_topk_words(model.b_U.view(1, 1, -1), model)

In [None]:
def zero_ablation_hook(value, hook):
    return torch.zeros_like(value)

layer_to_ablate = 0
original_losses = []
ablation_losses = []

original_loss = model(tokens, return_type="loss")
ablated_loss = model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(
            f"blocks.0.hook_attn_out",
            zero_ablation_hook
            )]
        )

bigram_pred = get_bigram_logits(tokens, model)
bigram_loss = - get_log_probs(bigram_pred, tokens).mean()

print(f"Original model loss: {original_loss:.6f}")
print(f"Ablated loss: {ablated_loss:.6f}")
print(f"Bigram loss {bigram_loss}")

### Head attribution

In [None]:
# Result of individual heads, shape=(head, batch, pos, d_model)
head_results = cache.stack_head_results(apply_ln=True)

# Directions of output logits in the residual stream, shape=(batch, pos, d_model)
model_tokens = logits.argmax(dim=-1)
directions = model.tokens_to_residual_directions(model_tokens)

In [None]:
# Dot product between the head activations and the output directions
head_attribution = einops.einsum(head_results, directions, "head batch pos d_model, batch pos d_model -> head")
print(head_attribution)

In [None]:
cos = torch.nn.CosineSimilarity(dim=-1)
# Cosine similarities between head outputs and directions, shape=(d_head, pos)
similarities = cos(directions[0], head_results[:, 0])
# Head similarity for "security" token
print(similarities[:, 2])
print(torch.max(similarities))

In [None]:
# Cosine similarity for the embedding
embed = model.W_E[tokens, :]
similarities = cos(directions[0], embed[0])
print(similarities)
print(torch.max(similarities))

## Induction analysis

In [None]:
#text = "Harry Potter is great. Harry"
text = "Leonard Potter is great. Leonard"
tokens = model.to_tokens(text)
logits, cache = model.run_with_cache(tokens)

In [None]:
# Compute cosine similarity of prediction " Potter" to head outputs
potter_token = model.to_tokens(" Potter", prepend_bos=False)
potter_direction = model.tokens_to_residual_directions(potter_token)
print("Potter direction shape", potter_direction.shape)

head_results = cache.stack_head_results(apply_ln=True)
print("Head result shape", head_results.shape) # head batch pos res

cos = torch.nn.CosineSimilarity(dim=-1)
similarities = cos(potter_direction, head_results[:, 0])
print("Cosine similarity shape", similarities.shape) # head position
# Similarity of Potter following the last Harry
similarities[:, -1]

In [None]:
# Most likely next words for "Leonard Potter is great. Leonard"
get_topk_words(logits, model)

In [None]:
bigram_logits = get_bigram_logits(tokens, model)
print(get_topk_words(bigram_logits, model, pos=-1))

In [None]:
utils.test_prompt("Leonard Potter is great. Leonard", " Potter", model, prepend_bos=True)

In [None]:
import plotly.express as px

head_results_final_token = head_results[:, 0, -1, :]
directions_final_token = directions[0, -1, :]

# Dot product between the head activations and the output directions
head_attribution = einops.einsum(head_results_final_token, directions_final_token, "head d_model, d_model -> head")
head_attribution = torch.reshape(head_attribution, (2, 8))

def imshow(tensor, renderer=None, **kwargs):
    preset_kwargs = {
        "color_continuous_midpoint": 0.0,
        "color_continuous_scale": "RdBu",
        "text_auto":".2f"
    }

    fig = px.imshow(utils.to_numpy(tensor), **{**preset_kwargs, **kwargs})
    fig.show(renderer=renderer)

imshow(head_attribution, labels={"x": "Head", "y": "Layer"}, title="Logit Attribution by Head")

cos = torch.nn.CosineSimilarity(dim=-1)
similarities = cos(potter_direction, head_results_final_token)
print("Cosine similarity shape", similarities.shape) # head position

similarities = torch.reshape(similarities, (2, 8))
# Similarity of Potter following the last Harry
imshow(similarities, labels={"x": "Head", "y": "Layer"}, title="Cosine similarities of heads at final token \"Harry\" with rank 4 answer token \"Potter\"")

is_token = model.to_tokens(" is", prepend_bos=False)
is_direction = model.tokens_to_residual_directions(is_token)
is_similarities = cos(is_direction, head_results_final_token)
is_similarities = torch.reshape(is_similarities, (2, 8))
imshow(is_similarities, labels={"x": "Head", "y": "Layer"}, title="Cosine similarities of heads at final token \"Harry\" with answer token \"is\"")

# Induction Heads Test

In [None]:
def generate_repeated_tokens(
    model: HookedTransformer, seq_len: int, batch: int = 1
): # Int[Tensor, "batch full_seq_len"]
    '''
    Generates a sequence of repeated random tokens
    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
    '''
    prefix = (torch.ones(batch, 1) * model.tokenizer.bos_token_id).long()
    rep_tokens_half = torch.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=torch.int64)
    rep_tokens = torch.cat([prefix, rep_tokens_half, rep_tokens_half], dim=-1).to(device)
    return rep_tokens

def run_and_cache_model_repeated_tokens(model: HookedTransformer, seq_len: int, batch: int = 1): # -> Tuple[torch.Tensor, torch.Tensor, ActivationCache]
    '''
    Generates a sequence of repeated random tokens, and runs the model on it, returning logits, tokens and cache
    Should use the `generate_repeated_tokens` function above
    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
        rep_logits: [batch, 1+2*seq_len, d_vocab]
        rep_cache: The cache of the model run on rep_tokens
    '''
    rep_tokens = generate_repeated_tokens(model, seq_len, batch)
    rep_logits, rep_cache = model.run_with_cache(rep_tokens)
    return rep_tokens, rep_logits, rep_cache
seq_len = 50
batch = 1
(rep_tokens, rep_logits, rep_cache) = run_and_cache_model_repeated_tokens(model, seq_len, batch)
rep_cache.remove_batch_dim()
rep_str = model.to_str_tokens(rep_tokens)
model.reset_hooks()
log_probs = get_log_probs(rep_logits, rep_tokens).squeeze()
print(f"Performance on the first half: {log_probs[:seq_len].mean():.3f}")
print(f"Performance on the second half: {log_probs[seq_len:].mean():.3f}")

# plot_loss_difference(log_probs, rep_str, seq_len)

In [None]:
#@title Induction Scores

seq_len = 50
batch = 10
rep_tokens_10 = generate_repeated_tokens(model, seq_len, batch)

# We make a tensor to store the induction score for each head.
# We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

def induction_score_hook(
    pattern, # : Float[Tensor, "batch head_index dest_pos source_pos"]
    hook,
):
    '''
    Calculates the induction score, and stores it in the [layer, head] position of the `induction_score_store` tensor.
    '''
    # Take the diagonal of attn paid from each dest position to src positions (seq_len - 1) tokens back
    # (This only has entries for tokens with index >= seq_len)
    induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)
    # Get an average score per head
    induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score

pattern_hook_names_filter = lambda name: name.endswith("pattern")

# Run with hooks (this is where we write to the `induction_score_store` tensor`)
model.run_with_hooks(
    rep_tokens_10,
    return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        induction_score_hook
    )]
)

# Plot the induction scores for each head in each layer
imshow(
    induction_score_store,
    labels={"x": "Head", "y": "Layer"},
    title="Induction Score by Head",
    text_auto=".2f",
    width=900, height=400
)

In [None]:
#@title Prev token scores

seq_len = 50
batch = 10
rep_tokens_10 = generate_repeated_tokens(model, seq_len, batch)

prev_token_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

def prev_token_score_hook(
    pattern, # : Float[Tensor, "batch head_index dest_pos source_pos"]
    hook,
):
    # (This only has entries for tokens with index >= seq_len)
    prev_token_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=-1)
    # Get an average score per head
    prev_token_score = einops.reduce(prev_token_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    prev_token_score_store[hook.layer(), :] = prev_token_score

pattern_hook_names_filter = lambda name: name.endswith("pattern")

# Run with hooks (this is where we write to the `prev_token_score_store` tensor`)
model.run_with_hooks(
    rep_tokens_10,
    return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        prev_token_score_hook
    )]
)

# Plot the prev_token scores for each head in each layer
imshow(
    prev_token_score_store,
    labels={"x": "Head", "y": "Layer"},
    title="Prev Token Score by Head",
    text_auto=".2f",
    width=900, height=400
)

In [None]:
seq_len = 100
batch = 50
rep_tokens_for_mean = generate_repeated_tokens(model, seq_len, batch)
rand_tokens_for_mean = rep_tokens_for_mean[:, :101]
logits, cache = model.run_with_cache(rand_tokens_for_mean)

mean_acts_0 = cache["blocks.0.attn.hook_pattern"].mean(dim=0)
mean_acts_1 = cache["blocks.1.attn.hook_pattern"].mean(dim=0)


In [None]:
#@title Define zero ablation hooks and try out

# print(cache.keys())

def head_3_zero_ablation_hook(
    value,
    hook
):
    value[:, 3, :, :] = 0.
    return value
def head_6_zero_ablation_hook(
    value,
    hook
):
    print(value.shape)
    value[:, 6, :, :] = 0.
    return value

def head_3_mean_ablation_hook(
    pattern,
    hook
):
    _, pattern_head, pattern_q_pos, pattern_v_pos = pattern.shape
    pattern[:, 3, :, :] = mean_acts_0[3, :pattern_q_pos, :pattern_v_pos]
    return pattern

def head_6_mean_ablation_hook(
    pattern,
    hook
):
    _, pattern_head, pattern_q_pos, pattern_v_pos = pattern.shape
    pattern[:, 6, :, :] = mean_acts_1[6, :pattern_q_pos, :pattern_v_pos]
    return pattern

logits = model(rep_tokens)
seq_len = 50
batch = 1
model.reset_hooks()
log_probs = get_log_probs(logits, rep_tokens).squeeze()
print(f"Performance on the first half: {log_probs[:seq_len].mean():.3f}")
print(f"Performance on the second half: {log_probs[seq_len:].mean():.3f}")


with model.hooks(fwd_hooks=[("blocks.0.attn.hook_pattern", head_3_mean_ablation_hook),
                            ("blocks.1.attn.hook_pattern", head_6_mean_ablation_hook)]):
    logits = model(rep_tokens)


seq_len = 50
batch = 1
model.reset_hooks()
log_probs = get_log_probs(logits, rep_tokens).squeeze()
print(f"Performance on the first half with ablated induction circuit: {log_probs[:seq_len].mean():.3f}")
print(f"Performance on the second half with ablated induction circuit: {log_probs[seq_len:].mean():.3f}")



# non_induction_pattern = lambda name : name != "blocks.0."

In [None]:
#@title Induction scores with highest scoring induction head mean ablated

seq_len = 50
batch = 10
rep_tokens_10 = generate_repeated_tokens(model, seq_len, batch)

# We make a tensor to store the induction score for each head.
# We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

def induction_score_hook(
    pattern, # : Float[Tensor, "batch head_index dest_pos source_pos"]
    hook,
):
    '''
    Calculates the induction score, and stores it in the [layer, head] position of the `induction_score_store` tensor.
    '''
    _, pattern_head, pattern_q_pos, pattern_v_pos = pattern.shape
    if hook.layer() == 1:
        pattern[:, 6, :, :] = mean_acts_1[6, :pattern_q_pos, :pattern_v_pos]
    # Take the diagonal of attn paid from each dest position to src positions (seq_len - 1) tokens back
    # (This only has entries for tokens with index >= seq_len)
    induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)
    # Get an average score per head
    induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score

pattern_hook_names_filter = lambda name: name.endswith("pattern")

# Run with hooks (this is where we write to the `induction_score_store` tensor`)
model.run_with_hooks(
    rep_tokens_10,
    return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[
        (pattern_hook_names_filter, induction_score_hook)]
)

# Plot the induction scores for each head in each layer
imshow(
    induction_score_store,
    labels={"x": "Head", "y": "Layer"},
    title="Induction Score by Head",
    text_auto=".2f",
    width=900, height=400
)

No backup induction heads!!

In [None]:
#@title Prev token scores with the highest scoring previous token head mean ablated (useless because turning off one previous token heads won't affect the others)

seq_len = 50
batch = 10
rep_tokens_10 = generate_repeated_tokens(model, seq_len, batch)

prev_token_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

def prev_token_score_hook(
    pattern, # : Float[Tensor, "batch head_index dest_pos source_pos"]
    hook,
):
    _, pattern_head, pattern_q_pos, pattern_v_pos = pattern.shape
    if hook.layer() == 0:
        pattern[:, 3, :, :] = mean_acts_0[3, :pattern_q_pos, :pattern_v_pos]
    # (This only has entries for tokens with index >= seq_len)
    prev_token_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=-1)
    # Get an average score per head
    prev_token_score = einops.reduce(prev_token_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    prev_token_score_store[hook.layer(), :] = prev_token_score

pattern_hook_names_filter = lambda name: name.endswith("pattern")

# Run with hooks (this is where we write to the `prev_token_score_store` tensor`)
logits = model.run_with_hooks(
    rep_tokens,
    # return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        prev_token_score_hook
    )]
)

seq_len = 50
batch = 1
model.reset_hooks()
log_probs = get_log_probs(logits, rep_tokens).squeeze()
print(f"Performance on the first half: {log_probs[:seq_len].mean():.3f}")
print(f"Performance on the second half: {log_probs[seq_len:].mean():.3f}")


# Plot the prev_token scores for each head in each layer
imshow(
    prev_token_score_store,
    labels={"x": "Head", "y": "Layer"},
    title="Prev Token Score by Head",
    text_auto=".2f",
    width=900, height=400
)

In [None]:
#@title Ablate everything except the induction circuit

seq_len = 50
batch = 10
rep_tokens_10 = generate_repeated_tokens(model, seq_len, batch)
prev_token_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

def prev_token_score_hook(
    pattern, # : Float[Tensor, "batch head_index dest_pos source_pos"]
    hook,
):
    _, pattern_head, pattern_q_pos, pattern_v_pos = pattern.shape
    if hook.layer() == 0:
        nums_not_3 = [i for i in range(8) if i != 3]
        pattern[:, nums_not_3, :, :] = mean_acts_0[nums_not_3, :pattern_q_pos, :pattern_v_pos]
    if hook.layer() == 1:
        nums_not_6 = [i for i in range(8) if i != 6]
        pattern[:, nums_not_6, :, :] = mean_acts_1[nums_not_6, :pattern_q_pos, :pattern_v_pos]

    # (This only has entries for tokens with index >= seq_len)
    prev_token_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=-1)
    # Get an average score per head
    prev_token_score = einops.reduce(prev_token_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    prev_token_score_store[hook.layer(), :] = prev_token_score

pattern_hook_names_filter = lambda name: name.endswith("pattern")

# Run with hooks (this is where we write to the `prev_token_score_store` tensor`)
logits = model.run_with_hooks(
    rep_tokens,
    # return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        prev_token_score_hook
    )]
)

# with model.hooks(fwd_hooks=[(
#         pattern_hook_names_filter,
#         prev_token_score_hook
#     )]):
#   logits = model(rep_tokens)
#   utils.test_prompt("apple banana orange apple banana", " orange", model, prepend_bos=True)

seq_len = 50
batch = 1
model.reset_hooks()
log_probs = get_log_probs(logits, rep_tokens).squeeze()
print(f"Performance on the first half: {log_probs[:seq_len].mean():.3f}")
print(f"Performance on the second half: {log_probs[seq_len:].mean():.3f}")


# Plot the prev_token scores for each head in each layer
imshow(
    prev_token_score_store,
    labels={"x": "Head", "y": "Layer"},
    title="Previous token scores for mean ablations on everything except our induction head and previous token head",
    text_auto=".2f",
    width=900, height=400
)

In [None]:
dataset = load_dataset("NeelNanda/c4-10k", split="train")
dataloader = torch.utils.data.DataLoader(dataset["text"], batch_size=32)

In [None]:
total_loss = 0
for batch in tqdm(dataloader):
    tokens = model.to_tokens(batch)
    loss = model(batch, return_type="loss")
    total_loss += loss
print(f"Loss {(total_loss / len(dataloader)):.2f}")

## Bigram evaluation

Motivation
- We want to understand how the 2L model predicts stuff
- The two easiest components are bigrams and unigrams
- Unigrams should be in the unembed bias (?)
- We want bigram statistics
- Embed - unembed doesn't give them to us because all heads can retrieve additional information about the current token
- We can try to recover bigrams from heads by using self-attention values

Plan
- Compute bigrams on c4 10k dataset
- Construct easy test cases that can be solved by doing bigram lookup
- See how well embed-unembed does on them - if embed-unembed works well we're done here if not the attention heads do relevant bigram stuff too
- Check with logit attribution where the model stores bigram information

How to see where the model stores bigram information
- Idea: set attention patterns to diagonal (either original diagonal values or 1s) to make model only attend to current token which means it can only do bigram from current to next token
- Problem: setting other attentions to 0 will probably mess with the outputs

In [None]:
all_text = dataset["text"]
all_tokens = model.to_tokens(all_text)
print(all_tokens.shape)

In [None]:

for example in all_tokens[:10]:
    for token in range(len(example)-1):
        bigram = example[token:token+2]
        print(bigram)