In [1]:
from nnsight import LanguageModel

model = LanguageModel("openai-community/gpt2", device_map="auto")

In [2]:
from 

In [3]:
def fetch_sub_envoy(envoy, target):
    if envoy._module_path == target:
        return envoy
    
    for sub in envoy._sub_envoys:
        result = fetch_sub_envoy(sub, target)  
        if result is not None:
            return result

import einops

def create_fn_hook(envoy):
    split = lambda x: einops.rearrange(x, "batch (heads d qkv) -> batch heads d qkv", qkv=3, d=768)
    revert = lambda x: einops.rearrange(x, "batch heads d qkv -> batch (heads d qkv)", qkv=3, d=768)

    hook = FnEnvoy(
        envoy,
        split,
        revert,
    )

    return hook

for layer_idx in range(12): 
    envoy_path = f".transformer.h.{layer_idx}.attn.c_attn"
    envoy = fetch_sub_envoy(model._envoy, envoy_path)
    hook = create_fn_hook(envoy)

    parent_path = f".transformer.h.{layer_idx}.attn"
    parent = envoy = fetch_sub_envoy(model._envoy, parent_path)
    setattr(parent, "qkv", hook)

In [30]:
import einops


def split_qkv(envoy):

    split = lambda x: einops.rearrange(x, "batch (heads d qkv) -> batch heads d qkv", qkv=3, d=768)
    revert = lambda x: einops.rearrange(x, "batch heads d qkv -> batch (heads d qkv)", qkv=3, d=768)

    wrapped = envoy._tracer._graph.add(
        target=split,
        args=[
            envoy.output[0]
        ]
    )

    reverted = envoy._tracer._graph.add(
        target=revert,
        args=[
            wrapped
        ]
    )
    
    return reverted

with model.trace() as tracer:
    with tracer.invoke("hello"):
        out = model.transformer.h[0].attn.c_attn
        q = split_qkv(out)

        q = q * 1000
        # model.transformer.h[0].attn.c_attn.output *= 0.

        edited = model.transformer.h[1].attn.input[0][0].save()
    with tracer.invoke("hello"):
        clean = model.transformer.h[1].attn.input[0][0].save()

In [31]:
def test_resolution(ground_truth, sample):
    for resolution in [1e-12, 1e-10, 1e-6, 1e-4, 1e-3, 1e-2, 1.0]:
        pct = (sample - ground_truth > resolution).float().mean().item()
        print(f'pct out of range {resolution:.1e} = {pct:.2%}')

test_resolution(clean, edited)

pct out of range 1.0e-12 = 0.00%
pct out of range 1.0e-10 = 0.00%
pct out of range 1.0e-06 = 0.00%
pct out of range 1.0e-04 = 0.00%
pct out of range 1.0e-03 = 0.00%
pct out of range 1.0e-02 = 0.00%
pct out of range 1.0e+00 = 0.00%


In [3]:
def fetch_sub_envoy(envoy, target):
    if envoy._module_path == target:
        return envoy
    
    for sub in envoy._sub_envoys:
        result = fetch_sub_envoy(sub, target)  
        if result is not None:
            return result

import einops

def create_fn_hook(envoy, target):
    rearrange = lambda x : einops.rearrange(x, "batch seq (heads head_dim) -> batch heads seq head_dim", heads=12)
    revert = lambda x : (einops.rearrange(x, "batch heads seq head_dim -> batch seq (heads head_dim)"),)

    hook = FnEnvoy(envoy, target, rearrange, revert)

    return hook

In [4]:
for layer_idx in range(12): 
    
    envoy_path = f".transformer.h.{layer_idx}.attn.c_proj"
    target_path = f".transformer.h.{layer_idx + 1}.attn.c_proj"

    envoy = fetch_sub_envoy(model, envoy_path)
    target = fetch_sub_envoy(model, target_path)

    hook = create_fn_hook(envoy, target)

    setattr(envoy, "heads", hook)

In [5]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (embed): Embedding(50257, 768)
    (pos_embed): Embedding(1024, 768)
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D(
            (heads): Placeholder
          )
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (unembed): Linear(in_features=768, out_features=50257, bias=False)
  (generator): WrapperModule()
)

In [7]:
with model.trace("nn") as tracer:
    
    test = model.transformer.layers[0].attn.c_proj.heads.output().save()

print(test.shape)

torch.Size([1, 12, 1, 64])
