In [1]:
from nnsight import LanguageModel
import torch
from typing import List

model = LanguageModel("openai-community/gpt2", device_map="cuda:0", dispatch=True)

with model.trace("a"):
    h_0_input = model.transformer.h[0].input.save()
    attn_input = model.transformer.h[0].attn.input.save()
    
    value_to_comapre = model.transformer.h[0].attn.output.save()

class Edit:

    def __init__(self, 
        parent: str, 
        target: str, 
        key: str, 
        replacement: torch.nn.Module,
    ) -> None:
        self.parent = parent
        self.target = target
        self.key = key
        self.replacement = replacement

    def __str__(self) -> str:
        return f"{self.parent}.{self.target} -> {self.key}"
    
    def __repr__(self) -> str:
        return self.__str__()

class WrapperModule(torch.nn.Module):
    """Simple torch module which passes it's input through. Useful for hooking.
    If there is only one argument, returns the first element.
    """

    def forward(self, *args, **kwargs):
        if len(args) == 1:
            args = args[0]

        value = args * 10000000
    
        print("stuff is going through this node!")

        return value

edit = Edit(
    "transformer.h.0.attn", 
    "query", 
    "query_wrapper",
    WrapperModule()
)
    

def edited_backend(gm: torch.fx.GraphModule, _: List[torch.Tensor]):

    gm.add_submodule(edit.key, edit.replacement)

    for node in gm.graph.nodes:
        
        if node.name == "query":
            with gm.graph.inserting_before(node):
                new = gm.graph.create_node(node.op, node.target, args=node.args, kwargs=node.kwargs, name="original_" + node.name)
                wrapper_node = gm.graph.call_module(edit.key, args=(new,))
                node.replace_all_uses_with(wrapper_node)
                gm.graph.erase_node(node)


    gm.recompile()

    # gm.graph.print_tabular()

    return gm.forward

mod = model._model.transformer.h[0].attn

setattr(mod, edit.key, edit.replacement)

torch._dynamo.reset()

opt_model = torch.compile(model._model.transformer.h[0].attn, backend=edited_backend, dynamic=True)
gm = opt_model(attn_input.value[0][0])

model._model.transformer.h[0].attn = opt_model

from nnsight.envoy import Envoy
model._envoy = Envoy(model._model)

with model.trace("empty", scan=False, validate=False):
    model.transformer.h[0].attn._orig_mod.query_wrapper.output *= 100
    out_two = model.transformer.h[0].attn.output.save()

value_to_comapre.value[0] == out_two.value[0]

  from .autonotebook import tqdm as notebook_tqdm
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


stuff is going through this node!
stuff is going through this node!


tensor([[[False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, Fa

In [8]:
value_to_comapre

(tensor([[[-1.0461e-01, -3.2664e-02,  9.7539e-02, -4.1179e-02,  8.3337e-02,
          -9.8935e-02, -5.4285e-02,  4.3335e-01, -1.2695e-01, -2.3745e-02,
          -3.7920e-01, -8.8503e-03, -9.3370e-02,  3.1767e-02,  3.0472e-02,
          -2.1082e-01, -1.5864e-01, -5.7906e-02,  3.8972e-02, -6.7503e-01,
          -1.7552e-02, -1.1526e-01,  1.1738e-02, -2.2705e-01, -1.9910e-02,
           3.1568e-02, -4.6815e-01,  4.6357e-02, -3.2329e-02, -2.6866e-02,
          -2.7751e-02, -8.3548e-03,  2.0700e-02, -1.8181e-02, -8.4641e-02,
          -2.3391e-01,  1.6156e-01,  2.1596e-04,  4.7437e-02,  7.1068e-02,
           1.7528e-02, -2.0074e-03,  5.9652e-03, -7.7801e-03,  3.0390e-02,
           2.5477e-04,  5.5166e-02,  5.1722e-03, -4.4916e-01,  8.3837e-02,
           1.0899e-02, -8.1486e-02,  8.5085e-01,  1.5919e-02, -1.0070e-01,
           9.8739e-01,  5.7250e-03, -4.3184e-01,  3.9712e-02, -1.2233e-01,
          -1.0468e-01,  1.3469e-04, -3.8979e-01,  4.3021e-01,  2.2164e+00,
           9.9928e-03,  

In [7]:
out_two

(tensor([[[ 1.1373e+00,  5.6976e-01, -8.9985e-02,  1.0620e-02, -3.9465e-03,
          -1.1492e-01,  1.1256e+00, -8.3002e-01,  4.3760e-01, -1.7915e-02,
           9.1425e-02,  3.8186e-02, -6.9904e-02, -4.8988e-02, -4.7959e-01,
           6.4840e-01,  1.0849e+00,  4.6035e-02, -4.1320e-02, -4.4889e-01,
          -2.8605e-02, -6.3296e-03,  1.2247e-01,  1.0062e+00, -1.9515e-02,
          -7.7060e-02,  1.5860e-01,  5.9957e-02, -1.8366e-02, -3.1568e-03,
          -9.9951e-02, -6.6356e-02,  7.9214e-02, -3.0222e-02, -1.5563e-01,
           7.6013e-01, -7.5835e-02,  1.1890e-01,  7.7665e-02,  1.3501e-01,
          -5.5514e-02, -1.1477e-02, -1.3684e-02,  2.2160e-02,  2.2982e-01,
          -4.5196e-02, -2.0880e-02,  1.7007e-02,  1.2033e+00,  2.0139e-01,
           4.9523e-02,  4.4231e-03,  5.3443e-01,  1.4307e-02,  2.0237e-01,
           8.2740e-01,  7.9362e-02, -7.3116e-02,  1.3923e-01,  1.3653e+00,
          -1.7305e+00, -1.2357e-02, -1.1265e+00,  8.6180e-01,  1.0459e+00,
          -1.0681e-02,  