# Head Boost for ICL

In [1]:
import os
os.environ['HF_HOME'] = '/workspace/huggingface'

from transformer_lens import HookedTransformer, ActivationCache, utils
import torch

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

import plotly.graph_objects as go
from plotly.offline import init_notebook_mode, iplot
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial

init_notebook_mode(connected=True)

Device: cuda


### Model

In [2]:
model_name = 'gpt2'

model = HookedTransformer.from_pretrained(model_name, device=device)

model.eval()
model.set_use_attn_result(True)
model.set_use_attn_in(True)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



Loaded pretrained model gpt2 into HookedTransformer


### Tasks

In [33]:
import json

with open("tasks.json", "r") as f:
    tasks = json.load(f)['tasks']

In [45]:
idx = 0

slices = []

pos = 0
train_prompt = ""
for i, (inp, out) in enumerate(zip(tasks[idx]['train_input'], tasks[idx]['train_output'])):
    train_prompt += inp + '\n' + out + "\n\n"
    pos += len(model.to_tokens(inp, prepend_bos=False)[0]) + 1
    if i != 0:
        slices.append((pos, pos + len(model.to_tokens(out, prepend_bos=False)[0])))
    pos += len(model.to_tokens(out, prepend_bos=False)[0]) + 1 + int(model_name == 'gpt2')

test_prompt = ""
for inp, out in zip(tasks[idx]['test_input'], tasks[idx]['test_output']):
    test_prompt += inp + '\n' + out + "\n\n"

In [46]:
idxs = []
for s in slices:
    idxs.extend(range(s[0], s[1]))
idxs = torch.tensor(idxs, requires_grad=False, device=device)

for i, c in enumerate(model.to_str_tokens(train_prompt)):
    if i in idxs:
        print('\033[1m' + c + '\033[0m', sep='', end='')
    else:
        print(c, sep='', end='')

<|endoftext|>I'm Davide, I'm 20 years old and I live in Rome.
{
'name': 'Davide',
'age': '20',
'city': 'Rome'
}

