In [None]:
!pip install bertviz

In [None]:
# アテンションの可視化
from transformers import AutoTokenizer
from bertviz.transformers_neuron_view import BertModel
from bertviz.neuron_view import show

model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = BertModel.from_pretrained(model_ckpt)
text = "time flies like an arrow"
show(model, "bert", tokenizer, text, display_mode="light", layer=0, head=8)

In [None]:
# 入力
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
inputs.input_ids # 特殊トークンを除いた5個

In [None]:
# 密な埋め込みを用意
from torch import nn
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_ckpt)
token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
token_emb # ルックアップテーブル

In [None]:
# 埋め込みベクトル化
inputs_embeds = token_emb(inputs.input_ids)
inputs_embeds.size() # [1, 5, 768]

In [None]:
# アテンションスコア計算
import torch
from math import sqrt

query = key = value = inputs_embeds
dim_k = key.size(-1)
scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k) # transposeで 0-index の1,2成分をチェンジ
scores.size() # [1, 5, 5]

In [None]:
# ソフトマックスをかける
import torch.nn.functional as F

weights = F.softmax(scores, dim=-1)
weights.sum(dim=-1) # 5成分とも1

In [None]:
# アテンションの重みとバリューをかける
attn_outputs = torch.bmm(weights, value)
attn_outputs.shape # [1, 5, 768]

In [None]:
# スケール化ドット積アテンション関数
def scaled_dot_product_attention(query, key, value):
  dim_k = query.size(-1)
  scores = torch.bmm(query, key.transpose(1,2)) / sqrt(dim_k)
  weights = F.softmax(scores, dim=-1)
  return torch.bmm(weights, value)