## Head Scaling

https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=3XtmNaDFO0eu


Pythia-1.4b
https://arxiv.org/pdf/2310.15910.pdf <br>
memory head (15.7), in-context head (19.14)


In [191]:
import torch, transformer_lens, itertools
from functools import partial
from measureLM import visualizing, patching

from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)

import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import numpy as np

In [3]:
model = transformer_lens.HookedTransformer.from_pretrained("pythia-1.4b").to("cpu")
model.cfg.spacing = "Ġ"
model.tokenizer.pad_token = model.tokenizer.eos_token

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/2.93G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model pythia-1.4b into HookedTransformer
Moving model to device:  cpu


In [295]:
def get_token_ids(toks=["Berlin", "Paris"]):
    #token_idcs = torch.tensor([model.tokenizer.convert_tokens_to_ids(model.cfg.spacing + tok) for tok in toks])
    token_idcs = torch.tensor([model.to_tokens(f" {tok}", prepend_bos=False)[...,0] for tok in toks])
    return token_idcs

def get_token_ranks(logits, toks=["Berlin", "Paris"]):
    
    ## sorting
    scores = logits[...,-1,:].squeeze()
    token_ranks = torch.argsort(scores.squeeze(), descending=True)
    sorted_token_scores = scores[token_ranks]
    
    ## select tokens
    token_ids = get_token_ids(toks)
    token_ranks = {toks[i]: torch.where(token_ranks == tok_id)[0].item() for i, tok_id in enumerate(get_token_ids(toks))}
    #token_ranks = list(map(lambda token_rank: round(1 / (token_rank + 1), 4), token_ranks))
    return token_ranks


In [296]:
prompt = "The capital of Germany is Paris. Q: What is the capital of Germany? A:" #The capital of Jamaica is Paris. 
logits, activs = model.run_with_cache(prompt)
get_token_ranks(logits, ["Berlin", "Paris"])

{'Berlin': 0, 'Paris': 2}

In [299]:
def get_token_idx(toks=["Berlin", "London"]):
    token_idcs = torch.tensor([model.tokenizer.convert_tokens_to_ids(model.cfg.spacing + tok) for tok in toks])
    return token_idcs

def scale_attn_vec(attn_head, hook: HookPoint, head_idx=0, alpha=1.0): 
    print(f'patching {hook.name}, head: {head_idx}, alpha {alpha}')
    ## shape: batch, tokens, heads, dim (model.cfg.d_head)
    attn_head[...,-1,head_idx,:] = attn_head[...,-1,head_idx,:] * alpha
    return attn_head

#memory head (15.7), in-context head (19.14)
def intervene(prompt, model, layer_idx=15, head_idx=7, alpha=1.0, att_comp="attn.hook_v"): ##hook_v
    patch_hook_fn = [(f"blocks.{layer_idx}.{att_comp}", partial(scale_attn_vec, head_idx=head_idx, alpha=alpha))]
    patched_logits = model.run_with_hooks(prompt, fwd_hooks=patch_hook_fn, return_type="logits", reset_hooks_end=True)
    return patched_logits

patched_logits = intervene(prompt, model, layer_idx=15, head_idx=7, alpha=-2.0)
get_token_ranks(patched_logits, ["Berlin", "Paris"])

patching blocks.15.attn.hook_z, head: 7, alpha -2.0


{'Berlin': 3, 'Paris': 2}

In [301]:
## sorting
scores = patched_logits[...,-1,:].squeeze()
token_ranks = torch.argsort(scores.squeeze(), descending=True)
