In [1]:
# Import stuff
import torch
import torch.nn as nn
import einops
import plotly.express as px

from jaxtyping import Float
from functools import partial

In [2]:
# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import circuitsvis as cv

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x119cb8f20>

In [4]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [5]:
device = utils.get_device()
device

device(type='mps')

In [25]:
model = HookedTransformer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", device=device, dtype="float16")
# model = HookedTransformer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device=device)
# model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", device=device)

Loading checkpoint shards: 100%|██████████| 4/4 [00:26<00:00,  6.64s/it]


Loaded pretrained model meta-llama/Llama-3.1-8B-Instruct into HookedTransformer


# Identify Induction Heads

In [8]:
text = "one two three one two three one two three"
tokens = model.to_tokens(text)
print(tokens.device)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)

mps:0


In [9]:
print(type(cache))
attention_pattern = cache["pattern", 0, "attn"]
print(attention_pattern.shape)
str_tokens = model.to_str_tokens(text)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([32, 10, 10])


In [10]:
print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

Layer 0 Head Attention Patterns:


In [26]:
def plot_head_detection_scores(
    scores: torch.Tensor,
    zmin: float = -1,
    zmax: float = 1,
    xaxis: str = "Head",
    yaxis: str = "Layer",
    title: str = "Head Matches"
) -> None:
    imshow(scores, zmin=zmin, zmax=zmax, xaxis=xaxis, yaxis=yaxis, title=title)

In [27]:
prompts = [
    "one two three one two three one two three",
    "1 2 3 4 5 1 2 3 4 1 2 3 1 2 3 4 5 6 7",
    "green ideas sleep furiously; green ideas don't sleep furiously"
]

In [28]:
from transformer_lens.head_detector import detect_head


head_scores = detect_head(model, prompts, "induction_head", exclude_bos=False, exclude_current_token=False, error_measure="abs")

In [29]:
head_scores

tensor([[0.1013, 0.0567, 0.0922,  ..., 0.0012, 0.0047, 0.0000],
        [0.0107, 0.0063, 0.0032,  ..., 0.0057, 0.0080, 0.0189],
        [0.0087, 0.0120, 0.0119,  ..., 0.0263, 0.0034, 0.0038],
        ...,
        [0.0032, 0.0030, 0.0034,  ..., 0.0219, 0.0892, 0.0732],
        [0.0266, 0.0748, 0.0368,  ..., 0.0543, 0.0643, 0.0922],
        [0.0359, 0.0540, 0.0267,  ..., 0.0523, 0.0583, 0.0261]],
       dtype=torch.float16)

In [30]:
plot_head_detection_scores(head_scores, title="Induction head; average across 3 prompts")

In [85]:
batch_size = 10
seq_len = 50
size = (batch_size, seq_len)
input_tensor = torch.randint(1000, 10000, size)

random_tokens = input_tensor.to(model.cfg.device)
repeated_tokens = einops.repeat(random_tokens, "batch seq_len -> batch (2 seq_len)")

In [86]:
# We make a tensor to store the induction score for each head. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
def induction_score_hook(
    pattern: Float[torch.Tensor, "batch head_index dest_pos source_pos"],
    hook: HookPoint,
):
    # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
    # (This only has entries for tokens with index>=seq_len)
    induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)
    # Get an average score per head
    induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score

# We make a boolean filter on activation names, that's true only on attention pattern names.
pattern_hook_names_filter = lambda name: name.endswith("pattern")

model.run_with_hooks(
    repeated_tokens, 
    return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        induction_score_hook
    )]
)

imshow(induction_score_store, xaxis="Head", yaxis="Layer", title="Induction Score by Head")

In [141]:
from collections import defaultdict
import numpy as np

min_score = 0.6 # Manually inspected
top_k = 5 # TODO
layer_inds, head_inds = np.where(induction_score_store.cpu().numpy() > min_score)

induction_layer_to_head_map = defaultdict(list)
for layer, head in zip(layer_inds, head_inds):
    induction_layer_to_head_map[layer].append(head)

induction_layer_to_head_map

defaultdict(list,
            {np.int64(2): [np.int64(22)],
             np.int64(5): [np.int64(8)],
             np.int64(8): [np.int64(1)],
             np.int64(10): [np.int64(14)],
             np.int64(15): [np.int64(1), np.int64(30)],
             np.int64(16): [np.int64(20)]})

In [126]:
size = (1, 5)
input_tensor = torch.randint(1000, 10000, size)

# TODO it would be nice to add a BOS token at the beginning
single_random_sequence = input_tensor.to(model.cfg.device)
repeated_random_sequence = einops.repeat(single_random_sequence, "batch seq_len -> batch (2 seq_len)")

repeated_random_sequence

