In [9]:
##导入

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


In [10]:
## model args 构建
from pydantic import BaseModel

class ModelArgs(BaseModel):
    embedding_dim: int = 1024
    n_heads: int = 8
    max_seq_len: int = 8192
    hidden_size: int = 1024
    dropout: float = 0.2
    num_encoder_layer: int = 8
    num_decoder_layer: int = 8
    theta: float = 10000.0
    vocab_size: int = 32796
    
    pass

In [11]:
## multi head attention 
### 支持is_causal参数
class MultiHeadAttention(nn.Module):
    def __init__(self, args: ModelArgs, is_causal: bool = False) -> None:
        super().__init__()

        assert args.embedding_dim % args.n_heads == 0

        # 计算每一个head 的维度
        self.head_dim = args.embedding_dim // args.n_heads
        self.n_heads = args.n_heads

        self.qw = nn.Linear(args.embedding_dim, self.head_dim * args.n_heads, bias=False)
        self.kw = nn.Linear(args.embedding_dim, self.head_dim * args.n_heads, bias=False)
        self.vw = nn.Linear(args.embedding_dim, self.head_dim * args.n_heads, bias=False)

        self.ow = nn.Linear(self.head_dim * args.n_heads, args.embedding_dim)
        self.is_causal = is_causal
        if self.is_causal:
            mask = torch.full((args.max_seq_len, args.max_seq_len), fill_value=float("-inf"))
            mask = torch.triu(mask, diagonal=1)
            self.register_buffer("mask", mask)
    
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        bsz, seq_len = q.shape[:2]

        xq = self.qw(q)
        xk = self.kw(k)
        xv = self.vw(v)
        xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        xk = xk.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        xv = xv.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)

        if self.is_causal:
            scores = scores + self.mask[:seq_len, :seq_len]
        
        attn = F.softmax(scores, dim=-1)
        
        output = torch.matmul(scores, xv)
        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, self.n_heads * self.head_dim)

        output = self.ow(output)
        return output








        



In [12]:
## MLP

class MLP(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None:
        super().__init__()
        self.l1 = nn.Linear(dim, hidden_dim, bias=False)
        self.l2 = nn.Linear(hidden_dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.dropout(self.l2(F.relu(self.l1(x))))






In [13]:
## LayerNormal


class LayerNormal(nn.Module):
    def __init__(self, features, eps:float=1e-6) -> None:
        super().__init__()
        self.a = nn.Parameter(torch.ones(features))
        self.b = nn.Parameter(torch.zeros(features))
        self.eps = eps
    
    def forward(self, x: torch.Tensor):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)

        x = (x - mean) / (std + self.eps)

        return self.a * x + self.b
    
    


In [14]:
## encoder 
class EncoderLayer(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
        self.attn = MultiHeadAttention(args)
        self.mlp = MLP(args.embedding_dim, args.hidden_size, args.dropout)
        self.attn_norm = LayerNormal(args.embedding_dim)
        self.mlp_norm = LayerNormal(args.embedding_dim)
    
    def forward(self, x: torch.Tensor):
        norm_x = self.attn_norm(x)
        x = x + self.attn(norm_x, norm_x, norm_x)
        x = x + self.mlp(self.mlp_norm(x)) 
        return x
    
class Encoder(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
        self.layers = [EncoderLayer(args) for _ in range(args.num_encoder_layer)]
        self.layer_norm = LayerNormal(args.embedding_dim)
    
    def forward(self, x: torch.Tensor):
        for layer in self.layers:
            x = layer(x)
        
        return self.layer_norm(x)





    








In [15]:
## decoder
class DecoderLayer(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
        self.self_attn = MultiHeadAttention(args)
        self.self_attn_norm = LayerNormal(args.embedding_dim)

        self.attn = MultiHeadAttention(args, is_causal=True)
        self.attn_norm = LayerNormal(args.embedding_dim)

        self.mlp = MLP(args.embedding_dim, args.hidden_size, args.dropout)
        self.mlp_norm = LayerNormal(args.embedding_dim)

    
    def forward(self, x: torch.Tensor, encoder_out: torch.Tensor):
        norm_x = self.self_attn_norm(x)
        x = x + self.self_attn(norm_x, norm_x, norm_x)
        norm_x = self.attn_norm(x)
        x = x + self.attn(norm_x, encoder_out, encoder_out)

        x = self.mlp(self.mlp_norm(x))
        return x


class Decoder(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
        self.layers = [DecoderLayer(args) for _ in range(args.num_decoder_layer)]
        self.norm = LayerNormal(args.embedding_dim)
    
    def forward(self, x: torch.Tensor, encoder_out: torch.Tensor):
        for layer in self.layers:
            x = layer(x, encoder_out)
        
        return  self.norm(x)

















             

In [17]:
class PositionEmbeddings(nn.Module):
    def __init__(self, max_seq_len, embedding_dim,theta: float=10000.0) -> None:
        super().__init__()
        pe = torch.zeros((max_seq_len, embedding_dim))
        position = torch.arange(0, max_seq_len).unsqueeze(1) # max_seq_len, 1
        div_term = torch.exp(torch.arange(0, max_seq_len, 2) * -(math.log(theta)/ embedding_dim))

        pe[:,0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)
    
    def forward(self, x: torch.Tensor):
        sql_len = x.size(1)

        return x + self.pe[:,:sql_len, :].requires_grad_(False)  # pyright: ignore[reportIndexIssue]



In [19]:
## transformer

from uu import decode
from torch.nn import L1Loss, parameter


class Transformers(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
        self.args = args

        self.transformer = nn.ModuleDict(
            dict(
                wte = nn.Embedding(self.args.vocab_size, self.args.embedding_dim),
                wpe = PositionEmbeddings(self.args.max_seq_len, self.args.embedding_dim, self.args.theta),
                drop = nn.Dropout(self.args.dropout),
                encoder = Encoder(self.args),
                decoder = Decoder(self.args)
            )
        )

        self.ln_head = nn.Linear(self.args.embedding_dim, self.args.vocab_size, bias=False)

        self.apply(self._init_model_weight)

        print(f"total parmas: {self.get_num_params()}")

    def _init_model_weight(self, module: nn.Module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def get_num_params(self, non_embedding: bool = False):
        n_params = sum([params.numel() for params in self.parameters()])
        if non_embedding:
            n_params -= self.transformer.wte.weight.numel()  # pyright: ignore[reportAttributeAccessIssue, reportCallIssue]
        return n_params

    def forward(self, index: torch.Tensor, target=None):
        
        tok_embed = self.transformer["wte"](index)
        pos_embed = self.transformer["wpe"](tok_embed)
        # 添加一个drop
        x = self.transformer["drop"](pos_embed)



        enc_out = self.transformer["encoder"](x)

        dec_out = self.transformer["decoder"](x, enc_out)

        if target is not None:
            logits = self.ln_head(dec_out)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1), ignore_index=-1)
        else:
            logits = self.ln_head(dec_out[:, [-1], :])
            loss = None
        return logits, loss






  from uu import decode
