# Transformer 架构

![](md-img/Transformer.png)

<br>

# Encoder Block的实现

<br>

### MultiHeadAttention Module

In [9]:
import torch
from torch import nn
import math

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, query_size, key_size, value_size, model_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.linear_q = nn.Linear(query_size, model_size)
        self.linear_k = nn.Linear(key_size, model_size)
        self.linear_v = nn.Linear(value_size, model_size)
        self.softmax = nn.Softmax(dim=-1)
        self.linear_o = nn.Linear(model_size, model_size)

    # query: (batch_size, num_querys, query_size)
    # key: (batch_size, num_pairs, key_size)
    # value: (batch_size, num_pairs, key_size)
    def forward(self, query, key, value):
        query = self.linear_q(query)    # (batch_size, num_querys, model_size)
        key = self.linear_q(key)        # (batch_size, num_pairs, model_size)
        value = self.linear_q(value)    # (batch_size, num_pairs, model_size)

        batch_size, num_querys, model_size = query.shape
        num_pairs = key.shape[1]
        dk = model_size / self.num_heads

        query = query.reshape(batch_size, num_querys, self.num_heads, -1)  # (batch_size, num_querys, num_heads, dk)
        key = key.reshape(batch_size, num_pairs, self.num_heads, -1)       # (batch_size, num_pairs, num_heads, dk)
        value = value.reshape(batch_size, num_pairs, self.num_heads, -1)   # (batch_size, num_pairs, num_heads, dk)
        query = query.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, num_querys, -1)  # (batch_size * num_heads, num_querys, dk)
        key = key.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, num_pairs, -1)       # (batch_size * num_heads, num_pairs, dk)
        value = value.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, num_pairs, -1)   # (batch_size * num_heads, num_pairs, dk)

        score = torch.bmm(query, key.permute(0, 2, 1)) / math.sqrt(dk)     # (batch_size * num_heads, num_querys, num_pairs)
        weight = self.softmax(score)         # (batch_size * num_heads, num_querys, num_pairs)
        output = torch.bmm(weight, value)    # (batch_size * num_heads, num_querys, dk)
        output = output.reshape(batch_size, self.num_heads, num_querys, -1)    # (batch_size, num_heads, num_querys, dk)
        output = output.permute(0, 2, 1, 3)  # (batch_size, num_querys, num_heads, dk)
        output = output.reshape(batch_size, num_querys, -1)    # (batch_size, num_querys, model_size)
        return self.linear_o(output)     # (batch_size, num_querys, model_size)

<br>

### FeedForward Module

In [None]:
class FeedForward(nn.Module):
    def __init__(self, model_size, hidden_size):
        super().__init__()
        self.linear1 = nn.Linear(model_size, hidden_size)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_size, model_size)

    # x: (batch_size, seq_len, model_size)
    def forward(self, x):
        x = self.linear1(x)   # (batch_size, seq_len, hidden_size)
        x = self.relu(x)      # (batch_size, seq_len, hidden_size)
        x = self.linear2(x)   # (batch_size, seq_len, model_size)
        return x

<br>

### EncoderBlock Module

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, model_size, num_heads, hidden_size, dropout):
        super().__init__()
        self.attention = MultiHeadAttention(model_size, model_size, model_size, model_size, num_heads)
        self.ffn = FeedForward(model_size, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(model_size)
        self.norm2 = nn.LayerNorm(model_size)


    # x: (batch_size, seq_len, model_size)
    def forward(self, x):
        temp = self.attention(x, x, x)
        x = x + self.dropout(temp)
        x = self.norm1(x)
        temp = self.ffn(x)
        x = x + self.dropout(temp)
        x = self.norm2(x)
        return x

<br>

# Decoder Block 的实现

<br>

### MaskedMultiHeadAttention Module

<br>

### EncoderDecoderAttention Module

<br>

### FeedForward Module

<br>

### DecoderBlock Module