### Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import dataclasses
import sys

import einops
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import torch
import tqdm.auto as tqdm
import transformer_lens.utils as tl_utils
from jaxtyping import Float
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from tuned_lens.nn import LogitLens, TunedLens
from tuned_lens.plotting import PredictionTrajectory

from pii import utils
from pii.prompt_data import LensData, PromptData

torch.set_grad_enabled(False)

### Load model / lenses

In [None]:
# You will need to login to huggingface first:
#   huggingface-cli login
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
hf_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    low_cpu_mem_usage=True,
)
logit_lens = LogitLens.from_model(hf_model).to("cuda")
tuned_lens = TunedLens.from_model_and_pretrained(hf_model).to("cuda")

In [None]:
tl_model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    hf_model=hf_model,
    device="cuda",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer,
)
tl_model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)

### Develop prompts

In [None]:
def get_corr_prompt(question: str, correct: bool = True):
    return f"""\
[INST] <<SYS>>
You are an assistant who only responds with a single word. \
You only give {'correct' if correct else 'incorrect'} answers.
<</SYS>>
Complete the statment. {question} [/INST] \
"""


def get_dontsay_prompt(question: str, forbidden_word: str):
    return f"""\
[INST] <<SYS>>
You are an obedient assistant who only responds with a single word. \
You will never say the forbidden word "{forbidden_word}".
<</SYS>>

{question} [/INST] \
"""


question = "Lionel Messi plays the sport of"

In [None]:
for i, corr in enumerate([True, False]):
    prompt = get_corr_prompt(question, correct=corr)
    print(prompt)
    utils.get_top_responses(
        prompt=prompt,
        model=tl_model,
        top_k=5,
    )
    if i == 0:
        print()

In [None]:
for i, forbidden_word in enumerate(["apple", "football"]):
    prompt = get_dontsay_prompt(question, forbidden_word)
    print(prompt)
    utils.get_top_responses(
        prompt=prompt,
        model=tl_model,
        top_k=5,
    )
    if i == 0:
        print()

### Cache activations

In [None]:
prompt_dict = dict(
    corr_base=PromptData.get_data(get_corr_prompt(question), tl_model),
    corr_competing=PromptData.get_data(get_corr_prompt(question, correct=False), tl_model),
    dontsay_base=PromptData.get_data(get_dontsay_prompt(question, "apple"), tl_model),
    dontsay_competing=PromptData.get_data(get_dontsay_prompt(question, "football"), tl_model),
)

for prompt_type in ["corr", "dontsay"]:
    base_data = prompt_dict[f"{prompt_type}_base"]
    competing_data = prompt_dict[f"{prompt_type}_competing"]

    print(f"Logits for {prompt_type}_base")
    print(base_data.ml_str, base_data.ml_token, base_data.final_logits[base_data.ml_token].item())
    print(competing_data.ml_str, competing_data.ml_token, base_data.final_logits[competing_data.ml_token].item())
    print(f"Logits for {prompt_type}_competing")
    print(base_data.ml_str, base_data.ml_token, competing_data.final_logits[base_data.ml_token].item())
    print(competing_data.ml_str, competing_data.ml_token, competing_data.final_logits[competing_data.ml_token].item())

In [None]:
logits_base, cache_base = tl_model.run_with_cache(
    get_corr_prompt(question, correct=True),
    remove_batch_dim=True,
)
logits_competing, cache_competing = tl_model.run_with_cache(
    get_corr_prompt(question, correct=False),
    remove_batch_dim=True,
)

logits_base = tl_utils.remove_batch_dim(logits_base)
logits_competing = tl_utils.remove_batch_dim(logits_competing)
logits_base.shape, logits_competing.shape

### Compute lens data

In [None]:
lens_dict = {
    k: LensData.get_data(
        prompt_data=v,
        tl_model=tl_model,
        logit_lens=logit_lens,
        tuned_lens=tuned_lens,
    )
    for k, v in prompt_dict.items()
}

### Analyze attention

In [None]:
import numpy as np

