# Head Boost for ICL

In [204]:
import os
os.environ['HF_HOME'] = '/workspace/huggingface'

from transformer_lens import HookedTransformer, ActivationCache, utils
import torch

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

import plotly.graph_objects as go
from plotly.offline import init_notebook_mode, iplot
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial

init_notebook_mode(connected=True)

Device: cuda


### Model

In [205]:
model_name = 'gemma-2b'

model = HookedTransformer.from_pretrained(model_name, device=device)

model.eval()
model.set_use_attn_result(True)
model.set_use_attn_in(True)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


### Tasks

In [206]:
import json

with open("tasks.json", "r") as f:
    tasks = json.load(f)['tasks']

In [207]:
idx = 1

slices = []

pos = 0
train_prompt = ""
for i, (inp, out) in enumerate(zip(tasks[idx]['train_input'], tasks[idx]['train_output'])):
    train_prompt += inp + '\n' + out + "\n\n"
    pos += len(model.to_tokens(inp, prepend_bos=False)[0]) + 1
    if i != 0:
        slices.append((pos, pos + len(model.to_tokens(out, prepend_bos=False)[0])))
    pos += len(model.to_tokens(out, prepend_bos=False)[0]) + 1 + int(model_name == 'gpt2')

test_prompt = ""
for inp, out in zip(tasks[idx]['test_input'], tasks[idx]['test_output']):
    test_prompt += inp + '\n' + out + "\n\n"

In [208]:
idxs = []
for s in slices:
    idxs.extend(range(s[0], s[1]))
idxs = torch.tensor(idxs, requires_grad=False, device=device)

for i, c in enumerate(model.to_str_tokens(train_prompt)):
    if i in idxs:
        print('\033[1m' + c + '\033[0m', sep='', end='')
    else:
        print(c, sep='', end='')

<bos><html> <h1> <p>
</p> </h1> </html>

