In [2]:
from nnsight import LanguageModel
from nnsight.envoy import Envoy
from nnsight.patching import Patch, Patcher

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers.models import gpt2
from nnsight import util
from __future__ import annotations
from typing import Optional, Tuple, Union
import torch

class GPT2AttentionAltered(gpt2.modeling_gpt2.GPT2Attention):
    def __init__(self, config, is_cross_attention=False, layer_idx=None):
        super().__init__(config, is_cross_attention, layer_idx)

        self.query = util.WrapperModule()
        self.key = util.WrapperModule()
        self.value = util.WrapperModule()

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        if encoder_hidden_states is not None:
            if not hasattr(self, "q_attn"):
                raise ValueError(
                    "If class is used as cross attention, the weights `q_attn` have to be defined. "
                    "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
                )

            query = self.q_attn(hidden_states)
            key, value = self.c_attn(encoder_hidden_states).split(
                self.split_size, dim=2
            )
            attention_mask = encoder_attention_mask
        else:
            query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        # Altered -------------

        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        # ---------------------

        if layer_past is not None:
            past_key, past_value = layer_past
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        if use_cache is True:
            present = (key, value)
        else:
            present = None

        if self.reorder_and_upcast_attn:
            attn_output, attn_weights = self._upcast_and_reordered_attn(
                query, key, value, attention_mask, head_mask
            )
        else:
            attn_output, attn_weights = self._attn(
                query, key, value, attention_mask, head_mask
            )

        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs  # a, present, (attentions)

In [6]:
GPT2Patcher = Patcher(
    [Patch(gpt2.modeling_gpt2, GPT2AttentionAltered, "GPT2Attention")]
)

GPT2Patcher.__enter__()

model = LanguageModel("openai-community/gpt2", unified=False, device_map="auto")

print(model._model.transformer.h[0].attn)

GPT2Patcher.__exit__(None, None, None)

print(model._model.transformer.h[0].attn)

GPT2AttentionAltered(
  (c_attn): Conv1D()
  (c_proj): Conv1D()
  (attn_dropout): Dropout(p=0.1, inplace=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
  (query): WrapperModule()
  (key): WrapperModule()
  (value): WrapperModule()
)
GPT2AttentionAltered(
  (c_attn): Conv1D()
  (c_proj): Conv1D()
  (attn_dropout): Dropout(p=0.1, inplace=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
  (query): WrapperModule()
  (key): WrapperModule()
  (value): WrapperModule()
)


In [12]:
original_component = gpt2.modeling_gpt2.GPT2Attention

patched_component = model._model.transformer.h[0].attn
original_component.weight.data.copy_(patched_component.weight.data)
original_component.bias.data.copy_(patched_component.bias.data)

model._model.transformer.h[0].attn = original_component

AttributeError: type object 'GPT2AttentionAltered' has no attribute 'weight'

In [None]:
patched_component = original_model.submodule.patched_layer
original_component.weight.data.copy_(patched_component.weight.data)
original_component.bias.data.copy_(patched_component.bias.data)

original_model.submodule.patched_layer = original_component

In [None]:
with model.trace("Hello world"):
    query_output = model.transformer.h[0].attn.query.output.save()

In [None]:
GPT2Patcher.patches[0].restore()

In [None]:
query_output

In [None]:
gpt2.modeling_gpt2.GPT2Attention

In [None]:
model._model.transformer.h[0].__dict__

In [None]:
model._model.