My name is Susan and I live in San Francisco. I've just turned 12.[1m
[0m[1m{[0m[1m
[0m[1m'[0m[1mname[0m[1m':[0m[1m '[0m[1mSusan[0m[1m',[0m[1m
[0m[1m'[0m[1mage[0m[1m':[0m[1m '[0m[1m12[0m[1m',[0m[1m
[0m[1m'[0m[1mcity[0m[1m':[0m[1m '[0m[1mSan[0m[1m Francisco[0m[1m'[0m[1m
[0m}



In [47]:
print(list(enumerate(model.to_str_tokens(train_prompt))))

[(0, '<|endoftext|>'), (1, 'I'), (2, "'m"), (3, ' Dav'), (4, 'ide'), (5, ','), (6, ' I'), (7, "'m"), (8, ' 20'), (9, ' years'), (10, ' old'), (11, ' and'), (12, ' I'), (13, ' live'), (14, ' in'), (15, ' Rome'), (16, '.'), (17, '\n'), (18, '{'), (19, '\n'), (20, "'"), (21, 'name'), (22, "':"), (23, " '"), (24, 'D'), (25, 'av'), (26, 'ide'), (27, "',"), (28, '\n'), (29, "'"), (30, 'age'), (31, "':"), (32, " '"), (33, '20'), (34, "',"), (35, '\n'), (36, "'"), (37, 'city'), (38, "':"), (39, " '"), (40, 'R'), (41, 'ome'), (42, "'"), (43, '\n'), (44, '}'), (45, '\n'), (46, '\n'), (47, 'My'), (48, ' name'), (49, ' is'), (50, ' Susan'), (51, ' and'), (52, ' I'), (53, ' live'), (54, ' in'), (55, ' San'), (56, ' Francisco'), (57, '.'), (58, ' I'), (59, "'ve"), (60, ' just'), (61, ' turned'), (62, ' 12'), (63, '.'), (64, '\n'), (65, '{'), (66, '\n'), (67, "'"), (68, 'name'), (69, "':"), (70, " '"), (71, 'Susan'), (72, "',"), (73, '\n'), (74, "'"), (75, 'age'), (76, "':"), (77, " '"), (78, '12'), 

### Boosting

In [48]:
def head_modifier_hook(x, hook, lam):
    x = lam[None, None, :, None] * x # b pos head dim
    return x

def mlp_modifier_hook(x, hook, gam):
    x = gam[None, :, None] * x # b pos dim
    return x

In [49]:
# Tuned lambdas

start_layer = 6

lambdas = torch.nn.Parameter(torch.ones(
    (model.cfg.n_layers - start_layer, model.cfg.n_heads), device=device), requires_grad=True
)

gammas = torch.nn.Parameter(torch.ones(
    (model.cfg.n_layers - start_layer), device=device), requires_grad=True
)

optimizer = torch.optim.Adam([lambdas], lr=0.1) 

for param in model.parameters():
    param.requires_grad = False

tokens = model.to_tokens(train_prompt)
labels = tokens[:, idxs+1]

In [50]:
from torch.nn import CrossEntropyLoss

loss_fn = CrossEntropyLoss()
losses = []
l1_coefficient = 0.001  # Set this to regulate l1 penalty
l2_coefficient = 0.000

for e in tqdm(range(40)):
    proba = model.run_with_hooks(
                tokens,
                fwd_hooks=[(
                        f"blocks.{l}.attn.hook_result",
                        partial(head_modifier_hook, lam=lambdas[l - start_layer]),
                    ) for l in range(start_layer, model.cfg.n_layers)] + 
                    [(
                        f"blocks.{l}.hook_mlp_out",
                        partial(mlp_modifier_hook, gam=gammas[None, l - start_layer]),
                    ) for l in range(start_layer, model.cfg.n_layers)]
            ).softmax(-1)[:, idxs]
    
    loss = loss_fn(proba.view(-1, model.cfg.d_vocab), labels.view(-1))
    
    # Add the L1 regularization term to the loss
    l1_norm = lambdas.abs().sum() #+ gammas.abs().sum()
    l2_norm = lambdas.norm(p=2) 
    loss += l1_coefficient * l1_norm
    loss += l2_coefficient * l2_norm

    losses.append(loss.item())

    optimizer.zero_grad()  # Clear previous gradients
    loss.backward()        # Compute gradients for `lambdas`
    optimizer.step()

100%|██████████| 40/40 [00:02<00:00, 16.53it/s]


In [51]:
import plotly.express as px

fig = px.line(y=losses, title='Loss')
fig.write_html('loss.html')

In [52]:
import plotly.express as px
import pandas as pd
import numpy as np

data = lambdas.detach().cpu().numpy()
fig = px.imshow(data,
                labels=dict(x="Heads", y="Layers", color="Lambda"),
                title="Lambda values", aspect='auto', color_continuous_scale='RdBu', zmin=-2, zmax=2)

# Update the heatmap to show the annotations
fig.update_layout(
    height=800,
    yaxis=dict(
        tickmode='array', 
        tickvals=list(range(model.cfg.n_layers - start_layer)),
        ticktext=list(range(start_layer, model.cfg.n_layers))
    )
)
fig.update_traces(showscale=True)

# Round the annotations to 3 decimal places
for i in range(model.cfg.n_heads):
    for j in range(model.cfg.n_layers - start_layer):
        fig.add_annotation(dict(font=dict(color="black",size=12),
                                x=i,
                                y=j,
                                text=str(round(data[j, i], 3)),
                                showarrow=False,
                                align='center',
                                opacity=0.6))

fig.write_html('lambdas.html')
#fig.show()

In [53]:
layer_id = 9
head_id = 7

with torch.no_grad():
    _, cache = model.run_with_cache(model.to_tokens(train_prompt))

data = cache[f'blocks.{layer_id}.attn.hook_pattern'][0, head_id].cpu()

labels = [f"{tok} ({i})" for i, tok in enumerate(model.to_str_tokens(train_prompt))]

# Create the plot using Plotly Express
fig = px.imshow(
    data,
    labels=dict(x="Keys", y="Queries", color="Attention Score"),
    x=labels,
    y=labels,
    title=f'Attention patter at head {head_id} of layer {layer_id}',
    color_continuous_scale="Blues",
    aspect='auto'
)

# Adjust the layout for better readability
fig.update_xaxes(tickangle=35)
fig.update_layout(coloraxis_colorbar=dict(title="Score"), height=800)
fig.write_html('pattern.html')

### Testing

In [54]:
print(model.generate(test_prompt, stop_at_eos=False, temperature=0, max_new_tokens=64))

  0%|          | 0/64 [00:00<?, ?it/s]

I'm Davide, I'm 20 years old and I live in Rome.
{
'name': 'Davide',
'age': '20',
'city': 'Rome'
}

Hello, I'm Paul and I'm 36. I am an engineer and I've just moved to Taipei.


I'm a member of the Taipei-based team, and I'm a member of the Taipei-based team, and I'm a member of the Taipei-based team, and I'm a member of the Taipei-based team, and I'm a member of the Taipei-based team,


In [55]:
tokens = model.to_tokens(test_prompt)
max_new_tokens = 64

for i in tqdm(range(max_new_tokens)): 
    with torch.no_grad():
        new_tok = model.run_with_hooks(
            tokens,
            fwd_hooks=[(
                        f"blocks.{l}.attn.hook_result",
                        partial(head_modifier_hook, lam=lambdas[l - start_layer]),
                    ) for l in range(start_layer, model.cfg.n_layers)] + 
                    [(
                        f"blocks.{l}.hook_mlp_out",
                        partial(mlp_modifier_hook, gam=gammas[None, l - start_layer]),
                    ) for l in range(start_layer, model.cfg.n_layers)]
        ).argmax(-1)[:, -1, None]

    tokens = torch.cat([tokens, new_tok], dim=-1)

print(model.to_string(tokens)[0])

100%|██████████| 64/64 [00:02<00:00, 26.75it/s]

<|endoftext|>I'm Davide, I'm 20 years old and I live in Rome.
{
'name': 'Davide',
'age': '20',
'city': 'Rome'
}

Hello, I'm Paul and I'm 36. I am an engineer and I've just moved to Taipei.


{
'name': 'Paul',
'age': '36',
'city': 'Taipei'

}

{
'name': 'Paul',
'age': '36',
'city': 'Taipei'

}

{

'name': 'Paul



