In [1]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Tokenizer
from datasets import load_dataset
from tqdm import tqdm
import json
import torch
import argparse
import datasets
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pylab as plt
from os import listdir
import json
import math
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


# Config

In [2]:
n_layers = 12
list_modules = ['attn', 'mlp']
trace_module_id = "transformer.h.{l}.{m}"

list_trace_module_ids = []

for l in range(n_layers):
    for m in list_modules:
        list_trace_module_ids.append(trace_module_id.format(l=l, m=m))

In [3]:
gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
gpt2_tokeniser = GPT2Tokenizer.from_pretrained("gpt2")

In [4]:
def tokenize_function(examples):
    output = gpt2_tokeniser(examples)
    return output

In [10]:
temp = {}

def func_v1(m_id):
    def hook_v1(module, _input, _output):
        global seq
        dict_sequence[m_id].append(seq)
        seq += 1
        
        module_ = list(module) if isinstance(module, tuple) else module
        input_ = list(_input) if isinstance(_input, tuple) else _input
        output_ = list(_output) if isinstance(_output, tuple) else _output
#         print(type(module_), module_)
        temp[m_id] = [module_, input_, output_]
        
    return hook_v1

gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")

for m_id in list_trace_module_ids:
    gpt2.get_submodule(m_id).register_forward_hook(func_v1(m_id))    


In [16]:
input_text = ["hi hello", 'good bye']
data = tokenize_function(input_text)
inputs = torch.tensor(data['input_ids'])
inputs

tensor([[ 5303, 23748],
        [11274, 33847]])

In [17]:
dict_sequence = {x: [] for x in list_trace_module_ids}
seq = 0

with torch.no_grad():
    clean_outputs = gpt2(inputs, labels=inputs.clone())
    clean_loss = np.exp(clean_outputs.loss.item())


## Order

In [18]:
dict_sequence

{'transformer.h.0.attn': [0],
 'transformer.h.0.mlp': [1],
 'transformer.h.1.attn': [2],
 'transformer.h.1.mlp': [3],
 'transformer.h.2.attn': [4],
 'transformer.h.2.mlp': [5],
 'transformer.h.3.attn': [6],
 'transformer.h.3.mlp': [7],
 'transformer.h.4.attn': [8],
 'transformer.h.4.mlp': [9],
 'transformer.h.5.attn': [10],
 'transformer.h.5.mlp': [11],
 'transformer.h.6.attn': [12],
 'transformer.h.6.mlp': [13],
 'transformer.h.7.attn': [14],
 'transformer.h.7.mlp': [15],
 'transformer.h.8.attn': [16],
 'transformer.h.8.mlp': [17],
 'transformer.h.9.attn': [18],
 'transformer.h.9.mlp': [19],
 'transformer.h.10.attn': [20],
 'transformer.h.10.mlp': [21],
 'transformer.h.11.attn': [22],
 'transformer.h.11.mlp': [23]}

## ATTN

### module

In [19]:
temp['transformer.h.2.attn'][0]

GPT2Attention(
  (c_attn): Conv1D()
  (c_proj): Conv1D()
  (attn_dropout): Dropout(p=0.1, inplace=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

### input

In [20]:
len(temp['transformer.h.2.attn'][1])

1

In [21]:
temp['transformer.h.2.attn'][1][0].shape

torch.Size([2, 2, 768])

### output (activation, (k, v))

In [22]:
len(temp['transformer.h.2.attn'][2])

2

#### activation $ \in \mathbb{R}^{n\_batch \times n\_tokens \times dim\_model}$

In [23]:
temp['transformer.h.2.attn'][2][0].shape

torch.Size([2, 2, 768])

#### k, v $ \in \mathbb{R}^{n\_batch \times n\_heads \times n\_tokens \times dim\_head}$

In [24]:
temp['transformer.h.2.attn'][2][1][0].shape

torch.Size([2, 12, 2, 64])

## MLP

### module

In [25]:
temp['transformer.h.0.mlp'][0]

GPT2MLP(
  (c_fc): Conv1D()
  (c_proj): Conv1D()
  (act): NewGELUActivation()
  (dropout): Dropout(p=0.1, inplace=False)
)

### input (n_tokens, dim_model)

In [26]:
len(temp['transformer.h.0.mlp'][1])

1

In [27]:
temp['transformer.h.0.mlp'][1][0].shape

torch.Size([2, 2, 768])

### output (n_tokens, dim_model)

In [28]:
len(temp['transformer.h.0.mlp'][2])

2

In [29]:
temp['transformer.h.0.mlp'][2][0].shape

torch.Size([2, 768])

In [30]:
temp['transformer.h.0.mlp'][2][0]

tensor([[-0.1473, -0.8551,  0.2517,  ..., -1.0792,  0.4386, -0.8061],
        [-0.9995, -0.4823, -1.1081,  ..., -0.7381, -0.0529, -1.1930]])