def subject_relation_attention(
        cache: dict,
        subject_index: int,
        relation_index: int,
        correct_indices: (int, int),
):
    total_list = []
    for layer_num in range(32):
        head_list = []
        for head_num in range(32):
            relation_attention = np.sum(cache[f"blocks.{layer_num}.attn.hook_pattern"][head_num, -1, relation_index:].cpu().numpy())
            subject_attention = np.sum(cache[f"blocks.{layer_num}.attn.hook_pattern"][head_num, -1, subject_index:relation_index].cpu().numpy())
            correct_attention = np.sum(cache[f"blocks.{layer_num}.attn.hook_pattern"][head_num, -1, correct_indices[0]:correct_indices[1]].cpu().numpy())
            head_list.append((relation_attention, subject_attention, correct_attention))
        total_list.append(head_list)
    return total_list

def topk_heads_attention(total_list, k=1):
    # Flatten list and associate each attention value with its layer and head
    flattened_list = [((layer_num, head_num), head_list[2]) for layer_num, layer in enumerate(total_list) for head_num, head_list in enumerate(layer)]
    
    # Sort based on correct attention
    sorted_list = sorted(flattened_list, key=lambda x: x[1], reverse=True)
    
    # Return top k
    return sorted_list[:k]


In [None]:
# Example usage:
total_list = subject_relation_attention(cache_competing, 40, 44, (25, 26))
top_heads = topk_heads_attention(total_list, k=5)
print(top_heads)

### Split DLA by attention

In [None]:
football_token = tokenizer.encode("football")[1]
print(football_token)

In [None]:
prompt = get_corr_prompt(question, correct=True)
print(len(tokenizer.encode(prompt)))
print(tokenizer.decode(tokenizer.encode(prompt)[40:44]))
print(tokenizer.decode(tokenizer.encode(prompt)[0:40]))
print(tokenizer.decode(tokenizer.encode(prompt)[44:52]))
print(tokenizer.decode(tokenizer.encode(prompt)[25:27]))

In [None]:
split_dla = dla_attention_token_group(cache_base, tl_model, football_token, (0,40), (40,44), (44, 52))

In [None]:
from jaxtyping import Float
from einops import rearrange, reduce, repeat
from torch import einsum
def dla_attention_token_group(
    cache: ActivationCache,
    model: HookedTransformer,
    all_tokenized_attributes,
    bos_token_indices,
    subject_token_indices,
    relation_token_indices,
) -> Float[torch.Tensor, "attribute model_component token_group"]:
    """DLA by attention head, by token group that is attended to

    The token groups are (bos, subject, relation).
    """

    layers = []

    logit_directions: Float[
        torch.Tensor, "attribute d_model"
    ] = model.tokens_to_residual_directions(torch.tensor(all_tokenized_attributes))

    for layer in range(model.cfg.n_layers):
        value: Float[torch.Tensor, "pos head_index d_head"] = cache[
            f"blocks.{layer}.attn.hook_v"
        ][0]

        pattern_post_softmax: Float[
            torch.Tensor, "head_index query_pos key_pos"
        ] = cache[f"blocks.{layer}.attn.hook_pattern"][0]
        
        # Z is usually calculated as values * attention, summed across keys
        z = einsum(
            "key_pos head_index d_head, \
                head_index query_pos key_pos -> \
                query_pos key_pos head_index d_head",
            value,
            pattern_post_softmax,
        )

        weights_output: Float[torch.Tensor, "head_index d_head d_model"] = model.blocks[
            layer
        ].attn.W_O

        # bias_output: Float[torch.Tensor, "d_model"] = model.blocks[layer].attn.b_O

        result = einsum(
            "query_pos key_pos head_index d_head, \
                    head_index d_head d_model -> \
                    query_pos key_pos head_index d_model",
            z,
            weights_output,
        )

        # Stacked head result (pos slice on last token)
        result_on_final_token: Float[
            torch.Tensor, "key_pos head_index d_model"
        ] = result[-1]

        # Scale
        if model.cfg.normalization_type not in ["LN", "LNPre"]:
            scaled_result = result_on_final_token
        else:
            center_stack = result_on_final_token - result_on_final_token.mean(
                dim=-1, keepdim=True
            )
            scale = cache["ln_final.hook_scale"][0, -1, :]  # first batch, last token
            scaled_result = center_stack / scale

        logit_attrs = einsum(
            "key_pos head_index d_model, attribute d_model -> key_pos head_index attribute",
            scaled_result,
            logit_directions,
        )

        logit_attrs = einops.rearrange(
            logit_attrs, "key_pos head_index attribute -> attribute head_index key_pos"
        )

        # Sum by category (bos, subject, relation)
        bos_token_attribution: Float[
            torch.Tensor, "attribute head_index"
        ] = logit_attrs[:, :, bos_token_indices].sum(dim=-1)
        subject_tokens_attribution: Float[
            torch.Tensor, "attribute head_index"
        ] = logit_attrs[:, :, subject_token_indices].sum(dim=-1)
        relation_token_attribution: Float[
            torch.Tensor, "attribute head_index"
        ] = logit_attrs[:, :, relation_token_indices].sum(dim=-1)

        combined: Float[torch.Tensor, "attribute head_index token_group"] = torch.stack(
            [
                bos_token_attribution,
                subject_tokens_attribution,
                relation_token_attribution,
            ],
            dim=-1,
        )
        layers.append(combined)

    layers_tensor: Float[
        torch.Tensor, "layer attribute head_index token_group"
    ] = torch.stack(layers, dim=0)

    rearranged = einops.rearrange(
        layers_tensor,
        "layer attribute head_index token_group -> attribute (layer head_index) token_group",
    )

    return rearranged

