In [1]:
import warnings
warnings.filterwarnings("ignore")


In [2]:
from pydantic import BaseModel

class ModelArgs(BaseModel):
    dim: int = 1024
    n_heads: int = 8
    dropout: float = 0.1
    max_seq_len: int = 2048
    embed_dim: int = 12345
    hidden_dim: int = 512
    num_encoder_layers: int = 8
    num_decoder_layres: int = 8
    theta: int = 10000

In [3]:
## multi attention
import torch.nn as nn
import torch.nn.functional as F
import math
import torch

class MultiHeadAttention(nn.Module):
    def __init__(self, args: ModelArgs, is_causal: bool=False) -> None:
        super().__init__()

        assert args.dim % args.n_heads == 0

        self.head_dim = args.dim // args.n_heads

        self.n_heads = args.n_heads

        # 构造 qw, kw, vw, ow 参数矩阵

        self.qw = nn.Linear(args.embed_dim, self.head_dim * self.n_heads, bias=False)
        self.kw = nn.Linear(args.embed_dim, self.head_dim * self.n_heads, bias=False)
        self.vw = nn.Linear(args.embed_dim, self.head_dim * self.n_heads, bias=False)


        self.ow = nn.Linear(self.head_dim * self.n_heads, args.embed_dim, bias=False)
        self.dropout = nn.Dropout(args.dropout)
        self.is_causal = is_causal
        if self.is_causal:
            mask = torch.full((args.max_seq_len, args.max_seq_len), fill_value=float("-inf"))
            mask = torch.triu(mask, diagonal=1)
            self.register_buffer("mask", mask)
        

    
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        bsz, seq_len = q.shape[:2]

        xq, xk, xv = self.qw(q), self.kw(k), self.vw(v)

        # resize
        xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
        xk = xk.view(bsz, seq_len, self.n_heads, self.head_dim)
        xv = xv.view(bsz, seq_len, self.n_heads, self.head_dim)

        xq = xq.transpose(1, 2)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)

        scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)

        if self.is_causal:
            scores = scores + self.mask # noqa  # pyright: ignore[reportOperatorIssue]
        
        scores = F.softmax(scores, dim=-1).type_as(xq)
        output = torch.matmul(scores, xv) # bsz, seq_len, 

        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, self.head_dim * self.n_heads)

        output = self.ow(output)
        return self.dropout(output)










In [4]:
## MLP
class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, dropout) -> None:
        super().__init__()

        self.l1 = nn.Linear(dim, hidden_dim, bias=False)
        self.l2 = nn.Linear(hidden_dim, dim, bias=False)

        self.dropout = nn.Dropout(dropout)

    
    def forward(self, x):
        return self.dropout(self.l2(F.relu(self.l1(x))))




In [5]:
## LayerNormal

class LayerNormal(nn.Module):
    def __init__(self, features, eps: float=1e-6) -> None:
        super().__init__()

        self.aw = nn.Parameter(torch.ones(features))
        self.bw = nn.Parameter(torch.zeros(features))
    

    def forward(self, x: torch.Tensor):
        avg = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        x = (x - avg) / std

        return  self.aw * x + self.bw
    

    

In [6]:
## Encoder

class EncoderLayer(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()

        self.attention = MultiHeadAttention(args, is_causal=False)
        self.mlp = MLP(args.dim, args.hidden_dim, args.dropout)
        self.layer_normal = LayerNormal(args.embed_dim)
    
    def forward(self, x):
        norm_x = self.layer_normal(x)
        x = x + self.attention(norm_x, norm_x, norm_x)

        output = x + self.mlp(self.layer_normal(x))
        return output

class Encoder(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
        self.encoder_layers = [EncoderLayer(args) for _ in range(args.num_encoder_layers)]
        self.norm = LayerNormal(args.embed_dim)
    
    def forward(self, x):
        for layer in self.encoder_layers:
            x = layer(x)
        
        return self.norm(x)


In [7]:
## Decoder decoder layer 包含一个 self_attention 和 cross_attention

"""
# Transformer Decoder Layer 结构图（ASCII风格）

```
        +-----------------------------+
        |        Input (from prev)    |
        +-----------------------------+
                     |
                     v
        +-----------------------------+
        | Masked Multi-Head Attention |
        +-----------------------------+
                     |
                     v
        +-----------------------------+
        |      Add & LayerNorm        |
        +-----------------------------+
                     |
                     v
        +-----------------------------+
        | Encoder-Decoder Attention   |
        +-----------------------------+
                     |
                     v
        +-----------------------------+
        |      Add & LayerNorm        |
        +-----------------------------+
                     |
                     v
        +-----------------------------+
        |   Feed Forward Network (FFN)|
        +-----------------------------+
                     |
                     v
        +-----------------------------+
        |      Add & LayerNorm        |
        +-----------------------------+
                     |
                     v
        +-----------------------------+
        |         Output              |
        +-----------------------------+
```


"""



class DecoderLayer(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
        self.self_attention = MultiHeadAttention(args, is_causal=False)
        self.self_attn_norm = LayerNormal(args.embed_dim)
        self.attention = MultiHeadAttention(args, is_causal=True)
        self.attn_norm = LayerNormal(args.embed_dim)
        
        self.mlp = MLP(args.dim, args.hidden_dim, args.dropout)
        self.mlp_norm = LayerNormal(args.embed_dim)
    
    def forward(self, x, encoder_output):
        norm_x = self.self_attn_norm(x)
        x = x + self.self_attention(norm_x, norm_x, norm_x)
        norm_x = self.attn_norm(x)
        x = x + self.attention(norm_x, encoder_output, encoder_output)

        output = x + self.mlp(self.mlp_norm(x))
        return output


class Decoder(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
        self.decoder_layers = [DecoderLayer(args) for _ in range(args.num_decoder_layres)] 
        self.norm = LayerNormal(args.embed_dim)
    
    def forward(self, x, encoder_output):
        for layer in self.decoder_layers:
            x = layer(x, encoder_output)
        
        return self.norm(x)

        

In [8]:
## position embedding

class PositionEmbedding(nn.Module):
    def __init__(self, max_sql_length: int, embedding_dim: int, theta: int = 10000) -> None:
        super().__init__()
        pe = torch.zeros((max_sql_length, embedding_dim))

        position = torch.arange(0, max_sql_length).unsqueeze(1)
        div_theta = torch.exp(torch.arange(0, embedding_dim, 2)/embedding_dim * -math.log(theta) )

        pe[:, 0::2] = torch.sin(position * div_theta)
        pe[:, 1::2] = torch.cos(position * div_theta)

        self.register_buffer("pe",pe)



    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len].requires_grad_(False)  # pyright: ignore[reportIndexIssue]