In [2]:
from nnsight import LanguageModel
from nnsight.models.UnifiedTransformer import UnifiedTransformer

import sys

sys.path.append("../..")

from nngine import alter

import torch 
import einops

tl_model = UnifiedTransformer(
    'gpt2-small',
    processing=False,
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device="cuda",
)

tl_model.set_use_hook_mlp_in(True)
tl_model.set_use_split_qkv_input(True)
tl_model.set_use_attn_result(True)

tokenizer = tl_model.tokenizer

model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)

def test_resolution(ground_truth, sample):
    for resolution in [1e-12, 1e-10, 1e-6, 1e-4, 1e-3, 1e-2, 1.0]:
        pct = (sample - ground_truth > resolution).float().mean().item()
        print(f'pct out of range {resolution:.1e} = {pct:.2%}')

Loaded pretrained model gpt2-small into HookedTransformer


In [53]:
clean = tokenizer.encode("When Alan and Alex got a drink at the store, Alex decided to give it to")

with model.trace(clean):

    resid_pre = model.transformer.h[0].input[0][0].grad.save()

    attn_out = model.transformer.h[0].attn.output[0].grad.save()

    attn_out_act = model.transformer.h[0].attn.output[0].save()

    resid_mid = model.transformer.h[0].ln_2.input[0][0].clone()
    
    resid_mid.grad.save()

    logits = model.output.logits.save()
    value = logits[:,-1,:].sum()
    value.backward()

with tl_model.trace(clean):

    tl_resid_pre = tl_model.blocks[0].hook_resid_pre.input[0][0].grad.save()

    tl_attn_out = tl_model.blocks[0].hook_attn_out.output.grad.save()

    tl_resid_mid = tl_model.blocks[0].hook_resid_mid.output.grad.save()

    tl_mlp_in = tl_model.blocks[0].hook_mlp_in.output.grad.save()

    tl_logits = tl_model.output.save()
    tl_value = tl_logits[:,-1,:].sum()
    tl_value.backward()

In [32]:
clean = tokenizer.encode("When Alan and Alex got a drink at the store, Alex decided to give it to")

resid_pre = model.transformer.h[0]
attn_out = model.transformer.h[0].attn
mlp = model.transformer.h[0].mlp
ln_2 = model.transformer.h[0].ln_2

with model.trace(clean):

    resid_mid = resid_pre.input[0][0] + attn_out.output[0]

    mlp.input[0][0][:] = ln_2(resid_mid)

    # Save mlp grad here
    mlp_in = resid_mid.grad.save()

    logits = model.output.logits.save()
    value = logits[:,-1,:].sum()
    value.backward()

with model.trace(clean):

    # This should be equivalent to mlp in
    ln_2_grad = ln_2.input[0][0].grad.save()

    logits = model.output.logits.save()
    value = logits[:,-1,:].sum()
    value.backward()

with tl_model.trace(clean):

    tl_resid_mid = tl_model.blocks[0].hook_resid_mid.output.save()

    tl_mlp_in = tl_model.blocks[0].hook_mlp_in.output.grad.save()

    tl_logits = tl_model.output.save()
    tl_value = tl_logits[:,-1,:].sum()
    tl_value.backward()