<html> <h1> <div>[1m
[0m[1m</[0m[1mdiv[0m[1m>[0m[1m [0m[1m</h1>[0m[1m </[0m[1mhtml[0m>

<h1> <div> <p>[1m
[0m[1m</[0m[1mp[0m[1m>[0m[1m </[0m[1mdiv[0m[1m>[0m[1m [0m</h1>

<html> <h1> <p> <div>[1m
[0m[1m</[0m[1mdiv[0m[1m>[0m[1m </[0m[1mp[0m[1m>[0m[1m [0m[1m</h1>[0m[1m </[0m[1mhtml[0m>

<h1> <h2> <div> <p>[1m
[0m[1m</[0m[1mp[0m[1m>[0m[1m </[0m[1mdiv[0m[1m>[0m[1m [0m[1m</h2>[0m[1m [0m</h1>



In [186]:
print(list(enumerate(model.to_str_tokens(train_prompt))))

[(0, '<|endoftext|>'), (1, '<'), (2, 'html'), (3, '>'), (4, ' <'), (5, 'h'), (6, '1'), (7, '>'), (8, ' <'), (9, 'p'), (10, '>'), (11, '\n'), (12, '</'), (13, 'p'), (14, '>'), (15, ' </'), (16, 'h'), (17, '1'), (18, '>'), (19, ' </'), (20, 'html'), (21, '>'), (22, '\n'), (23, '\n'), (24, '<'), (25, 'html'), (26, '>'), (27, ' <'), (28, 'h'), (29, '1'), (30, '>'), (31, ' <'), (32, 'div'), (33, '>'), (34, '\n'), (35, '</'), (36, 'div'), (37, '>'), (38, ' </'), (39, 'h'), (40, '1'), (41, '>'), (42, ' </'), (43, 'html'), (44, '>'), (45, '\n'), (46, '\n'), (47, '<'), (48, 'h'), (49, '1'), (50, '>'), (51, ' <'), (52, 'div'), (53, '>'), (54, ' <'), (55, 'p'), (56, '>'), (57, '\n'), (58, '</'), (59, 'p'), (60, '>'), (61, ' </'), (62, 'div'), (63, '>'), (64, ' </'), (65, 'h'), (66, '1'), (67, '>'), (68, '\n'), (69, '\n'), (70, '<'), (71, 'html'), (72, '>'), (73, ' <'), (74, 'h'), (75, '1'), (76, '>'), (77, ' <'), (78, 'p'), (79, '>'), (80, ' <'), (81, 'div'), (82, '>'), (83, '\n'), (84, '</'), (8

### Boosting

In [187]:
def head_modifier_hook(x, hook, lam):
    x = lam[None, None, :, None] * x # b pos head dim
    return x

def mlp_modifier_hook(x, hook, gam):
    x = gam[None, :, None] * x # b pos dim
    return x

In [196]:
# Tuned lambdas

start_layer = 6

lambdas = torch.nn.Parameter(torch.ones(
    (model.cfg.n_layers - start_layer, model.cfg.n_heads), device=device), requires_grad=True
)

gammas = torch.nn.Parameter(torch.ones(
    (model.cfg.n_layers - start_layer), device=device), requires_grad=True
)

optimizer = torch.optim.Adam([lambdas], lr=0.1) 

for param in model.parameters():
    param.requires_grad = False

tokens = model.to_tokens(train_prompt)
labels = tokens[:, idxs+1]

In [197]:
from torch.nn import CrossEntropyLoss

loss_fn = CrossEntropyLoss()
losses = []
l1_coefficient = 0.001  # Set this to regulate l1 penalty

for e in tqdm(range(100)):
    proba = model.run_with_hooks(
                tokens,
                fwd_hooks=[(
                        f"blocks.{l}.attn.hook_result",
                        partial(head_modifier_hook, lam=lambdas[l - start_layer]),
                    ) for l in range(start_layer, model.cfg.n_layers)] + 
                    [(
                        f"blocks.{l}.hook_mlp_out",
                        partial(mlp_modifier_hook, gam=gammas[None, l - start_layer]),
                    ) for l in range(start_layer, model.cfg.n_layers)]
            ).softmax(-1)[:, idxs]
    
    loss = loss_fn(proba.view(-1, model.cfg.d_vocab), labels.view(-1))
    
    # Add the L1 regularization term to the loss
    l1_norm = lambdas.abs().sum() + gammas.abs().sum()
    loss += l1_coefficient * l1_norm

    losses.append(loss.item())

    optimizer.zero_grad()  # Clear previous gradients
    loss.backward()        # Compute gradients for `lambdas`
    optimizer.step()

100%|██████████| 100/100 [00:05<00:00, 16.70it/s]


In [198]:
import plotly.express as px

fig = px.line(y=losses, title='Loss')
fig.write_html('loss.html')

In [199]:
import plotly.express as px
import pandas as pd
import numpy as np

data = lambdas.detach().cpu().numpy()
fig = px.imshow(data,
                labels=dict(x="Heads", y="Layers", color="Lambda"),
                title="Lambda values", aspect='auto', color_continuous_scale='RdBu', zmin=-2, zmax=2)

# Update the heatmap to show the annotations
fig.update_layout(
    height=800,
    yaxis=dict(
        tickmode='array', 
        tickvals=list(range(model.cfg.n_layers - start_layer)),
        ticktext=list(range(start_layer, model.cfg.n_layers))
    )
)
fig.update_traces(showscale=True)

# Round the annotations to 3 decimal places
for i in range(model.cfg.n_heads):
    for j in range(model.cfg.n_layers - start_layer):
        fig.add_annotation(dict(font=dict(color="black",size=12),
                                x=i,
                                y=j,
                                text=str(round(data[j, i], 3)),
                                showarrow=False,
                                align='center',
                                opacity=0.6))

fig.write_html('lambdas.html')
#fig.show()

In [200]:
layer_id = 9
head_id = 7

with torch.no_grad():
    _, cache = model.run_with_cache(model.to_tokens(train_prompt))

data = cache[f'blocks.{layer_id}.attn.hook_pattern'][0, head_id].cpu()

labels = [f"{tok} ({i})" for i, tok in enumerate(model.to_str_tokens(train_prompt))]

# Create the plot using Plotly Express
fig = px.imshow(
    data,
    labels=dict(x="Keys", y="Queries", color="Attention Score"),
    x=labels,
    y=labels,
    title=f'Attention patter at head {head_id} of layer {layer_id}',
    color_continuous_scale="Blues",
    aspect='auto'
)

# Adjust the layout for better readability
fig.update_xaxes(tickangle=35)
fig.update_layout(coloraxis_colorbar=dict(title="Score"), height=800)
fig.write_html('pattern.html')

In [201]:
gammas

Parameter containing:
tensor([1., 1., 1., 1., 1., 1.], device='cuda:0', requires_grad=True)

### Testing

In [202]:
print(model.generate(test_prompt, stop_at_eos=False, temperature=0, max_new_tokens=32))

  0%|          | 0/32 [00:00<?, ?it/s]

<html> <h1> <p>
</p> </h1> </html>

<h1> <div> <p>
</p> </div> </h1>

<html> <h1> <p> <div>
</div> </p> </h1> </html>

<h1> <h2> <div> <p>
</p> </div> </h2> </h1>

<h1> <h2> <h3> <h4>


<p>

</p> </h2> </h1> </html>

<h1> <h2> <p>


In [203]:
tokens = model.to_tokens(test_prompt)
max_new_tokens = 32

for i in tqdm(range(max_new_tokens)): 
    with torch.no_grad():
        new_tok = model.run_with_hooks(
            tokens,
            fwd_hooks=[
                (
                    f"blocks.{l}.attn.hook_result",
                    partial(head_modifier_hook, lam=lambdas[l - start_layer]),
                ) for l in range(start_layer, model.cfg.n_layers)
            ]
        ).argmax(-1)[:, -1, None]

    tokens = torch.cat([tokens, new_tok], dim=-1)

print(model.to_string(tokens)[0])

100%|██████████| 32/32 [00:00<00:00, 42.18it/s]

<|endoftext|><html> <h1> <p>
</p> </h1> </html>

<h1> <div> <p>
</p> </div> </h1>

<html> <h1> <p> <div>
</div> </p> </h1> </html>

<h1> <h2> <div> <p>
</p> </div> </h2> </h1>

<h1> <h2> <h3> <h4>


<h1> <h3> <h3> <h3> <h3> <h3> <h3> <h3>