tensor([[5185, 5331, 5985, 6836, 4416, 5185, 5331, 5985, 6836, 4416]],
       device='mps:0')

In [142]:
def get_model_attention_patterns(model, input: torch.Tensor, induction_layer_to_head_map, visualise: bool = False):
    model.reset_hooks()

    pattern_store = {}
    def visualize_pattern_hook(
        pattern: Float[torch.Tensor, "batch head_index dest_pos source_pos"],
        hook: HookPoint,
    ):
        layer_ind = hook.layer()
        for head_ind in layer_to_head_map[layer_ind]:
            p = pattern[0, head_ind, :, :].detach().cpu().numpy()
            pattern_store[f"layer_{layer_ind}_head_{head_ind}"] = p

            if visualise:
                display(
                    cv.attention.attention_patterns(
                        tokens=model.to_str_tokens(input, ), 
                        attention=p[None, :, :] # Add a dummy axis, as CircuitsVis expects 3D patterns.
                    )
                )

    model.run_with_hooks(
        input, 
        return_type=None, 
        fwd_hooks=[
            (utils.get_act_name("pattern", layer), visualize_pattern_hook) for layer in induction_layer_to_head_map.keys()
        ]
    )

    model.reset_hooks()

    return pattern_store

pattern_store_res = get_model_attention_patterns(model, repeated_random_sequence, induction_layer_to_head_map)

In [143]:
pattern_store_res.keys()

dict_keys(['layer_2_head_22', 'layer_5_head_8', 'layer_8_head_1', 'layer_10_head_14', 'layer_15_head_1', 'layer_15_head_30', 'layer_16_head_20'])

In [144]:
from scipy.stats import mode


def get_token_attention(pattern_store: dict) -> np.ndarray:
    seq_len = next(iter(pattern_store.values())).shape[0]

    max_token_attention = np.zeros((len(pattern_store), seq_len), dtype=int)

    for i, pattern in enumerate(pattern_store.values()):
        max_token_attention[i] = pattern.argmax(axis=1)

    token_attention = mode(max_token_attention, axis=0).mode

    return token_attention

token_attention_res = get_token_attention(pattern_store_res)
token_attention_res

array([0, 0, 0, 0, 0, 0, 2, 3, 0, 5])

# Generate from model answers to MSMarco

In [34]:
from pathlib import Path

from data_utils import load_prompts_from_msmarco_samples_from_rag_truth


dataset_path = Path("/Users/oganes/citations/RAGTruth-main/dataset/")
prompts = load_prompts_from_msmarco_samples_from_rag_truth(dataset_path)
len(prompts)

989

In [35]:
prompt = prompts[1]
print(prompt)

Briefly answer the following question:
tips how to conserve water
Bear in mind that your response should be strictly based on the following three passages:
passage 1:1 Take shorter showers. 2  Replace you showerhead with an ultra-low-flow version. 3  Some units are available that allow you to cut off the flow without adjusting the water temperature knobs. 4  Use the minimum amount of water needed for a bath by closing the drain first and filling the tub only 1/3 full.

passage 2:Here are 20 water-saving tips to get you going…. 1. Shower Bucket. Instead of letting the water pour down the drain, stick a bucket under the faucet while you wait for your shower water to heat up. You can use the water for flushing the toilet or watering your plants. 2. Turn off the tap while brushing your teeth. Water comes out of the average faucet at 2.5 gallons per minute. Don’t let all that water go down the drain while you brush! Turn off the faucet after you wet your brush, and leave it off until it’s t

In [36]:
torch.manual_seed(7575)

# TODO do people actuall generate with do_sample=True and temperature=1 in RAG scenarios?
res = model.generate(prompt, max_new_tokens=256, temperature=1, do_sample=True)

100%|██████████| 256/256 [00:46<00:00,  5.45it/s]


In [37]:
print(res)

Briefly answer the following question:
tips how to conserve water
Bear in mind that your response should be strictly based on the following three passages:
passage 1:1 Take shorter showers. 2  Replace you showerhead with an ultra-low-flow version. 3  Some units are available that allow you to cut off the flow without adjusting the water temperature knobs. 4  Use the minimum amount of water needed for a bath by closing the drain first and filling the tub only 1/3 full.

passage 2:Here are 20 water-saving tips to get you going…. 1. Shower Bucket. Instead of letting the water pour down the drain, stick a bucket under the faucet while you wait for your shower water to heat up. You can use the water for flushing the toilet or watering your plants. 2. Turn off the tap while brushing your teeth. Water comes out of the average faucet at 2.5 gallons per minute. Don’t let all that water go down the drain while you brush! Turn off the faucet after you wet your brush, and leave it off until it’s t

In [146]:
input_tensor = model.to_tokens(res).to(model.cfg.device)
len(input_tensor[0])

617

