In [2]:
from IPython.display import Image
import torch
from torch import nn
import math
from bertviz.transformers_neuron_view import BertModel, BertConfig
from transformers import BertTokenizer


### 1. model config and load

In [3]:
max_length = 256
model_name = 'bert-base-uncased'
config = BertConfig.from_pretrained(model_name, output_attention=True,
                                    output_hidden_states=True,
                                    return_dict=True)
tokenizer = BertTokenizer.from_pretrained(model_name)
config.max_position_embeddings = max_length
model = BertModel(config).from_pretrained(model_name)
model = model.eval()

100%|██████████| 433/433 [00:00<?, ?B/s]
100%|██████████| 440473133/440473133 [00:36<00:00, 12025115.90B/s]


In [4]:
model.config

{
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": true,
  "output_hidden_states": false,
  "pad_token_id": 0,
  "torchscript": false,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

In [5]:
att_head_size = int(model.config.hidden_size/model.config.num_attention_heads)

In [6]:
att_head_size

64

In [7]:
model.encoder

BertEncoder(
  (layer): ModuleList(
    (0-11): 12 x BertLayer(
      (attention): BertAttention(
        (self): BertSelfAttention(
          (query): Linear(in_features=768, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=True)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): BertSelfOutput(
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (LayerNorm): BertLayerNorm()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (intermediate): BertIntermediate(
        (dense): Linear(in_features=768, out_features=3072, bias=True)
      )
      (output): BertOutput(
        (dense): Linear(in_features=3072, out_features=768, bias=True)
        (LayerNorm): BertLayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
)

In [12]:
model.encoder.layer[0].attention.self.query.weight.T[:, 64:128]
# head1[:, :64]  head2[:, 64:128]

tensor([[-0.0112, -0.0324, -0.0615,  ..., -0.0383,  0.0031,  0.0059],
        [ 0.0260, -0.0067, -0.0616,  ...,  0.1097,  0.0029, -0.0540],
        [-0.0169,  0.0232,  0.0068,  ...,  0.0124, -0.0168,  0.0301],
        ...,
        [ 0.1083,  0.0056,  0.0968,  ...,  0.0188, -0.0171,  0.0141],
        [-0.0436, -0.1032, -0.1035,  ...,  0.0138, -0.0488, -0.0453],
        [-0.0611,  0.0224, -0.0320,  ...,  0.0376,  0.0186, -0.0482]],
       grad_fn=<SliceBackward0>)

### 2. data

In [13]:
from sklearn.datasets import fetch_20newsgroups
newsgroup_train = fetch_20newsgroups(subset='train')
inputs_tests = tokenizer(newsgroup_train['data'][:1],
                        truncation=True, padding=True, max_length=max_length,
                        return_tensors='pt')

In [14]:
inputs_tests.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [15]:
inputs_tests['input_ids'].shape

torch.Size([1, 201])

### 3. model output

In [16]:
model_output = model(**inputs_tests)

In [17]:
len(model_output)

3

In [18]:
model_output[-1][0].keys()

dict_keys(['attn', 'queries', 'keys'])

In [24]:
model_output[-1][0]['attn'].shape # torch.Size([1, 12, 201, 201])
model_output[-1][0]['attn'][0, 0, :, :]

tensor([[0.0053, 0.0109, 0.0052,  ..., 0.0039, 0.0036, 0.0144],
        [0.0086, 0.0041, 0.0125,  ..., 0.0045, 0.0041, 0.0071],
        [0.0051, 0.0043, 0.0046,  ..., 0.0043, 0.0045, 0.0031],
        ...,
        [0.0010, 0.0023, 0.0055,  ..., 0.0012, 0.0018, 0.0011],
        [0.0010, 0.0023, 0.0057,  ..., 0.0012, 0.0017, 0.0007],
        [0.0022, 0.0056, 0.0063,  ..., 0.0045, 0.0048, 0.0015]],
       grad_fn=<SliceBackward0>)

### 4. from scratch

In [25]:
emb_output = model.embeddings(inputs_tests['input_ids'], inputs_tests['token_type_ids'])

In [26]:
emb_output.shape

torch.Size([1, 201, 768])

In [27]:
model.encoder.layer[0]

BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
  )
  (output): BertOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [33]:
model.encoder.layer[0].attention.self.query.weight

Parameter containing:
tensor([[-0.0164,  0.0261, -0.0263,  ...,  0.0154,  0.0768,  0.0548],
        [-0.0326,  0.0346, -0.0423,  ..., -0.0527,  0.1393,  0.0078],
        [ 0.0105,  0.0334,  0.0109,  ..., -0.0279,  0.0258, -0.0468],
        ...,
        [-0.0085,  0.0514,  0.0555,  ...,  0.0282,  0.0543, -0.0541],
        [-0.0198,  0.0944,  0.0617,  ..., -0.1042,  0.0601,  0.0470],
        [ 0.0015, -0.0952,  0.0099,  ..., -0.0191, -0.0508, -0.0085]],
       requires_grad=True)

In [37]:
model.encoder.layer[0].attention.self.query.bias.shape

torch.Size([768])

In [38]:
# emb_output[0].shape: 201*768
# query.weight.T.shape: 768*768
# query.weight.T[:, :att_head_size].shape: 768*64
# -> 201*64
q_first_head_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.query.weight.T[:,:att_head_size] \
                        + model.encoder.layer[0].attention.self.query.bias[:att_head_size]

In [40]:
q_first_head_first_layer.shape

torch.Size([201, 64])

In [41]:
k_first_head_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.key.weight.T[:,:att_head_size] \
                        + model.encoder.layer[0].attention.self.key.bias[:att_head_size]

In [42]:
k_first_head_first_layer.shape

torch.Size([201, 64])

In [52]:
attn_scores = torch.nn.Softmax(q_first_head_first_layer @ k_first_head_first_layer.T / math.sqrt(att_head_size))

In [63]:
attn_scores = torch.nn.Softmax(dim=-1)(q_first_head_first_layer @ k_first_head_first_layer.T / math.sqrt(att_head_size))

In [64]:
attn_scores # (201*64) * (64*201) -> (201*201)

tensor([[0.0053, 0.0109, 0.0052,  ..., 0.0039, 0.0036, 0.0144],
        [0.0086, 0.0041, 0.0125,  ..., 0.0045, 0.0041, 0.0071],
        [0.0051, 0.0043, 0.0046,  ..., 0.0043, 0.0045, 0.0031],
        ...,
        [0.0010, 0.0023, 0.0055,  ..., 0.0012, 0.0018, 0.0011],
        [0.0010, 0.0023, 0.0057,  ..., 0.0012, 0.0017, 0.0007],
        [0.0022, 0.0056, 0.0063,  ..., 0.0045, 0.0048, 0.0015]],
       grad_fn=<SoftmaxBackward0>)

In [57]:
v_first_head_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.value.weight.T[:,:att_head_size] \
                        + model.encoder.layer[0].attention.self.value.bias[:att_head_size]

In [58]:
v_first_head_first_layer.shape

torch.Size([201, 64])

In [65]:
attn_emb = attn_scores @ v_first_head_first_layer

In [67]:
attn_emb

tensor([[-4.5640e-01,  4.6211e-02,  4.3913e-02,  ..., -2.0099e-02,
         -1.2756e-02,  6.4255e-03],
        [-4.5674e-01,  3.4322e-02,  3.2707e-02,  ..., -4.9206e-02,
          1.4975e-02, -3.0628e-02],
        [-4.9474e-01, -2.9539e-04, -7.5375e-04,  ..., -2.0035e-02,
          1.7146e-02, -3.0126e-02],
        ...,
        [-3.7991e-01,  5.2831e-02,  2.2534e-02,  ..., -1.8338e-02,
         -6.9508e-02,  2.1317e-02],
        [-3.8071e-01,  4.0900e-02,  2.8770e-02,  ..., -2.1192e-02,
         -5.2893e-02,  1.9734e-02],
        [-4.7131e-01,  1.0947e-01,  1.1631e-02,  ..., -3.4542e-02,
         -2.3752e-02, -5.0505e-03]], grad_fn=<MmBackward0>)

In [68]:
attn_emb.shape

torch.Size([201, 64])