In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from collections import Counter
from torch import optim

In [2]:
# 正弦函数位置编码表
def get_sin_function_positional_coding_table(max_len_sequence, dim_embedding):
    sin_function_positional_coding_table = torch.zeros(max_len_sequence, dim_embedding)
    for pos_i in range(max_len_sequence):
        for hid_j in range(dim_embedding):
            sin_function_positional_coding_table[pos_i, hid_j] = pos_i / torch.pow(torch.tensor(10000.0), 2 * (hid_j // 2) / dim_embedding)
    sin_function_positional_coding_table[:, 0::2] = torch.sin(sin_function_positional_coding_table[:, 0::2])
    sin_function_positional_coding_table[:, 1::2] = torch.sin(sin_function_positional_coding_table[:, 1::2])
    return sin_function_positional_coding_table

In [3]:
# 填充注意力掩码
def get_padding_attn_mask(sequence_Q, sequence_K):
    batch_size, len_Q = sequence_Q.size()
    batch_size, len_K = sequence_K.size()
    padding_attn_mask = sequence_K.eq(0).unsqueeze(1).expand(batch_size, len_Q, len_K)
    return padding_attn_mask

In [4]:
# 缩放-点积-填充掩码-注意力
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V, padding_attn_mask):
        scale_weight = torch.matmul(Q, K.transpose(-2, -1)) / torch.pow(torch.tensor(Q.size(-1)), 2)
        scale_weight.masked_fill(padding_attn_mask, -1e9)
        attn_weight = F.softmax(scale_weight, dim=-1)
        context = torch.matmul(attn_weight, V)
        return context, attn_weight

In [5]:
# 多头-缩放-点积-填充掩码-注意力
class MultiHeadAttention(nn.Module):
    def __init__(self, dim_embedding, dim_Q, dim_K, dim_V, n_head):
        super().__init__()
        self.dim_Q = dim_Q
        self.dim_K = dim_K
        self.dim_V = dim_V
        self.n_head = n_head

        self.multi_head_linear_Q = nn.Linear(dim_embedding, self.dim_Q * n_head)
        self.multi_head_linear_K = nn.Linear(dim_embedding, self.dim_K * n_head)
        self.multi_head_linear_V = nn.Linear(dim_embedding, self.dim_V * n_head)

        self.MultiHeadAttention_ScaledDotProductAttention = ScaledDotProductAttention()
        self.linear = nn.Linear(self.n_head * self.dim_V, dim_embedding)
        self.layer_norm = nn.LayerNorm(dim_embedding)

    def forward(self, Q, K, V, padding_attn_mask):
        residual, batch_size = Q, Q.size(0)

        multi_head_Q  = self.multi_head_linear_Q(Q).view(batch_size, -1, self.n_head, self.dim_Q).transpose(-3, -2)
        multi_head_K  = self.multi_head_linear_K(K).view(batch_size, -1, self.n_head, self.dim_K).transpose(-3, -2)
        multi_head_V  = self.multi_head_linear_V(V).view(batch_size, -1, self.n_head, self.dim_V).transpose(-3, -2)

        multi_head_padding_attn_mask = padding_attn_mask.unsqueeze(1).repeat(1, self.n_head, 1, 1)

        context, attn_weight = self.MultiHeadAttention_ScaledDotProductAttention(multi_head_Q, multi_head_K, multi_head_V, multi_head_padding_attn_mask)
        context = context.transpose(-3, -2).contiguous().view(batch_size, -1, self.n_head * self.dim_V)
        context = self.layer_norm(self.linear(context) + residual)
        return context, attn_weight

In [6]:
# 逐位-卷积-前馈网络
class PositionWiseFeedForwardNet(nn.Module):
    def __init__(self, dim_embedding, dim_pwffn):
        super().__init__()
        self.conv1 = nn.Conv1d(dim_embedding, dim_pwffn, kernel_size=1)
        self.conv2 = nn.Conv1d(dim_pwffn, dim_embedding, kernel_size=1)
        self.layer_norm = nn.LayerNorm(dim_embedding)

    def forward(self, context):
        residual = context
        context = F.relu(self.conv1(context.transpose(1, 2)))
        context = self.conv2(context).transpose(1, 2)
        context = self.layer_norm(context + residual)
        return context

In [7]:
# 编码器层
class EncoderLayer(nn.Module):
    def __init__(self, dim_embedding, dim_Q, dim_K, dim_V, n_head, dim_pwffn):
        super().__init__()
        self.EncoderLayer_MultiHeadAttention = MultiHeadAttention(dim_embedding, dim_Q, dim_K, dim_V, n_head)
        self.EncoderLayer_PositionWiseFeedForwardNet = PositionWiseFeedForwardNet(dim_embedding, dim_pwffn)

    def forward(self, encoder_layer_input, encoder_layer_padding_attn_mask):
        encoder_layer_output, encoder_layer_output_attn_weight = self.EncoderLayer_MultiHeadAttention(encoder_layer_input, encoder_layer_input, encoder_layer_input, encoder_layer_padding_attn_mask)
        encoder_layer_output = self.EncoderLayer_PositionWiseFeedForwardNet(encoder_layer_output)
        return encoder_layer_output, encoder_layer_output_attn_weight

In [8]:
# 编码器
class Encoder(nn.Module):
    def __init__(self, corpus, dim_embedding, dim_Q, dim_K, dim_V, n_head, dim_pwffn, n_layer):
        super().__init__()
        self.token_embedding = nn.Embedding(corpus.len_source_vocabulary, dim_embedding)
        self.position_embedding = nn.Embedding.from_pretrained(get_sin_function_positional_coding_table(corpus.max_len_source_sequence + 1, dim_embedding), freeze=True)
        self.layers = nn.ModuleList(EncoderLayer(dim_embedding, dim_Q, dim_K, dim_V, n_head, dim_pwffn) for _ in range(n_layer))

    def forward(self, encoder_input):
        positional_coding_template = torch.arange(1, encoder_input.size(1) + 1).unsqueeze(0)
        encoder_output = self.token_embedding(encoder_input) + self.position_embedding(positional_coding_template)

        encoder_padding_attn_mask = get_padding_attn_mask(encoder_input, encoder_input)

        encoder_attn_weights = []
        for layer in self.layers:
            encoder_output, encoder_attn_weight = layer(encoder_output, encoder_padding_attn_mask)
            encoder_attn_weights.append(encoder_attn_weight)
            
        return encoder_output, encoder_attn_weights

In [9]:
# 后续注意力掩码
def get_subsequent_attn_mask(decoder_input):
    return torch.triu(torch.ones(decoder_input.size(0), decoder_input.size(1), decoder_input.size(1)), diagonal=1).byte()

In [10]:
# 解码器层
class DecoderLayer(nn.Module):
    def __init__(self, dim_embedding, dim_Q, dim_K, dim_V, n_head, dim_pwffn):
        super().__init__()
        self.DecoderLayer_MultiHeadAttention = MultiHeadAttention(dim_embedding, dim_Q, dim_K, dim_V, n_head)
        self.EncoderDecoderLayer_MultiHeadAttention = MultiHeadAttention(dim_embedding, dim_Q, dim_K, dim_V, n_head)
        self.DecoderLayer_PositionWiseFeedForwardNet = PositionWiseFeedForwardNet(dim_embedding, dim_pwffn)

    def forward(self, decoder_layer_input, padding_subsequent_attn_mask, encoder_output, encoder_decoder_padding_attn_mask):
        decoder_layer_output, decoder_layer_attn_weight = self.DecoderLayer_MultiHeadAttention(decoder_layer_input, decoder_layer_input, decoder_layer_input, padding_subsequent_attn_mask)
        encoder_decoder_layer_output, encoder_decoder_layer_attn_weight = self.EncoderDecoderLayer_MultiHeadAttention(decoder_layer_output, encoder_output, encoder_output, encoder_decoder_padding_attn_mask)
        decoder_layer_output = self.DecoderLayer_PositionWiseFeedForwardNet(encoder_decoder_layer_output)
        return decoder_layer_output, decoder_layer_attn_weight, encoder_decoder_layer_attn_weight

In [11]:
# 解码器
class Decoder(nn.Module):
    def __init__(self, corpus, dim_embedding, dim_Q, dim_K, dim_V, n_head, dim_pwffn, n_layer):
        super().__init__()
        self.token_embedding = nn.Embedding(corpus.len_target_vocabulary, dim_embedding)
        self.position_embedding = nn.Embedding.from_pretrained(get_sin_function_positional_coding_table(corpus.max_len_target_sequence + 1, dim_embedding), freeze=True)
        self.layers = nn.ModuleList(DecoderLayer(dim_embedding, dim_Q, dim_K, dim_V, n_head, dim_pwffn) for _ in range(n_layer))
        
    def forward(self, decoder_input, encoder_input, encoder_output):
        positional_coding_template = torch.arange(1, decoder_input.size(1) + 1).unsqueeze(0)
        decoder_output = self.token_embedding(decoder_input) + self.position_embedding(positional_coding_template)

        decoder_padding_attn__mask = get_padding_attn_mask(decoder_input, decoder_input)
        decoder_subsequent_attn_mask = get_subsequent_attn_mask(decoder_input)
        decoder_padding_subsequent_attn_mask = torch.gt((decoder_padding_attn__mask + decoder_subsequent_attn_mask), 0)
        encoder_decoder_padding_attn_mask = get_padding_attn_mask(decoder_input, encoder_input)

        decoder_attn_weights, encoder_decoder_attn_weights = [], []
        for layer in self.layers:
            decoder_output, decoder_attn_weight, encoder_decoder_attn_weight = layer(decoder_output, decoder_padding_subsequent_attn_mask, encoder_output, encoder_decoder_padding_attn_mask)
            decoder_attn_weights.append(decoder_attn_weight)
            encoder_decoder_attn_weights.append(encoder_decoder_attn_weight)

        return decoder_output, decoder_attn_weights, encoder_decoder_attn_weights

In [12]:
class Transformer(nn.Module):
    def __init__(self, corpus, dim_embedding, dim_Q, dim_K, dim_V, n_head, dim_pwffn, n_layer):
        super().__init__()
        self.Encoder = Encoder(corpus, dim_embedding, dim_Q, dim_K, dim_V, n_head, dim_pwffn, n_layer)
        self.Decoder = Decoder(corpus, dim_embedding, dim_Q, dim_K, dim_V, n_head, dim_pwffn, n_layer)
        self.projection = nn.Linear(dim_embedding, corpus.len_target_vocabulary, bias=False)

    def forward(self, encoder_input, decoder_input):
        encoder_output, encoder_attn_weights = self.Encoder(encoder_input)
        decoder_output, decoder_attn_weights, encoder_decoder_attn_weights = self.Decoder(decoder_input, encoder_input, encoder_output)
        decoder_logits = self.projection(decoder_output)
        return decoder_logits, encoder_attn_weights, decoder_attn_weights, encoder_decoder_attn_weights

In [13]:
class TranslationCorpus:
    def __init__(self, sentences):
        self.sentences = sentences
        self.max_len_source_sequence = max(len(sentence[0].split()) for sentence in sentences) + 1
        self.max_len_target_sequence = max(len(sentence[1].split()) for sentence in sentences) + 2
        self.source_vocabulary, self.target_vocabulary, self.len_source_vocabulary, self.len_target_vocabulary = self.create_vocabulary()
        self.index_vocabulary_to_source_vocabulary = {v:k for k, v in self.source_vocabulary.items()}
        self.index_vocabulary_to_target_vocabulary = {v:k for k, v in self.target_vocabulary.items()}

    def create_vocabulary(self):
        source_counter = Counter(token for sentence in self.sentences for token in sentence[0].split())
        target_counter = Counter(token for sentence in self.sentences for token in sentence[1].split())
        source_vocabulary = {'<pad>': 0, **{token: i + 1 for i, token in enumerate(source_counter)}}
        target_vocabulary = {'<pad>': 0, '<sos>': 1, '<eos>': 2, **{token: i + 3 for i, token in enumerate(target_counter)}}
        len_source_vocabulary = len(source_vocabulary)
        len_target_vocabulary = len(target_vocabulary)
        return source_vocabulary, target_vocabulary, len_source_vocabulary, len_target_vocabulary
    
    def batch_dataset(self, batch_size, if_train_dataset=True):
        feature_dataset, label_dataset, target_dataset = [], [], []
        sentence_indices = torch.randperm(len(self.sentences))[:batch_size]
        for index in sentence_indices:
            source_sentence, target_sentence = self.sentences[index]
            source_sentence_to_index_sentence = [self.source_vocabulary[token] for token in source_sentence.split()]
            target_sentence_to_index_sentence = [self.target_vocabulary['<sos>']] + [self.target_vocabulary[token] for token in target_sentence.split()] + [self.target_vocabulary['<eos>']]
            source_sentence_to_index_sentence += [self.source_vocabulary['<pad>']] * (self.max_len_source_sequence - len(source_sentence_to_index_sentence))
            target_sentence_to_index_sentence += [self.target_vocabulary['<pad>']] * (self.max_len_target_sequence - len(target_sentence_to_index_sentence))
            feature_dataset.append(source_sentence_to_index_sentence)
            label_dataset.append([self.target_vocabulary['<sos>']] + ([self.target_vocabulary['<pad>']] * (self.max_len_target_sequence - 2)) if if_train_dataset else target_sentence_to_index_sentence[:-1])
            target_dataset.append(target_sentence_to_index_sentence[1:])
        return torch.LongTensor(feature_dataset), torch.LongTensor(label_dataset), torch.LongTensor(target_dataset)

In [14]:
sentences = [
    ['咖哥 喜欢 小冰', 'KaGe likes XiaoBing'],
    ['我 爱 学习 人工智能', 'I love studying AI'],
    ['深度学习 改变 世界', 'DL changed the world'],
    ['自然语言处理 很 强大', 'NLP is powerful'],
    ['神将网络 非常 复杂', 'Neural-networks are complex']
]

In [15]:
corpus = TranslationCorpus(sentences)
dim_embedding = 512
dim_Q = dim_K = dim_V = 64
n_head = 8
dim_pwffn = 2048
n_layer = 6
batch_size = 3

In [16]:
model = Transformer(corpus, dim_embedding, dim_Q, dim_K, dim_V, n_head, dim_pwffn, n_layer)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [17]:
model.train()
epochs = 50
for epoch in range(epochs):
    optimizer.zero_grad()
    train_feature_dataset, train_label_dataset, train_target_dataset = corpus.batch_dataset(batch_size, True)
    train_decoder_logits, _, _, _ = model(train_feature_dataset, train_label_dataset)
    loss = criterion(train_decoder_logits.view(-1, corpus.len_target_vocabulary), train_target_dataset.view(-1))
    if (epoch + 1) % 1 == 0:
        print(f'Epoch:{epoch + 1} cost={loss:.6f}')
    loss.backward()
    optimizer.step()

Epoch:1 cost=3.378279
Epoch:2 cost=3.832541
Epoch:3 cost=2.703499
Epoch:4 cost=2.520695
Epoch:5 cost=2.721653
Epoch:6 cost=2.365467
Epoch:7 cost=2.063673
Epoch:8 cost=2.157883
Epoch:9 cost=1.965703
Epoch:10 cost=1.666100
Epoch:11 cost=1.383447
Epoch:12 cost=1.764755
Epoch:13 cost=1.557842
Epoch:14 cost=1.421220
Epoch:15 cost=1.245278
Epoch:16 cost=1.203177
Epoch:17 cost=1.257416
Epoch:18 cost=1.239156
Epoch:19 cost=1.131092
Epoch:20 cost=1.143828
Epoch:21 cost=1.085775
Epoch:22 cost=1.161703
Epoch:23 cost=1.035513
Epoch:24 cost=0.945574
Epoch:25 cost=0.963884
Epoch:26 cost=0.877381
Epoch:27 cost=0.896587
Epoch:28 cost=0.771263
Epoch:29 cost=0.758602
Epoch:30 cost=0.664176
Epoch:31 cost=0.653280
Epoch:32 cost=0.613909
Epoch:33 cost=0.655209
Epoch:34 cost=0.492556
Epoch:35 cost=0.405615
Epoch:36 cost=0.437012
Epoch:37 cost=0.459177
Epoch:38 cost=0.386451
Epoch:39 cost=0.455268
Epoch:40 cost=0.471318
Epoch:41 cost=0.391468
Epoch:42 cost=0.432507
Epoch:43 cost=0.344285
Epoch:44 cost=0.3562

In [18]:
model.eval()
eval_feature_dataset, eval_label_dataset, eval_target_dataset = corpus.batch_dataset(1, False)
eval_decoder_logits, eval_encoder_attn_weights, eval_decoder_attn_weights, eval_encoder_decoder_attn_weights = model(eval_feature_dataset, eval_label_dataset)
eval_decoder_logits = eval_decoder_logits.view(-1, corpus.len_target_vocabulary).max(1, keepdim=True)[1]
translate_sentence = [corpus.index_vocabulary_to_target_vocabulary[index.item()] for index in eval_decoder_logits.squeeze()]
input_sentence = ''.join(corpus.index_vocabulary_to_source_vocabulary[index.item()] for index in eval_feature_dataset[0])
print(input_sentence, '  ----->  ', translate_sentence)

我爱学习人工智能<pad>   ----->   ['I', 'I', 'I', 'AI', 'AI']