In [161]:
next(iter(induction_layer_to_head_map.items()))

(np.int64(2), [np.int64(22)])

In [166]:
layer_ind, head_indices = next(iter(induction_layer_to_head_map.items()))
first_induction_layer_to_head_map = {layer_ind: head_indices}
first_induction_layer_to_head_map

{np.int64(2): [np.int64(22)]}

In [167]:
def get_model_attention_patterns(model, input: torch.Tensor, layer_to_head_map, visualise: bool = False):
    model.reset_hooks()

    pattern_store = {}
    def visualize_pattern_hook(
        pattern: Float[torch.Tensor, "batch head_index dest_pos source_pos"],
        hook: HookPoint,
    ):
        layer_ind = hook.layer()
        for head_ind in layer_to_head_map[layer_ind]:
            p = pattern[0, head_ind, :, :].detach().cpu().numpy()
            pattern_store[f"layer_{layer_ind}_head_{head_ind}"] = p

            if visualise:
                display(
                    cv.attention.attention_patterns(
                        tokens=model.to_str_tokens(input, ), 
                        attention=p[None, :, :] # Add a dummy axis, as CircuitsVis expects 3D patterns.
                    )
                )

    model.run_with_hooks(
        input, 
        return_type=None, 
        fwd_hooks=[
            (utils.get_act_name("pattern", layer), visualize_pattern_hook) for layer in layer_to_head_map.keys()
        ]
    )

    model.reset_hooks()

    return pattern_store


# pattern_store_res = get_model_attention_patterns(model, input_tensor, induction_layer_to_head_map)
pattern_store_res = get_model_attention_patterns(model, input_tensor, first_induction_layer_to_head_map, visualise=True)

In [171]:
model_generation_without_prompt = res.replace(prompt, "")
len(res), len(model_generation_without_prompt)

(2714, 1109)

In [170]:
len(token_attention_res)

617

In [168]:
token_attention_res = get_token_attention(pattern_store_res)
token_attention_res

array([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,  32,  33,  34,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0, 125,   0,   0,   0,   0,   0,   0,   0,   0,   

In [184]:
prompt_token_count = len(model.to_tokens(prompt)[0])
generated_token_attention_res = token_attention_res[prompt_token_count:]
generated_token_attention_res

array([  0,   0,   0, 251, 252, 253, 254, 255,   0, 257, 258, 259, 260,
       261, 262, 263, 264, 265,   0,   0,   0,   0,   0, 268,   0, 270,
       271, 272,   0, 274, 275, 276, 277,   0,   0,   0, 281,   0,   0,
         0,   0, 283,   0, 285, 286, 287, 288, 289, 290, 291, 292,   0,
         0,   0,   0,   0,   0, 296,   0,   0, 299, 300,   0, 302, 303,
       304, 305,   0, 307, 308, 309, 310, 311, 312,   0, 314, 315, 316,
         0,   0,   0,   0,   0,   0, 320, 321,   0, 323, 324, 325, 326,
       327, 328, 329, 330,   0,   0,   0,   0,   0,   0, 175, 176, 240,
       178, 179, 180,   0,   0,   0, 243,   0,   0,   0,   0,   0,   0,
         0, 134, 135, 136, 137,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0, 151, 152, 153, 154, 155,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   

'!!!�����!    in t        er   on are!!!!!en! th\n\n c! sitanar!!! p!!!!ou!is       inges wionedic!!!!!! m!!roas!ctnd in h!id nam            to re! { ofom!!!!!! (il! andurse lex Sad "!!!!!!������!!!�!!!!!!!����!!!!!!!!!!!!!�����!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'

In [177]:
model_generation_without_prompt

'1. Water lawns in the early morning or later evening in 20 minute intervals. \n2. Use a broom instead of your hose to remove debris from your driveway or sidewalk. \n3. Bathe your pets outdoors in areas that need water. \n4. Install a shut-off nozzle on your water hose; a garden hose left running can waste up to 20 litres per minute. \n5. Direct downspouts towards shrubs and trees in your garden. \n6. Turn off the tap while brushing your teeth and washing your hands. \n7. Let the water pour down the drain in your shower and fill up a bucket while waiting for your shower water to heat up. \n8. Keep your tank water topped up (Interpreted as aware of when your water tank is full in order not to waste resources). \n9. Only put dishwashing liquid in the dishwasher just before turning on the machine, else you waste water. \n10. Water in the optimal time. Keep your lawn compost, because cut clippings are highly water. \n11. Use an old clothes line or retractable clothes line for drying cloth

In [176]:
generated_tokens_as_str = model.to_str_tokens(model_generation_without_prompt, prepend_bos=False)
len(generated_tokens_as_str), generated_tokens_as_str[:5]

(256, ['1', '.', ' Water', ' law', 'ns'])