In [1]:
import torch
from transformer_lens import HookedTransformer
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [25]:
cfact = load_dataset("azhx/counterfact")
wikitext = load_dataset("Salesforce/wikitext", "wikitext-2-v1")
wikitext["train"] = wikitext["train"].filter(lambda row: len(row["text"]) > 50)

In [32]:
model = HookedTransformer.from_pretrained("Qwen/Qwen2-1.5B", default_padding_side="left")
model.tokenizer.padding_side = "left"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



Loaded pretrained model Qwen/Qwen2-1.5B into HookedTransformer


In [19]:
model.cfg

HookedTransformerConfig:
{'act_fn': 'silu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 11.313708498984761,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 128,
 'd_mlp': 8960,
 'd_model': 1536,
 'd_vocab': 151936,
 'd_vocab_out': 151936,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-06,
 'experts_per_token': None,
 'final_rms': True,
 'from_checkpoint': False,
 'gated_mlp': True,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02,
 'load_in_4bit': False,
 'model_name': 'Qwen2-1.5B',
 'n_ctx': 2048,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': 2,
 'n_layers': 28,
 'n_params': 1420296192,
 'normalization_type': 'RMSPre',
 'num_experts': None,
 'original_architecture': 'Qwen2ForCausalLM',
 'output_logits_soft_cap': -1.0,
 'parallel_attn_mlp': False,
 'po

In [20]:
model.hook_dict

{'hook_embed': HookPoint(),
 'blocks.0.ln1.hook_scale': HookPoint(),
 'blocks.0.ln1.hook_normalized': HookPoint(),
 'blocks.0.ln2.hook_scale': HookPoint(),
 'blocks.0.ln2.hook_normalized': HookPoint(),
 'blocks.0.attn.hook_k': HookPoint(),
 'blocks.0.attn.hook_q': HookPoint(),
 'blocks.0.attn.hook_v': HookPoint(),
 'blocks.0.attn.hook_z': HookPoint(),
 'blocks.0.attn.hook_attn_scores': HookPoint(),
 'blocks.0.attn.hook_pattern': HookPoint(),
 'blocks.0.attn.hook_result': HookPoint(),
 'blocks.0.attn.hook_rot_k': HookPoint(),
 'blocks.0.attn.hook_rot_q': HookPoint(),
 'blocks.0.mlp.hook_pre': HookPoint(),
 'blocks.0.mlp.hook_pre_linear': HookPoint(),
 'blocks.0.mlp.hook_post': HookPoint(),
 'blocks.0.hook_attn_in': HookPoint(),
 'blocks.0.hook_q_input': HookPoint(),
 'blocks.0.hook_k_input': HookPoint(),
 'blocks.0.hook_v_input': HookPoint(),
 'blocks.0.hook_mlp_in': HookPoint(),
 'blocks.0.hook_attn_out': HookPoint(),
 'blocks.0.hook_mlp_out': HookPoint(),
 'blocks.0.hook_resid_pre': H

In [44]:
torch.tensor(model.tokenizer.bos_token_id)

151643

In [33]:
from jaxtyping import Float
from torch import Tensor
from transformer_lens.HookedTransformer import HookPoint
import einops
from dataclasses import dataclass

@dataclass
class EditRequest:
    prompt: str
    subject: str
    target_true: str
    target_new: str

# TODO: Make these layer specific functions:

def get_key_cov(model: HookedTransformer, wiki_dataset):
    ''' Estimate E[KK^T] using 100k samples from wiki_dataset. '''
    K = torch.zeros(model.cfg.d_mlp, device=device)

    def collect_ks_hook(act: Float[Tensor, "batch seq d_mlp"], hook: HookPoint):
        nonlocal K
        K += einops.reduce(act, "batch seq d_mlp -> d_mlp", "sum")
        return act

    num_samples = 0
    with torch.no_grad():
        for text in wiki_dataset["train"]["text"]:
            if num_samples >= 100_000:
                break
            toks = model.tokenizer(text, return_tensors="pt", max_length=10_000, truncation=True, padding=False)["input_ids"]
            num_samples += toks.shape[1]
            model.run_with_hooks(
                toks,
                fwd_hooks = [(lambda name: "mlp.hook_post" in name, collect_ks_hook)]
            )
    K /= num_samples
    return einops.einsum(K, K, "i, j -> i j")

def correctness_filter(model, dataset, verbose=False):
    ''' Filter out any examples that the model gets wrong. '''
    def get_correctness(model, examples):
        ''' Populate dataset with first token of correct answer and model answer '''
        true_string = [' ' + r['target_true']['str'] for r in examples['requested_rewrite']]
        edit_string = [' ' + r['target_new']['str'] for r in examples['requested_rewrite']]
        question_string = [r['prompt'].format(r['subject']) for r in examples['requested_rewrite']]

        orig_padding_side = model.tokenizer.padding_side
        model.tokenizer.padding_side = "right"
        correct_tokens = list(model.tokenizer(true_string, return_tensors="pt", padding=True)["input_ids"][:, 0])
        edit_tokens = model.tokenizer(edit_string, return_tensors="pt", padding=True)["input_ids"][:, 0]
        model.tokenizer.padding_side = orig_padding_side

        question_tokens = model.tokenizer(question_string, return_tensors="pt", padding=True)["input_ids"]
        scaled_logits = torch.nn.functional.softmax(model(question_tokens)[:, -1, :], dim=-1)
        is_model_correct = list(scaled_logits[range(len(scaled_logits)), correct_tokens] > scaled_logits[range(len(scaled_logits)), edit_tokens])
        if verbose:
            print(f"String: {true_string[0]} tokenized as {correct_tokens[0]}, model is correct? {is_model_correct[0]}")

        return {"correct_token": correct_tokens, "is_model_correct": is_model_correct}
    
    with torch.no_grad():
        dataset = dataset.map(
            lambda row: get_correctness(model, row),
            batched=True,
            batch_size=1_000
        )
        dataset = dataset.filter(lambda x: x["is_model_correct"])

    assert all(cfact["train"]["is_model_correct"]), "Filter failed, model is not correct on all examples"

    return dataset

def get_subject_representations(model, edit_request: EditRequest):
    ''' Get the subject representations k* for the fact. '''

    K = torch.zeros(model.cfg.d_mlp, device=device)

    def collect_ks_hook(act: Float[Tensor, "batch seq d_mlp"], hook: HookPoint):
        nonlocal K
        K += einops.reduce(act[:, -1, :], "batch d_mlp -> d_mlp", "sum")
        return act

    start_input = torch.zeros((10, 1), device=device, dtype=torch.int)
    start_input[:, 0] = model.tokenizer.bos_token_id

    with torch.no_grad():
        # According to E.5 they use 10 prefixes of length 5 and 10 prefixes of length 10
        prefix_5 = model.tokenizer.batch_decode(model.generate(start_input, max_new_tokens=5))
        prefix_10 = model.tokenizer.batch_decode(model.generate(start_input, max_new_tokens=10))
        subject_prompts = [p + f". {edit_request.subject}" for p in prefix_5 + prefix_10]
        subject_toks = model.tokenizer(subject_prompts, return_tensors="pt", padding=True, add_special_tokens=False)["input_ids"]
        model.run_with_hooks(
            subject_toks,
            fwd_hooks = [(lambda name: "mlp.hook_post" in name, collect_ks_hook)]
        )
    return K / len(subject_prompts)


def get_value_representations(model, edit_request):
    ''' Get the value representation v* for a fact in the dataset. '''
    model = model.eval()  # Disable gradients on the model, we are optimizing a new vector here

    start_input = torch.zeros((10, 1), device=device, dtype=torch.int)
    start_input[:, 0] = model.tokenizer.bos_token_id

    prefix_5 = model.tokenizer.batch_decode(model.generate(start_input, max_new_tokens=5))
    prefix_10 = model.tokenizer.batch_decode(model.generate(start_input, max_new_tokens=10))



def get_rome_edit(model, edit_request, cov_dataset):
    ''' Implement the ROME edit function. 
        W_hat = W + A(C^-1 * k_star)^T, where:
        k_star = E[mlp_out(x_i + subject)], x_i is random prefix
        C = covariance matrix of keys
        A = (v_star - W * k_star) / (C^-1 * k_star)^T * k_star
        v_star optimizes log probability of outputting correct object and KL with original model
    '''
    k_star = get_subject_representations(model, edit_request)
    v_star = get_value_representations(model, edit_request)
    C = get_key_cov(model, cov_dataset)
    A = (v_star - model.W * k_star) / torch.matmul(torch.linalg.inv(C), k_star).T * k_star

    return torch.matmul(A, torch.linalg.inv(C).T)

In [51]:
model.tokenizer.encode(" Isaac Newton", add_special_tokens=False, return_tensors="pt")

tensor([[41508, 20668]])

In [34]:
cov = get_key_cov(model, wikitext)

In [57]:
start_input = torch.zeros((10, 1), device=device, dtype=torch.int)
start_input[:, 0] = model.tokenizer.bos_token_id
model.tokenizer.batch_decode(model.generate(start_input, max_new_tokens=5))

100%|██████████| 5/5 [00:00<00:00,  9.00it/s]


['<|endoftext|>什么是切断面部神经\n',
 '<|endoftext|>import boto3\nimport',
 '<|endoftext|>Imperialist Adm',
 '<|endoftext|>F Commands with MATLAB R',
 '<|endoftext|>当我们选择了好友们，',
 '<|endoftext|><?php\n// This',
 '<|endoftext|>Human: There is a',
 '<|endoftext|>Before this question has a',
 '<|endoftext|>接下来病历评价 �',
 '<|endoftext|>【答案】【答案']

In [7]:
cfact["train"] = correctness_filter(model, cfact["train"], verbose=True)


Map:   5%|▌         | 1000/19728 [00:05<01:43, 180.86 examples/s]

String:  French tokenized as 8585, model is correct? True


Map:  10%|█         | 2000/19728 [00:11<01:41, 174.37 examples/s]

String:  Finnish tokenized as 57853, model is correct? False


Map:  15%|█▌        | 3000/19728 [00:15<01:26, 193.13 examples/s]

String:  biology tokenized as 33358, model is correct? True


Map:  20%|██        | 4000/19728 [00:21<01:21, 194.08 examples/s]

String:  Moscow tokenized as 22415, model is correct? True


Map:  25%|██▌       | 5000/19728 [00:26<01:18, 187.31 examples/s]

String:  French tokenized as 8585, model is correct? True


Map:  30%|███       | 6000/19728 [00:31<01:11, 193.20 examples/s]

String:  bishop tokenized as 53206, model is correct? True


Map:  35%|███▌      | 7000/19728 [00:36<01:02, 202.53 examples/s]

String:  Europe tokenized as 4505, model is correct? True


Map:  41%|████      | 8000/19728 [00:41<00:58, 199.48 examples/s]

String:  Nokia tokenized as 35706, model is correct? True


Map:  46%|████▌     | 9000/19728 [00:46<00:54, 195.48 examples/s]

String:  goaltender tokenized as 79622, model is correct? False


Map:  51%|█████     | 10000/19728 [00:51<00:49, 198.33 examples/s]

String:  Microsoft tokenized as 5100, model is correct? True


Map:  56%|█████▌    | 11000/19728 [00:57<00:45, 191.84 examples/s]

String:  Philadelphia tokenized as 19335, model is correct? True


Map:  61%|██████    | 12000/19728 [01:01<00:39, 195.39 examples/s]

String:  Germany tokenized as 9856, model is correct? True


Map:  66%|██████▌   | 13000/19728 [01:08<00:38, 176.90 examples/s]

String:  Spanish tokenized as 15154, model is correct? True


Map:  71%|███████   | 14000/19728 [01:14<00:33, 173.24 examples/s]

String:  Kerala tokenized as 60407, model is correct? True


Map:  76%|███████▌  | 15000/19728 [01:19<00:26, 180.92 examples/s]

String:  Oslo tokenized as 57858, model is correct? False


Map:  81%|████████  | 16000/19728 [01:25<00:21, 174.58 examples/s]

String:  Sweden tokenized as 23190, model is correct? True


Map:  86%|████████▌ | 17000/19728 [01:30<00:14, 185.45 examples/s]

String:  Paris tokenized as 12095, model is correct? True


Map:  91%|█████████ | 18000/19728 [01:35<00:09, 187.38 examples/s]

String:  Amsterdam tokenized as 37741, model is correct? False


Map:  96%|█████████▋| 19000/19728 [01:41<00:03, 188.50 examples/s]

String:  actor tokenized as 12089, model is correct? True


Map: 100%|██████████| 19728/19728 [01:45<00:00, 187.52 examples/s]


String:  Antarctica tokenized as 71687, model is correct? False


Filter: 100%|██████████| 19728/19728 [00:02<00:00, 9113.44 examples/s]
