In [9]:
import warnings
warnings.filterwarnings("ignore")

In [36]:
from transformers import PretrainedConfig
from typing import Optional, Union
import torch
import math
import torch.nn as nn
import torch.nn.functional as F


### 模型配置

In [11]:

class OstrichModelConfig(PretrainedConfig):
    model_type = "Ostrich-Llm"
    def __init__(
            self,
            dim: int = 768, # 模型维度
            n_layers: int = 12, # Transformer的层数
            n_heads: int = 16, # 注意力机制的头数
            n_kv_heads: int = 8, # 键值头的数量
            vocab_size: int = 6144, # 词汇表大小
            hidden_dim: int = None, # 隐藏层维度
            multiple_of: int = 64, 
            norm_eps: float = 1e-5, # 归一化层的eps
            max_seq_len: int = 512, # 最大序列长度
            dropout: float = 0.0, # dropout概率
            flash_attn: bool = True, # 是否使用Flash Attention
            **kwargs,
    ):
        self.dim = dim
        self.n_layers = n_layers
        self.n_heads = n_heads 
        self.n_kv_heads = n_kv_heads
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.multiple_of = multiple_of
        self.norm_eps = norm_eps
        self.max_seq_len = max_seq_len
        self.dropout = dropout
        self.flash_attn = flash_attn
        super().__init__(**kwargs)

In [12]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float) -> None:
        super().__init__()
        self.dim = dim
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps
    
    def _norm(self, x: torch.Tensor):
        return x * torch.rsqrt(torch.pow(x, 2).mean(dim=-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        return self.weight * self._norm(x.float()).type_as(x)
    


In [26]:
## repeat_kv
### 因为使用 gropu_attn， q使用 n_heads, kv使用 kv_heads， 所以在计算最终的attn时，需要扩增k，v的维度到与v相同的维度


def repeat_kv(x: torch.Tensor, n_rep: int):
    """

    Args:
        x (torch.Tensor): k_states or v_states
        n_rep (int): repeat的数量
    """
    bsz, seq_len, kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x

    expand_x = x[
        :, :, :, None, :
    ]  # expand_x shape (bsz, seqlen, kv_heads, 1, head_dim)
    expand_x = expand_x.expand(bsz, seq_len, kv_heads, n_rep, head_dim).reshape(
        bsz, seq_len, kv_heads * n_rep, head_dim
    )
    return expand_x

In [27]:
# 预计算 分组频率 也就是 e^i* theta = cos(theta) + i * sin(theta) i表示虚数单位

from typing import Tuple


def precompute_freq_cis(max_seq_length, dim, theta: float=10000.0):
    # 维度频率计算
    freqs = 1.0 / theta ** (torch.arange(0, dim, 2)[:dim//2] / dim).float()

    # 生成 序列index,
    t = torch.arange(0, max_seq_length).type_as(freqs)

    # 外积计算
    freqs = torch.outer(t, freqs) # seq_len, dim // 2

    # 计算 大小为1 方向为theta的旋转矩阵 即转为复数域
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rope_embedding(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    # bsz, seq_len, dim => bsz , seq_len, dim//2 2
    # 按照维度进行22分组，x y
    xq_ = xq.reshape(*xq.shape[:-1], -1, 2)
    xk_ = xk.reshape(*xk.shape[:-1], -1, 2)

    # 转为复数 构建为x+iy
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)

    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2).type_as(xq)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2).type_as(xk)
    return xq_out, xk_out
    




In [None]:
class Attention(nn.Module):
    def __init__(self, args: OstrichModelConfig) -> None:
        super().__init__()

        # 判别是否包含 n_kv_heads, 若不包含 则n_kv_heads = n_heads
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads

        # 确保 n_heads 是否是 n_kv_heads的整数倍
        assert args.n_heads % self.n_kv_heads == 0

        model_parallel_size = 1

        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size

        self.n_reps = self.n_local_heads // self.n_local_kv_heads

        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)

        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        freq_cis = precompute_freq_cis(args.max_seq_len, dim=self.head_dim)
        self.register_buffer("freq_cis", freq_cis)
        self.dropout_prob = args.dropout

        self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
        if self.flash:
            # 若不支持，则手动mask
            mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), fill_value=float("-inf"))
            mask = torch.triu(mask, diagonal=1)
            self.register_buffer("mask", mask)
        
    
    def forward(self, x: torch.Tensor):
        bsz, seq_len = x.shape[:2]

        xq = self.wq(x)
        xk = self.wk(x)
        xv = self.wv(x)
        xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
        xv = repeat_kv(xv, self.n_reps)
        xk = repeat_kv(xk, self.n_reps)
        # transpose
        xq = xq.transpose(1, 2)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)
        
        xq, xk = apply_rope_embedding(xq, xk, self.freq_cis[:seq_len, :])  # pyright: ignore[reportArgumentType, reportIndexIssue]
        # 计算 scores
        
        

        if self.flash:
            output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout_prob, is_causal=True)
        else:
            scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)

            assert hasattr(self, "mask")
            scores = scores + self.mask[:, :, :seq_len, :seq_len]  # pyright: ignore[reportIndexIssue]
            
            attn = F.softmax(scores, dim=-1)
            attn = self.attn_dropout(attn)
            output = torch.matmul(attn, xv)
        
        # 恢复维度
        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)

        output = self.wo(output)
        output = self.resid_dropout(output)
        return output



        
        
        




        


