In [1]:
from transformer_lens import HookedTransformer
from nnsight import LanguageModel
import torch

In [24]:
import inspect

class UnifiedTransformer(HookedTransformer):
    def __init__(self, cfg, tokenizer=None, move_to_device=True, default_padding_side="right", device: int = 0):
        """
        Initializes the Wrapped version of HookedTransformer.

        Args:
            cfg: The config to use for the model.
            tokenizer: The tokenizer to use for the model.
            move_to_device: Whether to move the model to the device specified in cfg.
            default_padding_side: Which side to pad on.
            device: The device to use for the model.
        """
        super().__init__(cfg, tokenizer, move_to_device, default_padding_side)
        self.device = torch.device(device)
        
    def forward(self, input_ids, labels=None, **kwargs):
        """
        A wrapper method to resolve naming conventions.
        """
        sig = inspect.signature(super().forward)

        if "labels" in sig.parameters.keys():
            return super().forward(input_ids=input_ids, labels=labels,**kwargs)
        
        return super().forward(input=input_ids,**kwargs)
    
    def __repr__(self):
        """
        Some __repr__ overrides to make the model more readable.
        """
        lines = [self.__class__.__name__ + '(']
        for name, module in self.named_children():

            module_str = repr(module)
            
            module_str = module_str.split('\n')
            
            module_str = [line for line in module_str if ('_input' not in line and '_output' not in line)]
            module_str = [line.replace('hook_', '') for line in module_str]

            module_str = '\n'.join(module_str)

            lines.append(f'  ({name}): {module_str}')

        return '\n'.join(lines) + '\n)'

In [3]:
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    # model.resize_token_embeddings(len(tokenizer))

In [25]:
gpt2_small: HookedTransformer = UnifiedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [26]:
gpt2_small

UnifiedTransformer(
  (embed): Embed()
  (pos_embed): PosEmbed()
  (blocks): ModuleList(
  (0-11): 12 x TransformerBlock(
    (ln1): LayerNormPre(
      (scale): HookPoint()
      (normalized): HookPoint()
    )
    (ln2): LayerNormPre(
      (scale): HookPoint()
      (normalized): HookPoint()
    )
    (attn): Attention(
      (k): HookPoint()
      (q): HookPoint()
      (v): HookPoint()
      (z): HookPoint()
      (attn_scores): HookPoint()
      (pattern): HookPoint()
      (result): HookPoint()
    )
    (mlp): MLP(
      (pre): HookPoint()
      (post): HookPoint()
    )
    (attn_in): HookPoint()
    (mlp_in): HookPoint()
    (attn_out): HookPoint()
    (mlp_out): HookPoint()
    (resid_pre): HookPoint()
    (resid_mid): HookPoint()
    (resid_post): HookPoint()
  )
)
  (ln_final): LayerNormPre(
  (scale): HookPoint()
  (normalized): HookPoint()
)
  (unembed): Unembed()
)

In [14]:
gpt2_small

embed
hook_embed
pos_embed
hook_pos_embed
blocks
ln_final
unembed


TypeError: __repr__ returned non-string (type NoneType)

In [5]:
ground_model = LanguageModel("gpt2", device_map="auto")

In [6]:
model = LanguageModel(gpt2_small, tokenizer=ground_model.tokenizer, device_map="cuda")


Moving model to device:  meta


In [16]:
with model.invoke("The city of Paris is the capital of the country of") as invoker:
    pass

In [17]:
tokenizer.decode(invoker.output[:,-1,:].softmax(-1).argmax(-1))

' France'

In [18]:
model

UnifiedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (