# Embedding

## WordEmbedding

In [None]:
import torch
import torch.nn as nn


class WordEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x: torch.Tensor):
        """x.shape=(batch_size,seq_len)"""
        return self.embedding(x)


batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)

vocab_size = 200
d_model = 6
word_embedding_cls = WordEmbedding(vocab_size=vocab_size, d_model=d_model)
word_embedding = word_embedding_cls(batch_sentences)
word_embedding

tensor([[[-0.2594,  0.4778,  0.8906,  0.6225, -0.6952, -0.0059],
         [ 0.4469, -0.0968, -0.6293,  0.4834,  0.9538,  1.4804],
         [-1.2663, -1.6941, -1.0667,  0.3993,  0.4741,  0.7654],
         [ 0.7726,  0.0513,  0.2658,  0.2169,  0.0769, -0.4826],
         [-0.5204, -0.7254, -1.4262,  0.1011,  0.0400,  0.0062],
         [-0.0191, -1.0718,  0.6736, -0.9083,  0.7761, -0.7482],
         [ 0.2084, -0.9778, -0.6865, -0.3744, -0.1471, -1.9092],
         [ 0.3867, -0.1872, -0.3631, -0.3801,  0.8723,  0.6902]],

        [[-0.2594,  0.4778,  0.8906,  0.6225, -0.6952, -0.0059],
         [ 0.2681,  0.4924, -0.3625,  0.9996, -0.7339,  0.2405],
         [-0.0191, -1.0718,  0.6736, -0.9083,  0.7761, -0.7482],
         [-1.2663, -1.6941, -1.0667,  0.3993,  0.4741,  0.7654],
         [ 0.5060, -0.6744,  0.5305, -0.6060,  0.8598, -1.4691],
         [ 0.2084, -0.9778, -0.6865, -0.3744, -0.1471, -1.9092],
         [ 0.3867, -0.1872, -0.3631, -0.3801,  0.8723,  0.6902],
         [ 0.3867, -0.1

## PositionalEmbedding

In [None]:
import math
import torch
import torch.nn as nn


class PositionalEmbedding(nn.Module):
    def __init__(self, max_len: int, d_model: int):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len)
        div_term = torch.exp(-torch.arange(0, d_model, 2) * math.log(10000.0) / d_model)
        freq = torch.outer(position, div_term)
        pe[:, 0::2] = torch.sin(freq)
        pe[:, 1::2] = torch.cos(freq)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor):
        """x.shape=(batch_size,seq_len,d_model)"""
        return x + self.pe[:, : x.shape[1]]

In [6]:
batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)

max_len = 512
vocab_size = 200
d_model = 6
word_embedding_cls = WordEmbedding(vocab_size=vocab_size, d_model=d_model)
word_embedding = word_embedding_cls(batch_sentences)
positional_embedding_cls = PositionalEmbedding(max_len, d_model)
positional_embedding = positional_embedding_cls(word_embedding)
positional_embedding

tensor([[[-1.9628,  1.0211, -1.3259,  0.5956,  0.5534,  0.9839],
         [-0.2777,  2.4332,  1.5308,  0.9200, -2.2194,  1.3981],
         [ 1.0854, -1.5181, -0.1493,  1.3841,  1.2800,  2.0469],
         [-0.6702, -1.9092, -1.5013,  1.0365, -1.3446,  0.6457],
         [-1.0493, -1.3474, -0.4281, -0.1104, -0.3846,  2.6994],
         [-0.2033, -0.1855,  0.3566,  0.0544, -0.2543,  2.5764],
         [ 0.9776,  1.3213, -0.5522,  0.2842, -0.9604,  1.7178],
         [ 1.4179,  0.6697,  0.1250,  0.6449, -0.5645, -1.6632]],

        [[-1.9628,  1.0211, -1.3259,  0.5956,  0.5534,  0.9839],
         [ 1.6129,  1.6560,  0.4810,  1.9849, -0.5798,  1.1292],
         [ 1.6649, -0.8853,  0.2193,  0.0769, -0.2608,  2.5764],
         [ 0.3172, -2.0919, -0.1032,  1.3787,  1.2821,  2.0469],
         [-1.7056,  0.1021,  2.1960,  1.7241, -1.0572,  2.1324],
         [ 0.2981,  0.6448, -0.5971,  0.2959, -0.9625,  1.7178],
         [ 0.4815,  0.8760,  0.0807,  0.6587, -0.5666, -1.6631],
         [ 1.4179,  0.6

## 合并embedding

In [None]:
import torch
import torch.nn as nn


class Embedding(nn.Module):
    def __init__(self, vocab_size: int, max_len: int, d_model: int):
        super().__init__()
        self.word_embedding = WordEmbedding(vocab_size, d_model)
        self.pos_embedding = PositionalEmbedding(max_len, d_model)

    def forward(self, x: torch.Tensor):
        """x.shape=(batch_size,seq_len)"""
        return self.pos_embedding(self.word_embedding(x))

In [None]:
batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)

max_len = 512
vocab_size = 200
d_model = 6
embedding_cls = Embedding(vocab_size=vocab_size, max_len=max_len, d_model=d_model)
embedding_results = embedding_cls(batch_sentences)
embedding_results

tensor([[[-0.4572, -0.2178,  0.0781,  1.8263, -0.5870,  2.0135],
         [ 1.6376, -0.4071, -1.5502,  0.9801,  2.0680,  0.6822],
         [ 1.1654,  0.3321, -1.0168, -0.1903, -1.4823,  1.2675],
         [ 1.0693, -0.2825,  0.6275,  0.8787,  1.8801,  0.6586],
         [-0.2748,  0.4763, -1.2725,  0.7864,  0.8565,  0.7479],
         [-1.1064, -0.9123,  0.5808, -1.3677, -0.0208,  1.3122],
         [ 0.3025,  1.2600, -0.3686,  0.1991,  2.8254,  1.5712],
         [ 1.0995,  1.8471,  0.9251,  1.8300,  0.0190,  1.8744]],

        [[-0.4572, -0.2178,  0.0781,  1.8263, -0.5870,  2.0135],
         [-0.0238,  0.1447, -0.3907,  0.5942,  0.3845,  0.1024],
         [ 0.7618, -1.6121,  0.4435, -1.3452, -0.0273,  1.3123],
         [ 0.3973, -0.2417, -0.9707, -0.1957, -1.4802,  1.2675],
         [-1.8862,  0.0429, -0.4639,  0.2320, -0.4949,  0.1739],
         [-0.3770,  0.5835, -0.4135,  0.2108,  2.8233,  1.5712],
         [ 0.1631,  2.0533,  0.8808,  1.8438,  0.0168,  1.8745],
         [ 1.0995,  1.8

# Mask

## padding mask

In [None]:
import torch
import torch.nn as nn


def padding_mask(x: torch.Tensor, pad_value: int = 0):
    return x != pad_value

In [None]:
batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)

padding_mask(batch_sentences, 0)

tensor([[ True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True]])

## sequence mask

In [None]:
import torch
import torch.nn as nn


def sequence_mask(size: int):
    size = (1, size, size)
    return torch.tril(torch.ones(size), diagonal=1) != 0

In [22]:
sequence_mask(5)

tensor([[[ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True]]])

## 合并mask

In [43]:
import torch
import torch.nn as nn


class Mask:
    def padding_mask(self, x: torch.Tensor, pad_value: int = 0):
        return (x != pad_value).unsqueeze(-2)

    def sequence_mask(self, size: int):
        size = (1, size, size)
        return torch.tril(torch.ones(size)) != 0

    def merge_mask(self, x: torch.Tensor, size: int, pad_value: int = 0):
        return self.padding_mask(x, pad_value) & self.sequence_mask(size)

In [None]:
batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)

mask = Mask()
mask.merge_mask(batch_sentences, batch_sentences.shape[-1])

tensor([[[ True, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True, False]],

        [[ True, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True, False, F

# MultiHeadAtten

In [58]:
import torch
import torch.nn as nn
import copy


def clone(layer, nums: int):
    return nn.ModuleList([copy.deepcopy(layer) for _ in range(nums)])


class MultiHeadAttention(nn.Module):
    def __init__(self, heads: int, d_model: int):
        super().__init__()
        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.d_model = d_model
        self.linears = clone(nn.Linear(d_model, d_model), 4)

    def forward(
        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: None
    ):
        """
        query,key,value.shape=(batch_size,seq_len,d_model)
        如果是编码器或者解码器，则mask.shape=(batch_size,1,seq_len)或(batch_size,seq_len,seq_len)
        如果是编码解码,这mask.shape=(batch_size,tgt_seq_len,src_seq_len)
        """
        if mask is not None:
            mask = mask.unsqueeze(1)

        batch_size = query.shape[0]
        query, key, value = [
            func(data).reshape(batch_size, -1, self.heads, self.d_k).transpose(1, 2)
            for func, data in zip(self.linears, (query, key, value))
        ]
        scores, p_atten = self._attention(query, key, value, mask)
        x = scores.transpose(1, 2).reshape(batch_size, -1, self.d_k * self.heads)
        return self.linears[-1](x)

    def _attention(
        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: None
    ):
        """
        query,key,value.shape=(batch_size,heads,seq_len,d_k)
        如果是编码器或者解码器，则mask.shape=(batch_size,1,1,seq_len)或(batch_size,1,seq_len,seq_len)
        如果是编码解码,这mask.shape=(batch_size,1,tgt_seq_len,src_seq_len)

        返回结果,如果是自注意力，则shape=(batch_size,heads,seq_len,seq_len)
        如果编解码的注意力，则shape=(batch_size,heads,tgt_seq_len,src_seq_len)
        """
        d_k = query.shape[-1]
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores.masked_fill(mask == 0, -1e9)
        p_atten = torch.softmax(scores, dim=-1)
        return torch.matmul(p_atten, value), p_atten

In [71]:
batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)

src_max_len = 512
src_vocab_size = 200
d_model = 6
src_heads = 2
embedding_cls = Embedding(
    vocab_size=src_vocab_size, max_len=src_max_len, d_model=d_model
)
embedding_results = embedding_cls(batch_sentences)


mask = Mask()
src_mask = mask.padding_mask(batch_sentences)

attention_cls = MultiHeadAttention(src_heads, d_model)
attention_cls(embedding_results, embedding_results, embedding_results, src_mask)

tensor([[[ 0.1197,  0.2151, -0.0453, -0.4315,  0.3465,  0.2176],
         [ 0.1192,  0.2337, -0.0570, -0.4271,  0.3186,  0.2096],
         [ 0.1178,  0.2306, -0.1040, -0.4975,  0.2167,  0.1857],
         [ 0.1357,  0.2449, -0.1105, -0.4672,  0.2813,  0.1615],
         [ 0.2392,  0.2149, -0.1005, -0.4600,  0.2652,  0.1909],
         [ 0.2547,  0.2442, -0.1612, -0.4851,  0.2334,  0.1312],
         [-0.0236,  0.2142, -0.0539, -0.5058,  0.2773,  0.2103],
         [ 0.1703,  0.2669, -0.1975, -0.5378,  0.1658,  0.0973]],

        [[ 0.1032,  0.2625, -0.1649, -0.5457,  0.1226,  0.1454],
         [ 0.3144,  0.3140, -0.2243, -0.4838,  0.1277,  0.0781],
         [ 0.1356,  0.2988, -0.2279, -0.5573, -0.0030,  0.1383],
         [ 0.0794,  0.3018, -0.2447, -0.5909, -0.0651,  0.1452],
         [-0.2205,  0.2848, -0.2742, -0.7478, -0.2236,  0.1613],
         [-0.0886,  0.2686, -0.2149, -0.6556, -0.0600,  0.1642],
         [ 0.1743,  0.3550, -0.3417, -0.6023, -0.0432,  0.0146],
         [ 0.1084,  0.3

# LayerNorm

In [74]:
import torch
import torch.nn as nn


class LayerNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gama = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x: torch.Tensor):
        """x.shape=(batch_size,seq_len,d_model)"""
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True)
        return (x - mean) / (torch.sqrt(var + self.eps)) * self.gama + self.beta

In [None]:
batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)

src_max_len = 512
src_vocab_size = 200
d_model = 6
src_heads = 2
embedding_cls = Embedding(
    vocab_size=src_vocab_size, max_len=src_max_len, d_model=d_model
)
embedding_results = embedding_cls(batch_sentences)


mask = Mask()
src_mask = mask.padding_mask(batch_sentences)

attention_cls = MultiHeadAttention(src_heads, d_model)
attention_cls_output = attention_cls(
    embedding_results, embedding_results, embedding_results, src_mask
)

layer_norm_cls = LayerNorm(d_model)
layer_norm_cls(attention_cls_output)

tensor([[[-0.9220,  0.3081, -0.5607,  0.8963, -1.0646,  1.3430],
         [ 0.0922,  0.7240, -0.3764,  0.1275, -1.7239,  1.1566],
         [-0.7673,  0.4614, -0.5221,  0.6954, -1.2439,  1.3765],
         [-0.6938,  0.4627, -0.5116,  0.6840, -1.3073,  1.3660],
         [-0.3658,  0.5307, -0.5158,  0.5744, -1.5195,  1.2961],
         [-0.8288,  0.3166, -0.5720,  0.9063, -1.1456,  1.3235],
         [-0.9503,  0.3180, -0.5472,  0.8776, -1.0492,  1.3511],
         [-0.6751,  0.5215, -0.4606,  0.5788, -1.3470,  1.3823]],

        [[-0.9624,  0.2757, -0.5563,  0.9527, -1.0250,  1.3153],
         [-0.9486,  0.2419, -0.6616,  1.0943, -0.9513,  1.2254],
         [-0.9434,  0.2421, -0.6543,  1.0843, -0.9622,  1.2335],
         [-0.9538,  0.2862, -0.6569,  1.0001, -0.9612,  1.2856],
         [-0.9726,  0.1985, -0.6677,  1.1343, -0.9033,  1.2108],
         [-0.9713,  0.2723, -0.5699,  0.9756, -1.0072,  1.3005],
         [-0.9408,  0.2992, -0.5944,  0.9555, -1.0258,  1.3064],
         [-0.9344,  0.3

# FeedForward

In [None]:
import torch
import torch.nn as nn


class FeedForward(nn.Module):
    def __init__(self, d_model: int, hidden_dim: int):
        super().__init__()
        self.linear1 = nn.Linear(d_model, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, d_model)
        self.d_model = d_model
        self.hidden_dim = hidden_dim

    def forward(self, x: torch.Tensor):
        return self.linear2(torch.relu(self.linear1(x)))

In [None]:
batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)

src_max_len = 512
src_vocab_size = 200
d_model = 6
hidden_dim = 24
src_heads = 2
embedding_cls = Embedding(
    vocab_size=src_vocab_size, max_len=src_max_len, d_model=d_model
)
embedding_results = embedding_cls(batch_sentences)


mask = Mask()
src_mask = mask.padding_mask(batch_sentences)

attention_cls = MultiHeadAttention(src_heads, d_model)
attention_cls_output = attention_cls(
    embedding_results, embedding_results, embedding_results, src_mask
)

layer_norm_cls = LayerNorm(d_model)
layer_norm_output = layer_norm_cls(attention_cls_output)

feedforward_cls = FeedForward(d_model, hidden_dim)
feedforward_cls(layer_norm_output)

tensor([[[ 0.2970, -0.2642, -0.1333,  0.1346, -0.0407, -0.0441],
         [ 0.2865, -0.2142, -0.1056,  0.0586, -0.0690,  0.0267],
         [ 0.2345, -0.2826, -0.2156,  0.3108,  0.0380, -0.1706],
         [ 0.0659, -0.1865, -0.2565,  0.2692,  0.0484, -0.2983],
         [ 0.1875, -0.2069, -0.1105,  0.1082,  0.1196, -0.2013],
         [ 0.2690, -0.2611, -0.1498,  0.1757,  0.0030, -0.1001],
         [ 0.3037, -0.2385, -0.1153,  0.0829, -0.0438, -0.0102],
         [ 0.2634, -0.1494, -0.1002,  0.0008, -0.0661,  0.0206]],

        [[ 0.2481, -0.1859, -0.1566,  0.1825, -0.1329,  0.1252],
         [ 0.4056, -0.0463, -0.2222, -0.0841, -0.1300, -0.0345],
         [ 0.2658, -0.1077, -0.1882,  0.1676, -0.1329,  0.1681],
         [ 0.0838, -0.1866, -0.3830,  0.5116,  0.0107, -0.0707],
         [ 0.0306, -0.1580, -0.3514,  0.4308,  0.0221, -0.1619],
         [ 0.2467, -0.2389, -0.1472,  0.1849, -0.1098,  0.0474],
         [ 0.3705, -0.0655, -0.1710, -0.0632, -0.1329, -0.0197],
         [ 0.3861, -0.0

# 编码器层

In [None]:
import torch
import torch.nn as nn


class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, hidden_dim: int, heads: int):
        super().__init__()
        self.multi_head_atten = MultiHeadAttention(heads, d_model)
        self.layer_norm = clone(LayerNorm(d_model), 2)
        self.feedforward = FeedForward(d_model, hidden_dim)

    def forward(self, x: torch.Tensor, mask=None):
        atten_outputs = self.multi_head_atten(x, x, x, mask)
        x += self.layer_norm[0](atten_outputs)
        feedforward_outputs = self.feedforward(x)
        x += self.layer_norm[1](feedforward_outputs)
        return x

In [None]:
batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)

src_max_len = 512
src_vocab_size = 200
d_model = 6
hidden_dim = 24
src_heads = 2
embedding_cls = Embedding(
    vocab_size=src_vocab_size, max_len=src_max_len, d_model=d_model
)
embedding_results = embedding_cls(batch_sentences)


mask = Mask()
src_mask = mask.padding_mask(batch_sentences)

encoder_layer = EncoderLayer(d_model, hidden_dim, src_heads)
encoder_layer(embedding_results, src_mask)

tensor([[[ 2.7471, -0.5135,  2.5910, -2.4158,  1.9173,  1.9910],
         [ 1.1906,  0.9480,  3.0668, -0.0510, -0.3191,  1.4701],
         [ 2.3394, -2.0119,  3.5881,  2.0119, -1.8817, -0.7396],
         [ 1.9286, -1.9994,  2.3473,  2.6664, -3.2357, -1.0885],
         [ 1.1343, -3.0320,  0.3227,  0.8000, -0.6706, -1.6313],
         [-0.4487, -1.5924,  0.3783,  2.8060, -2.5042, -1.9292],
         [ 0.7252,  3.0834,  0.6972,  1.5750,  0.1807, -0.7911],
         [ 1.4904,  1.9779,  0.5275, -1.0639,  0.4326, -1.0736]],

        [[ 2.2483,  0.2041,  2.0298, -2.3571,  1.9032,  2.2888],
         [ 2.4459, -2.1319,  0.4876,  3.9433, -2.5750, -1.3785],
         [ 1.7332, -2.0767,  0.0889,  2.6249, -2.6966, -1.9166],
         [ 2.0567, -2.5172,  2.6904,  3.0931, -2.1513, -1.1646],
         [ 0.8407, -2.6807,  1.2238,  3.7096, -2.4359, -0.4528],
         [-0.4456,  1.0478,  1.7615,  3.1917, -1.1143, -0.3622],
         [ 0.4674,  2.0756,  0.3029, -0.2607,  0.1678, -1.2248],
         [ 1.4184,  2.0

# Encoder

In [None]:
import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, layer: EncoderLayer, nums: int):
        super().__init__()
        self.layers = clone(layer, nums)

    def forward(self, x: torch.Tensor, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return x

In [None]:
batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)

src_max_len = 512
src_vocab_size = 200
d_model = 6
hidden_dim = 24
src_heads = 2
layer_nums = 8
embedding_cls = Embedding(
    vocab_size=src_vocab_size, max_len=src_max_len, d_model=d_model
)
embedding_results = embedding_cls(batch_sentences)


mask = Mask()
src_mask = mask.padding_mask(batch_sentences)

encoder_layer = EncoderLayer(d_model, hidden_dim, src_heads)
encoder = Encoder(encoder_layer, layer_nums)
encoder(embedding_results, src_mask)

tensor([[[ -4.6929, -14.1930,   8.5558,  13.6445,  -7.7855,   6.1300],
         [ -9.1643, -14.2703,  11.4374,  13.1600,  -2.4385,   4.1526],
         [ -9.6081, -14.7413,  10.7707,  13.3700,  -6.8403,   8.1122],
         [ -7.3210, -15.1406,   8.8960,  12.9493,  -7.5348,   9.9626],
         [ -7.4417, -12.3003,   8.1422,  10.1414,  -8.1089,   7.5740],
         [-19.6019, -14.0036,  16.9428,  11.4761,   1.7704,   5.4786],
         [-16.9150, -13.0332,  17.0304,  12.7865,   2.8235,   4.9336],
         [ -8.6086, -13.7271,  13.2326,  14.8531,  -1.7082,   3.7085]],

        [[ -4.6210, -14.7225,   8.3883,  13.6017,  -6.7144,   5.7269],
         [ -9.5332, -13.4827,   7.9000,  10.6679,  -1.5611,   8.5812],
         [-13.6171, -16.5595,  14.8419,  12.9995,   1.3102,   4.1346],
         [ -9.6144, -15.8478,  10.7009,  13.0335,  -7.8042,   9.2960],
         [ -8.4473, -14.8369,  11.9023,  14.1387,  -6.4568,   6.5284],
         [-16.9718, -14.2307,  16.6397,  12.4862,   2.5405,   5.7704],
    

# 解码器

In [None]:
import torch
import torch.nn as nn


class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, hidden_dim: int, heads: int):
        super().__init__()
        self.multi_head_attens = clone(MultiHeadAttention(heads, d_model), 2)
        self.layer_norms = clone(LayerNorm(d_model), 3)
        self.feedforward = FeedForward(d_model, hidden_dim)

    def forward(
        self, memory: torch.Tensor, x: torch.Tensor, tgt_mask=None, src_mask=None
    ):
        masked_atten_outputs = self.multi_head_attens[0](x, x, x, tgt_mask)
        x += self.layer_norms[0](masked_atten_outputs)
        multi_atten_outputs = self.multi_head_attens[1](x, memory, memory, src_mask)
        x += self.layer_norms[1](multi_atten_outputs)
        feed_forward_outputs = self.feedforward(x)
        x += self.layer_norms[2](feed_forward_outputs)
        return x

In [94]:
batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)
decoder_inputs = torch.tensor(
    [
        [101, 13, 6, 9, 2, 7, 102, 0],
        [101, 10, 18, 5, 3, 102, 0, 0],
        [101, 1, 4, 5, 19, 2, 23, 102],
    ]
)

src_max_len = 512
src_vocab_size = 200
d_model = 6
hidden_dim = 24
src_heads = 2
layer_nums = 8
src_embedding_cls = Embedding(
    vocab_size=src_vocab_size, max_len=src_max_len, d_model=d_model
)
src_embedding_results = src_embedding_cls(batch_sentences)
tgt_embedding_cls = Embedding(
    vocab_size=src_vocab_size, max_len=src_max_len, d_model=d_model
)
tgt_embedding_results = tgt_embedding_cls(decoder_inputs)


mask = Mask()
src_mask = mask.padding_mask(batch_sentences)

encoder_layer = EncoderLayer(d_model, hidden_dim, src_heads)
encoder = Encoder(encoder_layer, layer_nums)
encoder_outputs = encoder(src_embedding_results, src_mask)

decoder_layer_cls = DecoderLayer(d_model, hidden_dim, src_heads)
decoder_layer_cls(encoder_outputs, tgt_embedding_results, src_mask)

tensor([[[-0.2220,  0.7314, -3.4175,  0.6165,  0.4795,  0.9315],
         [ 0.5121,  1.1963, -3.1803,  0.6147,  3.7034, -0.6640],
         [ 0.2788,  2.0792, -3.8357,  0.9190,  2.3652,  0.5567],
         [-1.8392,  2.5938, -2.4730,  0.4890,  2.3009,  0.9286],
         [-0.7027,  1.3716, -1.7195,  1.6228,  2.4574,  1.1560],
         [-2.0800,  1.4047,  0.6848,  0.3811,  3.7724,  2.4783],
         [-1.3923,  0.5486, -1.3639,  1.8846,  2.0690,  1.3879],
         [-0.4787,  2.1709, -1.7801,  2.5245,  1.9235,  0.5721]],

        [[-0.1645,  0.9530, -3.6034,  0.6343,  0.5244,  0.7756],
         [-0.6099,  3.3170, -2.5119, -0.3787,  1.6383, -0.2641],
         [-0.2875,  2.3824, -2.3183, -0.4385,  1.0352, -0.5757],
         [-0.0983,  2.9323, -0.8020,  1.7290,  2.6300,  0.7892],
         [-0.8780,  1.0525, -2.5832, -0.1945,  1.9199, -0.4407],
         [-2.1099, -0.1588, -1.7264,  2.0556,  2.1696,  1.5126],
         [-1.5402,  2.3642, -1.9045,  2.7316,  1.9534,  0.5649],
         [-0.5713,  2.3

# 解码器

In [None]:
import torch
import torch.nn as nn


class Decoder(nn.Module):
    def __init__(self, layer: DecoderLayer, nums: int):
        super().__init__()
        self.layers = clone(layer, nums)

    def forward(
        self, memory: torch.Tensor, x: torch.Tensor, tgt_mask=None, src_mask=None
    ):
        for layer in self.layers:
            x = layer(memory, x, tgt_mask, src_mask)
        return x

In [None]:
batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)
decoder_inputs = torch.tensor(
    [
        [101, 13, 6, 9, 2, 7, 102, 0],
        [101, 10, 18, 5, 3, 102, 0, 0],
        [101, 1, 4, 5, 19, 2, 23, 102],
    ]
)

src_max_len = 512
src_vocab_size = 200
d_model = 6
hidden_dim = 24
src_heads = 2
layer_nums = 8
src_embedding_cls = Embedding(
    vocab_size=src_vocab_size, max_len=src_max_len, d_model=d_model
)
src_embedding_results = src_embedding_cls(batch_sentences)
tgt_embedding_cls = Embedding(
    vocab_size=src_vocab_size, max_len=src_max_len, d_model=d_model
)
tgt_embedding_results = tgt_embedding_cls(decoder_inputs)


mask = Mask()
src_mask = mask.padding_mask(batch_sentences)

encoder_layer = EncoderLayer(d_model, hidden_dim, src_heads)
encoder = Encoder(encoder_layer, layer_nums)
encoder_outputs = encoder(src_embedding_results, src_mask)

decoder_layer_cls = DecoderLayer(d_model, hidden_dim, src_heads)
decoder = Decoder(decoder_layer_cls, layer_nums)
decoder(encoder_outputs, tgt_embedding_results, src_mask)

tensor([[[ 21.2602,   9.2331, -24.4601,  -1.1073, -13.6907,  13.5291],
         [ 22.5556,   8.9649, -23.6728,  -3.4604, -18.6228,  10.3777],
         [ 23.7073,   8.2565, -21.5613,  -2.2190, -16.2544,  10.4904],
         [ 18.9561,  11.7710, -24.0340,  -1.1761, -13.7383,  11.1137],
         [ 21.6701,  10.4643, -24.0852,  -2.1130, -15.9563,  10.4645],
         [ 21.3494,   8.6791, -23.3579,  -0.9561, -12.1208,  12.1037],
         [ 22.1133,  10.8962, -23.5492,  -1.7947, -15.1828,  10.3826],
         [ 21.8016,   8.5705, -22.9546,  -2.4148, -13.9840,  12.1114]],

        [[ 22.5688,   6.3563, -22.4941,  -0.2357, -13.8495,  12.4185],
         [ 21.1378,   5.2267, -23.8733,  -0.1486, -12.6912,  14.1384],
         [ 21.7826,   6.0262, -21.8250,  -0.9273, -14.1055,  12.4486],
         [ 22.0722,   8.3679, -20.9050,  -1.2527, -15.2012,   9.1900],
         [ 22.8378,  10.7941, -19.0810,  -1.0239, -20.1041,   7.5475],
         [ 23.6377,   7.9679, -20.6503,  -1.2229, -16.4547,   8.1964],
    

# 生成器

In [105]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Generator(nn.Module):
    def __init__(self, vocab_size: int, d_model: int):
        super().__init__()
        self.generator = nn.Linear(d_model, vocab_size)

    def forward(self, x: torch.Tensor):
        """x.shape=(batch_size,seq_len,d_model)"""
        return F.log_softmax(self.generator(x), dim=-1)

# Transormer

In [109]:
import torch
import torch.nn as nn


class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size: int,
        tgt_vocab_size: int,
        src_d_model: int,
        tgt_d_model: int,
        src_hidden_dim: int,
        tgt_hidden_dim: int,
        src_heads: int,
        tgt_heads: int,
        layer_nums: int,
        max_len: int,
    ):
        super().__init__()
        self.encoder_layer = EncoderLayer(src_d_model, src_hidden_dim, src_heads)
        self.decoder_layer = DecoderLayer(tgt_d_model, tgt_hidden_dim, tgt_heads)
        self.encoder = Encoder(self.encoder_layer, layer_nums)
        self.decoder = Decoder(self.decoder_layer, layer_nums)
        self.src_embedding = Embedding(src_vocab_size, max_len, src_d_model)
        self.tgt_embedding = Embedding(tgt_vocab_size, max_len, tgt_d_model)
        self.generator = Generator(tgt_vocab_size, tgt_d_model)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        src.shape=(batch_size,src_seq_len)
        tgt.shape=(batch_size,tgt_seq_len)
        """
        src_embed = self.src_embedding(src)
        tgt_embed = self.tgt_embedding(tgt)
        encoder_outputs = self.encoder(src_embed, src_mask)
        decoder_outputs = self.decoder(encoder_outputs, tgt_embed, tgt_mask, src_mask)
        return self.generator(decoder_outputs)

In [None]:
batch_sentences = torch.tensor(
    [
        [101, 3, 2, 5, 7, 8, 102, 0],
        [101, 13, 8, 2, 9, 102, 0, 0],
        [101, 21, 8, 15, 9, 7, 13, 102],
    ]
)
decoder_inputs = torch.tensor(
    [
        [101, 13, 6, 9, 2, 7, 102, 0],
        [101, 10, 18, 5, 3, 102, 0, 0],
        [101, 1, 4, 5, 19, 2, 23, 102],
    ]
)
src_vocab_size = 200
tgt_vocab_size = 240
src_d_model = 8
tgt_d_model = 8
src_hidden_dim = 24
tgt_hidden_dim = 24
src_heads = 4
tgt_heads = 4
layer_nums = 8
max_len = 512


mask = Mask()
src_mask = mask.padding_mask(batch_sentences)
tgt_mask = mask.merge_mask(decoder_inputs, decoder_inputs.shape[-1])

transformer = Transformer(
    src_vocab_size,
    tgt_vocab_size,
    src_d_model,
    tgt_d_model,
    src_hidden_dim,
    tgt_hidden_dim,
    src_heads,
    tgt_heads,
    layer_nums,
    max_len,
)
transformer(batch_sentences, decoder_inputs, tgt_mask, src_mask)

tensor([[[-23.6987, -22.2254, -22.7746,  ..., -23.5941, -25.6780, -46.6218],
         [-30.9086, -26.8903, -22.5864,  ..., -16.3116, -21.9060, -56.5316],
         [-25.2848, -21.4719, -19.5253,  ..., -14.5516, -15.5766, -48.5730],
         ...,
         [-25.4831, -22.5566, -18.8937,  ..., -12.8312, -15.1014, -46.1652],
         [-23.3707, -23.2196, -22.7867,  ..., -22.6124, -24.8479, -46.8203],
         [-25.8424, -24.0414, -25.2619,  ..., -22.7277, -28.8318, -49.9659]],

        [[-16.7712, -17.3099, -28.5644,  ..., -31.7214, -30.8977, -21.3963],
         [-13.9675, -15.1343, -29.1793,  ..., -27.7112, -32.6018, -19.2178],
         [-31.4915, -25.8899, -21.2553,  ..., -13.3160, -20.6630, -52.3194],
         ...,
         [-20.3351, -22.2516, -27.1599,  ..., -34.2610, -32.0830, -33.2498],
         [-20.7398, -20.4035, -27.5340,  ..., -28.5940, -34.9267, -36.7387],
         [-21.0911, -19.9814, -27.4198,  ..., -27.1833, -34.8969, -37.5490]],

        [[-13.5771, -16.1949, -25.2118,  ...