In [285]:
import math
import torch.nn as nn
import torch
import torch.nn.functional as F

### Position-wise Feed-Forward Networks

$$ FFN(x) = \max(0, xW_1 + b_1)W_2 + b_2 $$

In [286]:
class PositionWiseFFN(nn.Module):
    """Position-wise Feed-Forward Networks"""

    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs):
        super(PositionWiseFFN, self).__init__()
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

In [287]:
ffn = PositionWiseFFN(4, 5, 8)
ffn.eval()
ffn(torch.ones((2, 3, 4))).shape

torch.Size([2, 3, 8])

In [288]:
class AddNorm(nn.Module):
    """残差连接和层归一化"""

    def __init__(self, normalized_shape, dropout):
        super(AddNorm, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)  # Normalized_shape is input.size()[1:]

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

In [289]:
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape

torch.Size([2, 3, 4])

In [290]:
def masked_softmax(X, valid_lens=None):
    """通过在最后⼀个轴上遮蔽元素来执⾏softmax操作"""

    def sequence_mask(X, valid_len, value=0):
        """Mask irrelevant entries in sequences"""
        maxlen = X.size(1)
        mask = torch.arange(maxlen, device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = value
        return X

    if valid_lens is None:
        return F.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 被遮蔽的元素使⽤⼀个非常大的负值替换,使其softmax输出为0
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                          value=-1e6)
        return F.softmax(X.reshape(shape), dim=-1)


class DotProductAttention(nn.Module):
    """缩放点积注意力"""

    def __init__(self, dropout):
        super(DotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        # queries.shape = (b, ?q, d)
        # keys.shape = (b, ?k, d)
        # scores.shape = (b, ?q, d) x (b, d, ?k) = (b, ?q, ?k)
        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # values.shape=(b, ?k, ?v)
        # 返回值.shape=(b, ?q, ?k) x (b, ?k, ?v) = (b, ?q, ?v)
        return torch.bmm(self.dropout(self.attention_weights), values)


class MultiHeadAttention(nn.Module):
    """多头注意力"""

    def __init__(self, query_size, key_size, value_size,
                 num_heads, dropout, bias=False):  # 模仿pytorch的参数组成
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        assert query_size % num_heads == 0, "query_size must be divisible by num_heads"
        # query_size(即E_q)
        # key_size(即E_k)
        # value_size(即E_v)
        self.W_q = nn.Linear(query_size, query_size, bias=bias)
        self.W_k = nn.Linear(key_size, query_size, bias=bias)
        self.W_v = nn.Linear(value_size, query_size, bias=bias)
        self.W_o = nn.Linear(query_size, query_size, bias=bias)
        self.attention = DotProductAttention(dropout)

    @staticmethod
    def transpose_qkv(X, num_heads):
        # 输入:X.shape=(N, L or S, E_q)
        # X.shape=(N, L or S, num_heads, E_q / num_heads)
        X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
        # X.shape=(N, num_heads, L or S, E_q / num_heads)
        X = X.permute(0, 2, 1, 3)
        # 返回值.shape=(N * num_heads, L or S, E_q / num_heads)
        return X.reshape(-1, X.shape[2], X.shape[3])

    def forward(self, queries, keys, values, valid_lens):
        # queries.shape=(N, L, E_q)
        # self.W_q(queries).shape=(N, L, E_q)
        # queries.shape=(N * num_heads, L, E_q / num_heads)

        # keys.shape=(N, S, E_k)
        # self.W_k(queries).shape=(N, S, E_q)
        # keys.shape=(N * num_heads, S, E_q / num_heads)

        # values.shape=(N, S, E_v)
        # self.W_v(values).shape=(N, S, E_q)
        # values.shape=(N * num_heads, S, E_q / num_heads)
        queries = self.transpose_qkv(self.W_q(queries), self.num_heads)
        keys = self.transpose_qkv(self.W_k(keys), self.num_heads)
        values = self.transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # E_q维度信息增加到batch_size维度上
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

        # output.shape=(N * num_heads, L, E_q / num_heads)
        output = self.attention(queries, keys, values, valid_lens)
        # output.shape=(N, num_heads, L, E_q / num_heads)
        output = output.reshape(-1, self.num_heads, output.shape[1], output.shape[2])
        # output.shape=(N, L, num_heads, E_q / num_heads)
        output = output.permute(0, 2, 1, 3)
        # output.shape=(N, L, E_q)
        output_concat = output.reshape(output.shape[0], output.shape[1], -1)
        # 返回值.shape=(N, L, E_q)
        return self.W_o(output_concat)

In [291]:
class EncoderBlock(nn.Module):
    """transformer编码器Block"""

    def __init__(self, query_size, key_size, value_size, norm_shape,
                 ffn_num_hiddens, num_heads, dropout, use_bias=False):
        super(EncoderBlock, self).__init__()
        self.attention = MultiHeadAttention(query_size, key_size, value_size, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        # 多头注意力`forward`返回值的shape为:(N, L, E_q)
        self.ffn = PositionWiseFFN(query_size, ffn_num_hiddens, query_size)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))  # 多头自注意力
        return self.addnorm2(Y, self.ffn(Y))

