In [2]:
# 现在要有能手撕transformer的能力
from torch import nn
from transformers import AutoConfig, AutoTokenizer

model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

text = "The true nobility is in being superior to your previous self."
# 通过设置 add_special_tokens=False 去除了分词结果中的 [CLS] 和 [SEP]
inputs = tokenizer(text=text, return_tensors="pt", add_special_tokens=False)
print(inputs.input_ids)

config = AutoConfig.from_pretrained(model_ckpt)
token_emb =nn.Embedding(config.vocab_size, config.hidden_size)
print(token_emb)

inputs_embeds = token_emb(inputs.input_ids)
print(inputs_embeds.size())

tensor([[ 1996,  2995, 11760,  2003,  1999,  2108,  6020,  2000,  2115,  3025,
          2969,  1012]])
Embedding(30522, 768)
torch.Size([1, 12, 768])


In [3]:
import torch
import torch.nn.functional as F
from math import sqrt

# attention机制
def scaled_dot_product_attention(query, key, value, query_mask=None, key_mask=None, mask=None):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
    if query_mask is not None and key_mask is not None:
        mask = torch.bmm(query_mask.unsqueeze(-1), key_mask.unsqueeze(1))
    if mask is not None:
        scores = torch.masked_fill(mask == 0, -float("inf"))
    weights = F.softmax(scores, dim=-1)
    return torch.bmm(weights, value)

In [12]:
# 注意力头->多头注意力
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)
    
    def forward(self, query, key, value, query_mask=None, key_mask=None, mask=None):
        return scaled_dot_product_attention(self.q(query), self.k(key), self.v(value), query_mask, key_mask, mask)
    
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.hidden_size
        num_heads = config.num_attention_heads
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList([AttentionHead(embed_dim, head_dim) for _ in range(num_heads)])
        self.output_linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, query, key, value, query_mask=None, key_mask=None, mask=None):
        x = torch.concat([h(query, key, value, query_mask, key_mask, mask) for h in self.heads], dim=-1)
        return self.output_linear(x)

In [13]:
# 测试代码
# multihead_attn = MultiHeadAttention(config)
# query = key = value = inputs_embeds
# attn_output = multihead_attn(query, key, value)
# print(attn_output.size())

torch.Size([1, 12, 768])


In [14]:
# Encoder结构
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.linear_1 = nn.Linear(self.embed_dim, config.intermediate_size)
        self.linear_2 = nn.Linear(config.intermediate_size, self.embed_dim)
        self.gelu = nn.GELU() # 高斯误差线性单元
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
    def forward(self, x):
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.linear_2(x)
        return self.dropout(x)

In [15]:
# 测试代码
# feed_forward = FeedForward(config)
# ff_outputs = feed_forward(attn_output)
# print(ff_outputs.size())

torch.Size([1, 12, 768])


In [17]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)
    
    def forward(self, x, mask=None):
        # 注意这里残差网络的设置
        x = x + self.attention(x, x, x, mask=mask)
        tmp_x = self.layer_norm_1(x)
        x = tmp_x + self.feed_forward(tmp_x)
        return self.layer_norm_2(x)

In [18]:
encoder_layer = TransformerEncoderLayer(config)
print(inputs_embeds.shape)
print(encoder_layer(inputs_embeds).size())

torch.Size([1, 12, 768])
torch.Size([1, 12, 768])


In [19]:
# 位置编码这一部分了解即可
class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout()

    def forward(self, input_ids):
        # Create position IDs for input sequence
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0)
        # Create token and position embeddings
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        # Combine token and position embeddings
        embeddings = token_embeddings + position_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

embedding_layer = Embeddings(config)
print(embedding_layer(inputs.input_ids).size())

torch.Size([1, 12, 768])


In [20]:
class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = Embeddings(config)
        self.layers = nn.ModuleList([TransformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
        
    def forward(self, x, mask=None):
        x = self.embeddings(x)
        for layer in self.layers:
            x = layer(x, mask=mask)
        return x

In [22]:
encoder = TransformerEncoder(config)
print(encoder(inputs.input_ids).size())
print(encoder(inputs.input_ids))

torch.Size([1, 12, 768])
tensor([[[ 0.1278, -0.2382,  0.0507,  ..., -0.2470,  0.7110,  0.5316],
         [ 0.7871, -2.6126, -0.0287,  ..., -1.0430, -1.7671,  0.6151],
         [ 0.6588, -0.4444,  0.8571,  ..., -0.3877, -0.7966,  0.1449],
         ...,
         [-0.2746, -0.5145, -0.5651,  ..., -1.6235, -0.7572,  0.3775],
         [-0.6288,  0.2726, -0.4630,  ..., -0.4225, -0.3616,  0.3818],
         [ 0.2232, -0.1742,  0.2891,  ...,  0.1275,  0.1498,  0.8632]]],
       grad_fn=<NativeLayerNormBackward0>)
