In [1]:
import os#环境代理设置
os.environ["http_proxy"] = "http://127.0.0.1:7890"
os.environ["https_proxy"] = "http://127.0.0.1:7890"

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import transformers
import numpy as np
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import AutoTokenizer
from bertviz.transformers_neuron_view import BertModel
from bertviz.neuron_view import show

In [4]:
model_ckpt = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = BertModel.from_pretrained(model_ckpt)

In [5]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Lin

In [6]:
sample_text = 'time flies like an arrow'
show(model, model_type="bert", tokenizer=tokenizer, sentence_a=sample_text, 
     display_mode="light", layer=0, head=8)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### computation of self-attention

In [7]:
tokenizer.model_input_names

['input_ids', 'token_type_ids', 'attention_mask']

In [8]:
tokenizer

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [9]:
model_inputs = tokenizer(sample_text, return_tensors='pt', add_special_tokens=False)

In [10]:
model_inputs

{'input_ids': tensor([[ 2051, 10029,  2066,  2019,  8612]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [11]:
from torch import nn
from transformers import AutoConfig, AutoTokenizer, AutoModel

In [12]:
config = AutoConfig.from_pretrained(model_ckpt)
config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.46.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [13]:
token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
token_embedding

Embedding(30522, 768)

In [14]:
model_inputs['input_ids'].shape

torch.Size([1, 5])

In [15]:
model_inputs['input_ids']

tensor([[ 2051, 10029,  2066,  2019,  8612]])

In [16]:
input_embeddings = token_embedding(model_inputs['input_ids'])

In [17]:
input_embeddings.shape

torch.Size([1, 5, 768])

In [18]:
# 暂时不考虑position embedding

In [19]:
def scaled_dot_product_attention(query, key, value):
    dim_k = key.size(-1)
    attn_scores = torch.bmm(query, key.transpose(1,2)) / np.sqrt(dim_k)
    attn_weights = F.softmax(attn_scores, dim=-1)
    return torch.bmm(attn_weights, value)

In [20]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.W_q = nn.Linear(embed_dim, head_dim)
        self.W_k = nn.Linear(embed_dim, head_dim)
        self.W_v = nn.Linear(embed_dim, head_dim)
    def forward(self, hidden_state):
        q = self.W_q(hidden_state)
        k = self.W_k(hidden_state)
        v = self.W_v(hidden_state)
        attn_outputs = scaled_dot_product_attention(q, k, v)
        return attn_outputs

In [21]:
class MutiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.hidden_size
        num_heads = config.num_attention_heads
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList([
            AttentionHead(embed_dim, head_dim) for _ in range(num_heads)
        ])
        self.output_layer = nn.Linear(embed_dim, embed_dim)
    def forward(self, hidden_state):
        print(f'input hidden_state: {hidden_state.shape}')
        print(f'head(hidden_state): {self.heads[11](hidden_state).shape}')
        x = torch.cat([head(hidden_state) for head in self.heads], dim=-1)
        print(f'cat heads: {x.shape}')
        x = self.output_layer(x)
        return x

In [22]:
mha = MutiHeadAttention(config)

In [23]:
mha(input_embeddings)

input hidden_state: torch.Size([1, 5, 768])
head(hidden_state): torch.Size([1, 5, 64])
cat heads: torch.Size([1, 5, 768])


tensor([[[-0.3563,  0.3929,  0.2303,  ..., -0.1719, -0.0214, -0.0459],
         [-0.3417,  0.4604,  0.0924,  ..., -0.1610, -0.0160, -0.1048],
         [-0.2919,  0.4289,  0.2050,  ..., -0.1073,  0.0074, -0.1263],
         [-0.2956,  0.3219,  0.1450,  ..., -0.2711,  0.0137, -0.0327],
         [-0.3275,  0.3571,  0.1405,  ..., -0.2336, -0.0738, -0.0190]]],
       grad_fn=<ViewBackward0>)