In [1]:
import plotly.express as px

from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy

In [2]:
# Load gpt2
# model = LanguageModel("openai-community/gpt2", device_map="cuda:0")
model = LanguageModel("microsoft/phi-2", device_map="cuda:0")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
# clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
# corrupted_prompt = (
#     "After John and Mary went to the store, John gave a bottle of milk to"
# )

clean_prompt = "Output the value of an array at a given index.\na=[1, 5, 0, 7, 7]. a[3]="
corrupted_prompt = (
    "Output the value of an array at a given index.\na=[1, 5, 0, 4, 7]. a[3]="
)

correct_index = model.tokenizer("7")["input_ids"][0]
incorrect_index = model.tokenizer("4")["input_ids"][0]

print(f"'7': {correct_index}")
print(f"'4': {incorrect_index}")

# clean_prompt = "Output the value of an array at a given index.\na=[1, 5, 0, 7, 7]. a[3]="
# corrupted_prompt = (
#     "Output the value of an array at a given index.\na=[1, 5, 0, 7, 7]. a[3]="
# )

# correct_index = model.tokenizer("7")["input_ids"][0]
# incorrect_index = model.tokenizer("4")["input_ids"][0]

# print(f"'7': {correct_index}")
# print(f"'4': {incorrect_index}")

'7': 22
'4': 19


In [4]:
new_clean_tokens = model.tokenizer(clean_prompt)['input_ids']
corrupted_tokens = model.tokenizer(corrupted_prompt)['input_ids']

In [7]:
# N_LAYERS = model.config.n_layer
N_LAYERS = 32

# Enter nnsight tracing context
with model.trace() as tracer:

    # Clean run
    with tracer.invoke(clean_prompt) as invoker:
        clean_tokens = model.input[1]["input_ids"].squeeze().save()

        # Get hidden states of all layers in the network.
        # We index the output at 0 because it's a tuple where the first index is the hidden state.
        # No need to call .save() as we don't need the values after the run, just within the experiment run.
        clean_hs = [
            model.model.layers[layer_idx].output[0]
            for layer_idx in range(N_LAYERS)
        ]

        # Get logits from the lm_head.
        clean_logits = model.lm_head.output

        # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
        clean_logit_diff = (
            clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
        ).save()

    # Corrupted run
    with tracer.invoke(corrupted_prompt) as invoker:
        corrupted_logits = model.lm_head.output

        # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
        corrupted_logit_diff = (
            corrupted_logits[0, -1, correct_index]
            - corrupted_logits[0, -1, incorrect_index]
        ).save()

    ioi_patching_results = []

    # Iterate through all the layers
    for layer_idx in range(len(model.model.layers)):
        _ioi_patching_results = []

        # Iterate through all tokens
        for token_idx in range(len(new_clean_tokens)):

            # Patching corrupted run at given layer and token
            with tracer.invoke(corrupted_prompt) as invoker:

                # Apply the patch from the clean hidden states to the corrupted hidden states.
                model.model.layers[layer_idx].output[0].t[token_idx] = clean_hs[
                    layer_idx
                ].t[token_idx]

                patched_logits = model.lm_head.output

                patched_logit_diff = (
                    patched_logits[0, -1, correct_index]
                    - patched_logits[0, -1, incorrect_index]
                )

                # Calculate the improvement in the correct token after patching.
                patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                    clean_logit_diff - corrupted_logit_diff
                )

                _ioi_patching_results.append(patched_result.save())

        ioi_patching_results.append(_ioi_patching_results)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You're using a CodeGenTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


OutOfMemoryError: CUDA out of memory. Tried to allocate 984.00 MiB. GPU 

In [31]:
print(f"Clean logit difference: {clean_logit_diff.value:.3f}")
print(f"Corrupted logit difference: {corrupted_logit_diff.value:.3f}")

ioi_patching_results = util.apply(ioi_patching_results, lambda x: x.value.item(), Proxy)

clean_tokens = [model.tokenizer.decode(token) for token in new_clean_tokens]
token_labels = [f"{token}_{index}" for index, token in enumerate(new_clean_tokens)]

fig = px.imshow(
    ioi_patching_results,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    labels={"x": "Position", "y": "Layer"},
    x=token_labels,
    title="Normalized Logit Difference After Patching Residual Stream on the IOI Task",
)

fig.show()

Clean logit difference: -1.170
Corrupted logit difference: -1.590