In [None]:
args = OstrichModelConfig()
attention_model = Attention(args)

# 模拟输入数据
batch_size = 1
seq_len = 50
dim = args.dim

x = torch.rand(batch_size, seq_len, dim)  # 随机生成输入张量


# 运行Attention模型
output = attention_model(x)

# attention出来之后的形状 依然是[batch_size, seq_len, dim]
print("Output shape:", output.shape)

Output shape: torch.Size([1, 50, 768])


In [15]:
import torch

x = torch.randn(2, 3, 4, 5)   # shape = [2, 3, 4, 5]
y = torch.flatten(x, -2, -1)  # 把最后两个维度合并
print(y.shape)

torch.Size([2, 3, 20])


In [None]:
class MLP(nn.Module):
    def __init__(self, dim: int, hidden_size: int, multiple_of: int, dropout_prob: float) -> None:
        super().__init__()
        if hidden_size is None:
            # 如果hidden_size 不设置的话，我们往往会先将其设置为 dim 的4 倍，然后将至 2/3 倍，也就是 8 * dim // 3 
            # 另外 hidden_size 应该为multiple_of 的整数倍
            hidden_size = int(4 * dim)
            hidden_size = int(2 * hidden_size / 3)
            hidden_size = ((hidden_size + multiple_of - 1) // multiple_of) * multiple_of           

        self.w1 = nn.Linear(dim, hidden_size, bias=False)
        self.w2 = nn.Linear(dim, hidden_size, bias=False)
        self.w3 = nn.Linear(hidden_size, dim, bias=False)


        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.w3(F.selu(self.w1(x)) + self.w2(x)))
    
    







In [38]:
mlp = MLP(args.dim, args.hidden_dim, args.multiple_of, args.dropout)
# 随机生成数据
x = torch.randn(1, 50, args.dim)
# 运行MLP模型
output = mlp(x)
print(output.shape)

torch.Size([1, 50, 768])


In [40]:
## 构建Decoder

class DecoderLayer(nn.Module):
    def __init__(self, args: OstrichModelConfig) -> None:
        super().__init__()
        self.attn_norm = RMSNorm(args.dim, args.norm_eps)
        self.attn = Attention(args)
        
        self.ffn_norm = RMSNorm(args.dim, args.norm_eps)
        self.feed_forward = MLP(args.dim, args.hidden_dim, args.multiple_of, args.dropout)
        self.dropout = nn.Dropout(args.dropout)


    def forward(self, x):
        x = x + self.attn(self.attn_norm(x))
        x = x + self.feed_forward(self.ffn_norm(x)) 
        return self.dropout(x)


class Decoder(nn.Module):
    def __init__(self, args: OstrichModelConfig) -> None:
        super().__init__()
        self.layers = [DecoderLayer(args) for _ in range(args.n_layers)]
        self.norm = RMSNorm(args.dim, args.norm_eps)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)
        


In [41]:
decoder = Decoder(args)
# 随机生成数据
x = torch.randn(1, 50, args.dim)
# 运行MLP模型
output = decoder(x)
print(output.shape)

torch.Size([1, 50, 768])


In [None]:
## 至此我们开始组装我们的模型
from transformers import PreTrainedModel
from typing import Optional
class OstrichModel(PreTrainedModel):
    config_class = OstrichModelConfig
    last_loss: Optional[torch.Tensor]
    def __init__(self, config: OstrichModelConfig, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        
        self.embed = nn.Embedding(args.vocab_size, args.dim)
        
        self.decoder = Decoder(config)

        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
    
    
    def _init_weights(self, module: nn.Module):
        if isinstance(module, nn.Linear):
            torch.init

In [3]:
import torch
import torch.random
a = torch.randn((5, 512))
a

tensor([[ 1.0597,  1.4129,  0.1038,  ..., -1.7825,  1.9595, -1.5581],
        [-1.4532, -0.3955,  0.7904,  ..., -2.2277,  0.6763, -0.2762],
        [-0.6718, -0.5793,  0.3264,  ..., -0.6065,  0.4360, -1.7508],
        [-1.3422, -1.0430,  1.4487,  ..., -2.1306, -1.6287, -1.1258],
        [-0.0550,  1.4950,  0.1686,  ..., -0.8293,  1.2068, -0.6586]])

In [4]:
a[a<0.1] = 0
a

tensor([[1.0597, 1.4129, 0.1038,  ..., 0.0000, 1.9595, 0.0000],
        [0.0000, 0.0000, 0.7904,  ..., 0.0000, 0.6763, 0.0000],
        [0.0000, 0.0000, 0.3264,  ..., 0.0000, 0.4360, 0.0000],
        [0.0000, 0.0000, 1.4487,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 1.4950, 0.1686,  ..., 0.0000, 1.2068, 0.0000]])