### Plot by head

In [None]:
ld_base = lens_dict["corr_base"]
ld_comp = lens_dict["corr_competing"]

logits_base = ld_base.get_logits("attn_head", lens_type="tuned")
logits_comp = ld_comp.get_logits("attn_head", lens_type="tuned")

utils.plot_head_data(
    lines=[
        (
            f"logit: |{ld_base.pd.ml_str}|", 
            logits_base[:, ld_base.pd.ml_token]
        ),
        (
            f"logit: |{ld_comp.pd.ml_str}|", 
            logits_base[:, ld_comp.pd.ml_token]
        ),
    ],
    annotation_text=ld_base.pd.prompt.split("\n")[-1],
    title="Answer correctly prompt",
    yaxis_title="Logit",
).show()

utils.plot_head_data(
    lines=[
        (
            f"logit: |{ld_base.pd.ml_str}|", 
            logits_comp[:, ld_base.pd.ml_token]
        ),
        (
            f"logit: |{ld_comp.pd.ml_str}|", 
            logits_comp[:, ld_comp.pd.ml_token]
        ),
    ],
    annotation_text=ld_comp.pd.prompt.split("\n")[-1],
    title="Answer incorrectly prompt",
    yaxis_title="Logit",
).show()

In [None]:
ld_base = lens_dict["dontsay_base"]
ld_comp = lens_dict["dontsay_competing"]

logits_base = ld_base.get_logits("attn_head", lens_type="tuned")
logits_comp = ld_comp.get_logits("attn_head", lens_type="tuned")

utils.plot_head_data(
    lines=[
        (
            f"logit: |{ld_base.pd.ml_str}|", 
            logits_base[:, ld_base.pd.ml_token]
        ),
        (
            f"logit: |{ld_comp.pd.ml_str}|", 
            logits_base[:, ld_comp.pd.ml_token]
        ),
    ],
    annotation_text=ld_base.pd.prompt.split("\n")[-1],
    title="Don't say apple",
    yaxis_title="Logit",
).show()

utils.plot_head_data(
    lines=[
        (
            f"logit: |{ld_base.pd.ml_str}|", 
            logits_comp[:, ld_base.pd.ml_token]
        ),
        (
            f"logit: |{ld_comp.pd.ml_str}|", 
            logits_comp[:, ld_comp.pd.ml_token]
        ),
    ],
    annotation_text=ld_comp.pd.prompt.split("\n")[-1],
    title="Don't say football",
    yaxis_title="Logit",
).show()

## Attention plots

In [None]:
str_tokens = tl_model.to_str_tokens(get_corr_prompt(question, correct=False))

In [None]:
str_tokens = tl_model.to_str_tokens(get_corr_prompt(question, correct=False))

In [None]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

In [None]:
cache_incorrect.keys()

In [None]:
for i in range(31, -1, -1):
    attention_pattern = cache_incorrect[f"blocks.{i}.attn.hook_pattern"]

    print(attention_pattern.shape)
    print(f"Layer {i} Head Attention Patterns:")
    display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))