In [None]:
import math
import torch
import torch.nn as nn

In [None]:
class SimpleDecoderLayer(nn.Module):
    """Some Information about SimpleDecoderLayer"""
    def __init__(self, hidden_dim, head_num, attention_dropout_rate=0.1, ffn_dropout_rate=0.1):
        super().__init__()
        self.head_num = head_num
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // head_num
        
        self.attention_layer_norm = nn.LayerNorm(hidden_dim)
        self.ffn_layer_norm = nn.LayerNorm(hidden_dim)
        
        
        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
        
        
        self.attention_dropout = nn.Dropout(attention_dropout_rate)
        self.ffn_dropout = nn.Dropout(ffn_dropout_rate)
        
        self.up_dim = nn.Linear(hidden_dim, hidden_dim * 4)
        self.down_dim = nn.Linear(hidden_dim * 4, hidden_dim)
        
        self.activation_layer = nn.GELU()
        
    def ffn(self, x):
        
        # * Activation_function只在FFN部分出现
        up = self.up_dim(x)
        up = self.activation_layer(up)
        down = self.down_dim(up)
        
        output = self.ffn_dropout(down)
        
        return self.ffn_layer_norm(x + output)
    
    # * 之所以设计一个attention_layer是为了进行复用
    def attention_layer(self, query, key, value, attention_mask=None):
        
        # * shape is (batch, head_num, seq, seq)
        attention_matrix = query @ key.transpose(-1, -2) / math.sqrt(self.head_dim)
        
        # * tril输出一个下三角矩阵，上三角均为0
        if attention_mask is not None:
            attention_mask = attention_mask.tril()
            
            attention_matrix = attention_matrix.masked_fill(attention_mask == 0, float("-inf"))
        
        else:
            attention_mask = torch.ones_like(attention_matrix).tril()
            
            attention_matrix = attention_matrix.masked_fill(attention_mask == 0, float("-inf"))
        
        
        attention_weights = torch.softmax(attention_matrix, dim=-1)
        print(attention_weights)
        attention_weights = self.attention_dropout(attention_weights)
        
        mid_output = (attention_weights @ value).transpose(1,2).contiguous()
        
        batch, seq_len, _, _ = mid_output.size()
        
        mid_output = mid_output.view(batch, seq_len, -1)
        
        # * output shape is (batch, seq_len, hiddendim)
        output = self.output_proj(mid_output)
        
        return output
        
        
    def MultiHeadAttention(self, x, mask=None):
        
        batch, seq_len, _ = x.size()
        
        # * shape is (batch, head_num, seq_len, head_dim)
        query_head_state = self.query_proj(x).view(batch, seq_len, self.head_num, self.head_dim).transpose(2,1)
        key_head_state = self.key_proj(x).view(batch, seq_len, self.head_num, self.head_dim).transpose(2,1)
        value_head_state = self.value_proj(x).view(batch, seq_len, self.head_num, self.head_dim).transpose(2,1)
        
        
        output = self.attention_layer(query_head_state,key_head_state, value_head_state, mask)
        
        
        return self.attention_layer_norm(x + output)
        

    def forward(self, x, attention_mask=None):
        
        x = self.MultiHeadAttention(x, attention_mask)
        x = self.ffn(x)
        
        return x
    
    


class Decoder(nn.Module):
    """Some Information about Decoder"""
    def __init__(self):
        super().__init__()
        
        # * 堆叠了5层
        self.layer_list = nn.ModuleList([SimpleDecoderLayer(64, 8) for i in range(5)])
        
        # * nn.Embedding的第一个参数表示嵌入向量的数量，第二个参数表示嵌入向量的维度
        # * nn.Embedding在应用的时候，输入的是索引值
        self.output_emb = nn.Embedding(12, 64)
        
        self.out_linear = nn.Linear(64, 12)

    def forward(self, x, mask=None):
        
        x = self.output_emb(x)
        
        for layer in self.layer_list:
            x = layer(x, mask)
        
        print(x.shape)
        
        output = self.out_linear(x)

        return torch.softmax(output, dim=-1)



# * batch_size is 3 , seq_len is 4, 嵌入向量的总数的是 3 * 4 = 12 
X = torch.randint(0, 12, (3,4))

simple_decoder = Decoder()

# * mask shape is (batch, head_num, seq, seq)
mask = torch.tensor([[1,1,1,1],[1,1,0,0],[1,1,1,0]]).unsqueeze(1).unsqueeze(2).repeat(1, 8, 4, 1)


simple_decoder(X, mask)


