In [14]:
from nnsight import LanguageModel
from nnsight.util import WrapperModule, fetch_and_set, fetch_attr
import torch
from typing import List

In [15]:
model = LanguageModel("openai-community/gpt2", device_map="cuda:0", dispatch=True)

with model.trace("a"):
    h_0_input = model.transformer.h[0].input.save()
    attn_input = model.transformer.h[0].attn.input.save()

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [16]:
class Edit:

    def __init__(self, 
        parent: str, 
        target: str, 
        key: str, 
        replacement: torch.nn.Module,
    ) -> None:
        self.parent = parent
        self.target = target
        self.key = key
        self.replacement = replacement

    def __str__(self) -> str:
        return f"{self.parent}.{self.target} -> {self.key}"
    
    def __repr__(self) -> str:
        return self.__str__()

In [17]:
class WrapperModule(torch.nn.Module):
    """Simple torch module which passes it's input through. Useful for hooking.
    If there is only one argument, returns the first element.
    """

    def forward(self, *args, **kwargs):
        if len(args) == 1:
            args = args[0]

        value = args * 10000000
    
        print("stuff is going through this node!")

        return value

wrapper_edit_one = Edit(
    "transformer.h.0.attn", 
    "query", 
    "query_wrapper",
    WrapperModule()
)
    
edits = [wrapper_edit_one]


In [19]:
def get_backend(
    edit_batch: List[Edit],
):  
    # wrapper_dict = {edit.key: edit.replacement for edit in edit_batch}
    target_dict = {edit.target: edit.key for edit in edit_batch}

    def edited_backend(gm: torch.fx.GraphModule, _: List[torch.Tensor]):
        unseen = set(target_dict.keys())
        
        # print(gm)
        for edit in edit_batch:
            gm.add_submodule(edit.key, edit.replacement)

        for node in gm.graph.nodes:
            if node.op == "placeholder": 
                continue
            if node.name in unseen:
                
                with gm.graph.inserting_before(node):
                    new = gm.graph.create_node(node.op, node.target, args=node.args, kwargs=node.kwargs, name="original_" + node.name)
                    wrapper_node = gm.graph.call_module(target_dict[node.name], args=(new,))
                    node.replace_all_uses_with(wrapper_node)
                    gm.graph.erase_node(node)

                unseen.remove(node.name)
            
            if not unseen:
                break


        gm.recompile()
        # print(gm)

        gm.graph.print_tabular()

        return gm.forward

    return edited_backend

In [20]:
mod = model._model.transformer.h[0].attn

for edit in grouped_edits["transformer.h.0.attn"]:
    setattr(mod, edit.key, edit.replacement)

In [21]:
torch._dynamo.reset()

opt_model = torch.compile(model._model.transformer.h[0].attn, backend=get_backend(grouped_edits["transformer.h.0.attn"]), dynamic=True)

gm = opt_model(attn_input.value[0][0])

opcode         name              target                                                    args                                      kwargs
-------------  ----------------  --------------------------------------------------------  ----------------------------------------  --------
placeholder    l_x_              L_x_                                                      ()                                        {}
call_method    size              size                                                      (l_x_,)                                   {}
get_attr       l__self___bias    L__self___bias                                            ()                                        {}
call_method    size_1            size                                                      (l_x_, -1)                                {}
call_method    view              view                                                      (l_x_, -1, size_1)                        {}
get_attr       l__self___weight  L__se

In [22]:
model._model.transformer.h[0].attn = opt_model

from nnsight.envoy import Envoy
model._envoy = Envoy(model._model)

model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
          (query_wrapper): WrapperModule()
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1-11): 11 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p

In [23]:
with model.trace("empty", scan=False, validate=False):
    test = model.transformer.h[0].attn._orig_mod.query_wrapper.output.save()
    out_two = model.transformer.h[0].attn.output.save()

opcode         name            target                       args                             kwargs
-------------  --------------  ---------------------------  -------------------------------  ----------
placeholder    l_stack0_       L_stack0_                    ()                               {}
call_method    split           split                        (l_stack0_, 768)                 {'dim': 2}
call_function  original_query  <built-in function getitem>  (split, 0)                       {}
call_module    query_wrapper   query_wrapper                (original_query,)                {}
call_function  key             <built-in function getitem>  (split, 1)                       {}
call_function  value           <built-in function getitem>  (split, 2)                       {}
call_method    tensor          view                         (query_wrapper, (1, 1, 12, 64))  {}
call_method    query_1         permute                      (tensor, 0, 2, 1, 3)             {}
call_method    tenso

In [27]:
(out[0] == out_two[0]).sum() / out[0].numel() 

tensor(1., device='cuda:0')