### Transformer Decoder(Casual LM:self-attention+ffn)
![Transformer Decoder][def]

[def]: ./pics/Transformer_Decoder.png

In [11]:
import math
import torch
import torch.nn as nn

class SimpleDecoderLayer(nn.Module):
    def __init__(self, hidden_dim, head_num, attention_dropout = 0.1):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num

        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        self.att_dropout = nn.Dropout(attention_dropout)

        self.att_ln = nn.LayerNorm(hidden_dim, eps = 1e-12)

        # ffn(升维 -> 降维 -> att_ln)
        self.up_proj = nn.Linear(hidden_dim, hidden_dim * 4)
        self.down_proj = nn.Linear(hidden_dim * 4, hidden_dim)

        # 激活函数
        self.act_fn = nn.GELU()

        # dropout
        self.ffn_dropout = nn.Dropout(attention_dropout)
        self.ffn_ln = nn.LayerNorm(hidden_dim, eps = 1e-12)


    def attention_Layer(self, query, key, value, attention_mask = None):
        # output shape: [batch_size, seq_len, hidden_dim]
        # Q, K, V shape: [batch_size, head_num, seq_len, head_dim]
        key = key.transpose(2, 3)
        attention_weight = query @ key / math.sqrt(self.head_dim)

        # 自带的下三角矩阵以及attention_mask
        if attention_mask is not None:
            attention_mask = attention_mask.tril()
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float("-1e20")
            )
        else:
            attention_mask = torch.ones_like(attention_weight).tril()
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float("-1e20")
            )
        attention_weight =torch.softmax(attention_weight, dim = -1)
        attention_weight = self.att_dropout(attention_weight)

        out_mid = attention_weight @ value
        # [batch_size, head_num, seq_len, head_dim] 
        #              -> [batch_size, seq_len, hidden_dim]
        out_mid = out_mid.transpose(1, 2).contiguous()
        batch_size, seq_len, _, _ = out_mid.size()
        out_mid = out_mid.view(batch_size, seq_len, -1)

        out = self.out_proj(out_mid)
        return out
        

    def MHA(self, x, attention_mask = None):
        # x shape: [batch_size, seq_len, hidden_dim]
        batch_size, seq_len, _ = x.size()

        # Q, K, V shape: [batch_size, seq_len, hidden_dim]
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # [batch_size, seq_len, hidden_dim] -> [batch_size, head_num, seq_len, head_dim]
        # hidden_dim -> head_num * head_dim
        Q_state = Q.view(batch_size, seq_len, self.head_num, -1).transpose(1, 2)
        K_state = K.view(batch_size, seq_len, self.head_num, -1).transpose(1, 2)
        V_state = V.view(batch_size, seq_len, self.head_num, -1).transpose(1, 2)

        output = self.attention_Layer(Q_state, K_state, V_state, attention_mask)    
        # [batch_size, head_num, seq_len, head_dim] 
        #       -> [batch_size, seq_len, hidden_dim]
        return self.att_ln(x + output)


    def FFN(self, x):
        # ffn(升维 -> 降维 -> att_ln)
        up = self.up_proj(x)
        act_up = self.act_fn(up)
        down = self.down_proj(act_up)
        drop_down = self.ffn_dropout(down)
        # post LayerNorm
        x = self.ffn_ln(x + drop_down)
        return x


    def forward(self, x, attention_mask = None):
        x = self.MHA(x, attention_mask)
        x = self.FFN(x)
        return x


class Decoder(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.layer_list = nn.ModuleList(
            [
                SimpleDecoderLayer(64, 8) for _ in range(5)
            ]
        )
        self.emb = nn.Embedding(12, 64)
        self.out = nn.Linear(64, 12)

    def forward(self, x, mask):
        # (b, s)
        x = self.emb(x)
        for i, l in enumerate(self.layer_list):
            x = l(x, mask)
        print(x.shape)
        out = self.out(x)
        return torch.softmax(out, dim = -1)

# x = torch.rand(3, 4, 64)
x = torch.randint(low=0, high=12, size=(3, 4))

# net = SimpleDecoderLayer(64, 8)
net = Decoder()

mask = (
    torch.tensor([[1, 0, 0, 0],
                  [1, 1, 0, 0],
                  [1, 1, 1, 0]]).unsqueeze(1).unsqueeze(2).repeat(1, 8, 4, 1)
)
net(x, mask)

torch.Size([3, 4, 64])


tensor([[[0.0767, 0.1227, 0.0396, 0.0906, 0.2155, 0.0697, 0.0790, 0.0774,
          0.0394, 0.0785, 0.0360, 0.0748],
         [0.0293, 0.0605, 0.0523, 0.0345, 0.0697, 0.0578, 0.0306, 0.2561,
          0.2396, 0.0836, 0.0413, 0.0447],
         [0.0643, 0.0681, 0.0486, 0.0475, 0.2576, 0.0567, 0.0422, 0.0796,
          0.0195, 0.1085, 0.0461, 0.1613],
         [0.1664, 0.1270, 0.0334, 0.0332, 0.1078, 0.0468, 0.0539, 0.1207,
          0.1179, 0.0612, 0.0472, 0.0845]],

        [[0.0976, 0.1298, 0.0698, 0.0789, 0.0697, 0.1016, 0.0583, 0.1136,
          0.0700, 0.0611, 0.0662, 0.0835],
         [0.1013, 0.0590, 0.1061, 0.0553, 0.1135, 0.0456, 0.0824, 0.1160,
          0.0413, 0.1021, 0.0585, 0.1189],
         [0.1001, 0.1262, 0.1524, 0.0966, 0.0322, 0.0819, 0.0294, 0.0571,
          0.0670, 0.0922, 0.0446, 0.1201],
         [0.2437, 0.0558, 0.0503, 0.0646, 0.1308, 0.0480, 0.0422, 0.0509,
          0.0355, 0.1032, 0.0266, 0.1484]],

        [[0.1289, 0.0274, 0.0966, 0.0739, 0.1011, 0.0564, 0.