# Edit

In [1]:
from nnsight import LanguageModel
from nnsight.edit import Edit
import torch

  from .autonotebook import tqdm as notebook_tqdm


Load some dispatched model

In [2]:
model = LanguageModel("openai-community/gpt2", device_map="cuda:0", dispatch=True)

Create a simple wrapper module that passes its args through. Useful for hooking the forward pass.

In [3]:
class WrapperModule(torch.nn.Module):

    def forward(self, *args, **kwargs):
        if len(args) == 1:
            args = args[0]
        
        return args

Declare an edit.

In [4]:
edit_one = Edit(
    "transformer.h.0.attn", # Path to module that we want to edit
    "query", # Name of parameter we want to wrap
    "query_wrapper", # Name of wrapper in our module
    WrapperModule() # Module we'd like to wrap with
)

We could also modify the args as they are passed through.

In [5]:
class EditModule(torch.nn.Module):

    def forward(self, *args, **kwargs):
        if len(args) == 1:
            args = args[0]
        
        value = args * 100

        return value

edit_two = Edit(
    "transformer.h.0.attn", # Path to module that we want to edit
    "attn_weights", # Name of parameter we want to wrap
    "weights_wrapper", # Name of wrapper in our module
    WrapperModule() # Module we'd like to wrap with
)

We can use various utils from edit module to see how Dynamo will transform bytecode.

In [6]:
from nnsight.edit import print_gm

with model.trace(" "):
    mlp_input = model.transformer.h[0].mlp.input.save()

print_gm(model.transformer.h[0].mlp, mlp_input[0][0])

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


We'll load the edits into the module, and access at the points we declared.

In [7]:
model.load_edits([edit_one, edit_two])

Note that the first time running a model after loading edits will take a little longer.

In [8]:
with model.trace("National Deep Inferencing Fabric", scan=False, validate=False):
    value_one = model.transformer.h[0].attn.query_wrapper.output.save()
    value_two = model.transformer.h[0].attn.weights_wrapper.input.save()

print(value_one.shape)
print(value_two.shape)

torch.Size([1, 6, 768])
((torch.Size([1, 12, 6, 6]),), {})


You can edit and intervene as normal.

In [None]:
with model.trace("National Deep Inferencing Fabric", scan=False, validate=False):
    value_one = model.transformer.h[0].attn.query_wrapper.output.grad.save()
    value_two = model.transformer.h[0].attn.weights_wrapper.output * 100

    value = model.output.logits.sum()
    value.backward()

print(value_one.shape)
print(value_two.shape)

# Technical Notes

Edit has an option kwarg, "instance", which takes the instance at which you want to intervene. 

When Dynamo has a graph break, this makes it so that variable names are reset. By default, the edit module grabs the value from the first graph it is found in.

In [10]:
edit_one = Edit(
    "transformer.h.0.attn", # Path to module that we want to edit
    "query", # Name of parameter we want to wrap
    "query_wrapper", # Name of wrapper in our module
    WrapperModule() # Module we'd like to wrap with
)