In [292]:
encoder_blk = EncoderBlock(query_size=24,
                           key_size=24,
                           value_size=24,
                           norm_shape=[100, 24],
                           ffn_num_hiddens=48,
                           num_heads=8,
                           dropout=0.5)
encoder_blk.eval()

X = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
# 可以看出Transformer编码器中的任何层都不会改变其输⼊的形状
print(encoder_blk(X, valid_lens).shape)

torch.Size([2, 100, 24])


In [293]:
class PositionalEncoding(nn.Module):
    """位置编码"""

    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2,
                                                                                                      dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

In [294]:
class TransformerEncoder(nn.Module):
    """transformer编码器"""

    def __init__(self, vocab_size, query_size, key_size, value_size, norm_shape,
                 ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False):
        super(TransformerEncoder, self).__init__()
        self.query_size = query_size
        self.pos_encoding = PositionalEncoding(query_size, dropout)
        # vocab_size:单词表的单词数目
        self.embedding = nn.Embedding(vocab_size, query_size)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block" + str(i), EncoderBlock(query_size=query_size,
                                                                key_size=key_size,
                                                                value_size=value_size,
                                                                norm_shape=norm_shape,
                                                                ffn_num_hiddens=ffn_num_hiddens,
                                                                num_heads=num_heads,
                                                                dropout=dropout,
                                                                use_bias=use_bias))

    def forward(self, X, valid_lens):
        # 位置编码值在-1和1之间,因此嵌入值乘以嵌⼊维度的平方根进行缩放
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.query_size))
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X

In [295]:
encoder = TransformerEncoder(
    vocab_size=200,
    query_size=24,
    key_size=24,
    value_size=24,
    norm_shape=[100, 24],
    ffn_num_hiddens=48,
    num_heads=8,
    num_layers=2,
    dropout=0.5)

encoder.eval()
valid_lens = torch.tensor([3, 2])
encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape

torch.Size([2, 100, 24])

In [296]:
class DecoderBlock(nn.Module):
    """transformer解码器Block"""

    def __init__(self, query_size, key_size, value_size, norm_shape,
                 ffn_num_hiddens, num_heads, dropout, i):
        super(DecoderBlock, self).__init__()
        self.i = i
        self.attention1 = MultiHeadAttention(
            query_size=query_size, key_size=key_size,
            value_size=value_size, num_heads=num_heads, dropout=dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(
            query_size=query_size, key_size=key_size,
            value_size=value_size, num_heads=num_heads, dropout=dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(query_size, ffn_num_hiddens, query_size)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), dim=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        # 键和值均为编码器的输出
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

In [297]:
decoder_blk = DecoderBlock(query_size=24, key_size=24,
                           value_size=24, norm_shape=[100, 24],
                           ffn_num_hiddens=48, num_heads=8, dropout=0.5, i=0)

decoder_blk.train()
X = torch.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
decoder_blk(X, state)[0].shape

torch.Size([2, 100, 24])

In [298]:
class TransformerDecoder(nn.Module):
    """transformer解码器"""

    def __init__(self, vocab_size, query_size, key_size, value_size, norm_shape,
                 ffn_num_hiddens, num_heads, num_layers, dropout):
        super(TransformerDecoder, self).__init__()
        self.query_size = query_size
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, query_size)
        self.pos_encoding = PositionalEncoding(query_size, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block" + str(i), DecoderBlock(
                query_size=query_size,
                key_size=key_size,
                value_size=value_size,
                norm_shape=norm_shape,
                ffn_num_hiddens=ffn_num_hiddens,
                num_heads=num_heads,
                dropout=dropout,
                i=i))
        self.dense = nn.Linear(query_size, vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]

    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.query_size))
        self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            self._attention_weights[0][i] = blk.attention1.attention.attention_weights
            self._attention_weights[1][i] = blk.attention2.attention.attention_weights
        return self.dense(X), state

    @property
    def attention_weights(self):
        return self._attention_weights

In [299]:
num_heads, num_layers, dropout = 4, 2, 0.1
key_size, query_size, value_size = 24, 24, 24
ffn_num_hiddens = 48
norm_shape = [24]

decoder = TransformerDecoder(
    vocab_size=1000,
    query_size=query_size,
    key_size=key_size,
    value_size=value_size,
    norm_shape=norm_shape,
    ffn_num_hiddens=ffn_num_hiddens,
    num_heads=num_heads,
    num_layers=num_layers,
    dropout=dropout)

valid_lens = torch.tensor([3, 2])
X = torch.ones((2, 100), dtype=torch.long)
state = decoder.init_state(encoder(X, valid_lens), valid_lens)
output, state = decoder(X, state)
print(output.shape)

torch.Size([2, 100, 1000])


In [300]:
class EncoderDecoder(nn.Module):
    """The base class for the encoder-decoder architecture"""

    def __init__(self, encoder, decoder):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X):
        enc_outputs = self.encoder(enc_X)
        dec_state = self.decoder.init_state(enc_outputs)
        return self.decoder(dec_X, dec_state)
