This notebook is for prototyping subnetwork flow.

### Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [24]:
import dataclasses

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim
import transformer_lens.utils as tl_utils
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

from pii import utils

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fa328501600>

In [3]:
from pii.subnet_flow.masked_model import PerTokenMaskedTransformer

### Load models

In [4]:
# You will need to login to huggingface first:
#   huggingface-cli login
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
hf_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    low_cpu_mem_usage=True,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
tl_model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    hf_model=hf_model,
    device="cuda",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer,
)
tl_model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)

Using pad_token, but it is not set yet.


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


  0%|          | 0/20 [00:00<?, ?it/s]

'The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions'

In [13]:
prompt = "The capital of Germany is"

mtf = PerTokenMaskedTransformer(tl_model, base_prompt=prompt)
print([n for n, p in mtf.named_parameters() if p.requires_grad])
for mask in mtf.masks.values():
    mask *= 0

print(
    mtf.tl_model.generate(
        "The capital of Germany is",
        max_new_tokens=20,
        temperature=0,
        use_past_kv_cache=False,  # This is important!
    )
)
mtf.masks_active = False
print(
    mtf.tl_model.generate(
        "The capital of Germany is",
        max_new_tokens=20,
        temperature=0,
        use_past_kv_cache=False,  # This is important!
    )
)

['masks.0_attn_z', 'masks.0_mlp_out', 'masks.10_attn_z', 'masks.10_mlp_out', 'masks.11_attn_z', 'masks.11_mlp_out', 'masks.12_attn_z', 'masks.12_mlp_out', 'masks.13_attn_z', 'masks.13_mlp_out', 'masks.14_attn_z', 'masks.14_mlp_out', 'masks.15_attn_z', 'masks.15_mlp_out', 'masks.16_attn_z', 'masks.16_mlp_out', 'masks.17_attn_z', 'masks.17_mlp_out', 'masks.18_attn_z', 'masks.18_mlp_out', 'masks.19_attn_z', 'masks.19_mlp_out', 'masks.1_attn_z', 'masks.1_mlp_out', 'masks.20_attn_z', 'masks.20_mlp_out', 'masks.21_attn_z', 'masks.21_mlp_out', 'masks.22_attn_z', 'masks.22_mlp_out', 'masks.23_attn_z', 'masks.23_mlp_out', 'masks.24_attn_z', 'masks.24_mlp_out', 'masks.25_attn_z', 'masks.25_mlp_out', 'masks.26_attn_z', 'masks.26_mlp_out', 'masks.27_attn_z', 'masks.27_mlp_out', 'masks.28_attn_z', 'masks.28_mlp_out', 'masks.29_attn_z', 'masks.29_mlp_out', 'masks.2_attn_z', 'masks.2_mlp_out', 'masks.30_attn_z', 'masks.30_mlp_out', 'masks.31_attn_z', 'masks.31_mlp_out', 'masks.3_attn_z', 'masks.3_mlp

  0%|          | 0/20 [00:00<?, ?it/s]

The capital of Germany isnt The The of The The The The The of The The The of The of The The The The


  0%|          | 0/20 [00:00<?, ?it/s]

The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions


### Perform optimization

In [23]:
def get_corr_prompt(question: str, correct: bool = True):
    return f"""\
[INST] <<SYS>>
You are an assistant who only responds with a single word. \
You only give {'correct' if correct else 'incorrect'} answers.
<</SYS>>

{question} [/INST] \
"""

question = "Nick Kyrgios, what sport does he play?"
prompt = get_corr_prompt(question, correct=True)
print(prompt)
utils.get_top_responses(
    prompt=prompt,
    model=tl_model,
    top_k=5,
)

[INST] <<SYS>>
You are an assistant who only responds with a single word. You only give correct answers.
<</SYS>>

Nick Kyrgios, what sport does he play? [/INST] 
Rank 0. Logit: 26.75 Prob: 99.86% Tokens: (29010) |Tennis|</s>|
Rank 1. Logit: 19.81 Prob:  0.10% Tokens: (  323) |T|EN|N|IS|</s>|
Rank 2. Logit: 19.00 Prob:  0.04% Tokens: (22556) |tennis|</s>|
Rank 3. Logit: 15.33 Prob:  0.00% Tokens: (27884) |ATP|</s>|
Rank 4. Logit: 15.17 Prob:  0.00% Tokens: (  376) |"|T|ennis|"|</s>|


In [21]:
correct_str = "Tennis"
correct_token = tl_model.to_tokens(correct_str)[0, 1].item()
correct_token

29010

In [None]:
mtf = PerTokenMaskedTransformer(tl_model, base_prompt=prompt)
prompt_tokens = tl_model.to_tokens(prompt)
optim = torch.optim.Adam(
    mtf.masks.parameters(),
    lr=1e-3,
)

for i in tqdm(range(100)):
    logits = mtf.tl_model.forward(prompt_tokens, return_type="logits")[0]
    final_logits = logits[0, -1]

    reg = mtf.masks.sum()
    loss = -final_logits[correct_token] + reg

    optim.zero_grad()
    loss.backward()
    optim.step()