In [1]:
import pickle
from nnsight import LanguageModel
import torch
import plotly.express as px

In [2]:
model = LanguageModel("meta-llama/Llama-2-7b-hf", device_map="auto",torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
tokenizer = model.tokenizer

In [3]:
def conditional_probability(tokens, logits):
    probs = torch.gather(logits.cpu().softmax(-1), -1, tokens)
    return torch.prod(probs, dim=1)

In [4]:
cot = ["When John and Mary went to the shops, John gave the bag to Mary", "When John and Mary went to the shops, Mary gave the bag to John"]
corrupted_cot = ["When John and Mary went to the shops, Mary gave the bag to John", "When John and Mary went to the shops, John gave the bag to Mary"]

In [5]:
with model.invoke(cot[:2]) as invoker_clean:
    pass

clean_tokens = tokenizer(cot, return_tensors="pt", padding=True)
clean_tokens = clean_tokens.input_ids[:,-1:]

clean_probs = conditional_probability(clean_tokens, invoker_clean.output.logits[:,-2,:])

del invoker_clean.output.logits
del invoker_clean

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [6]:
clean_probs

tensor([0.8083, 0.4268])

In [24]:
def patching_metric(corrupted_logits, corrupted_tokens, clean_probs=clean_probs): 
    # corrupted_logits = corrupted_logits[:,-3:-1,:]

    corrupted_probs = conditional_probability(corrupted_tokens, corrupted_logits)
    print(corrupted_probs)
    return (clean_probs - corrupted_probs).mean()

In [19]:
clean_resid_out = []
corrupted_resid_out = []

metric = patching_metric

with model.invoke(cot[:2], fwd_args = {"inference": False}) as invoker_clean:
    for layer in model.model.layers:
        l = layer.output[0].save()
        l.retain_grad()
        clean_resid_out.append(l)

with model.invoke(corrupted_cot[:2], fwd_args = {"inference": False}) as invoker_corrupted:
    corrupted_logits = model.lm_head.output[:,-2,:].save()

    for layer in model.model.layers:
        l = layer.output[0].save()
        l.retain_grad()
        corrupted_resid_out.append(l)

In [16]:
corrupted_tokens = tokenizer(corrupted_cot, return_tensors="pt", padding=True)
corrupted_tokens = corrupted_tokens.input_ids[:,-1:]

In [22]:
corrupted_logits.shape

torch.Size([2, 32000])

In [25]:
value = patching_metric( corrupted_logits.value, corrupted_tokens)
value.backward()
value

tensor([0.4258, 0.8086], dtype=torch.bfloat16, grad_fn=<ProdBackward1>)


tensor(0.0003, grad_fn=<MeanBackward0>)

In [26]:
import einops

patching_results = []
for clean, corrupted in zip(clean_resid_out, corrupted_resid_out):
    residual_attr = einops.reduce(
        corrupted.value.grad * (clean.value - corrupted.value),
        "batch pos d_model -> pos",
        "sum"
    )
    patching_results.append(residual_attr.detach().cpu().to(dtype=torch.float32).numpy())

In [27]:
import plotly.express as px

token_ids = [tokenizer.decode(t) + "_"+str(i) for i, t in enumerate(tokenizer(cot[0]).input_ids)]
px.imshow(
    patching_results, 
    color_continuous_scale="RdBu", 
    color_continuous_midpoint=0.0, 
    title="Patching Results",
    x=token_ids,
)