In [2]:
from transformers import PretrainedConfig
from typing import List

In [3]:
import math
import struct
import inspect
import time
from typing import Any, Optional, Tuple, List
import numpy as np
from torch import nn 
from transformers import PreTrainedModel 
import torch


In [4]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)).type_as(x)


In [5]:
def precompute_pos_cis(dim: int, end: int = int(32*1024), theta: float = 1e6):
    freqs = 1.0 / (theta ** (torch.arange(0,dim,2)[: (dim//2)].float() / dim))

    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t,freqs).float()

    pos_cis = torch.polar(torch.ones_like(freqs),freqs)
    return pos_cis 


In [6]:
def apply_rotray_emb(xq, xk, pos_cis):
    def unite_shape(pos_cis,x):
        ndim = x.ndim
        assert 0 <= 1 < ndim
        assert pos_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i==1 or i == ndim -1 else 1 for i, d in enumerate(x.shape)]
        return pos_cis.view(*shape)
    
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1,2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1,2))
    pos_cis  = unite_shape(pos_cis, xq_)
    xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

In [7]:
def repeat_kv(x: torch.Tensor, n_rep: int):
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:,:,None,:].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs,slen,n_kv_heads * n_rep, head_dim)
    )

In [8]:
class LMConfig(PretrainedConfig):
    model_type = "miaodeeai"
    def __init__(self,
                 dim: int = 512,
                 n_layers: int = 1,
                 n_heads: int = 8,
                 n_kv_heads: int = 2,
                 vocab_size: int = 6400,
                 hidden_dim: int = None,
                 multiple_of: int = 64,
                 norm_eps: float = 1e-5,
                 max_seq_len: int = 8192,
                 rope_theta: int = 1e6,
                 dropout: float = 0.0,
                 flash_attn: bool = True,
                 ###底下的是使用 MoE 的时候才需要的参数
                 use_moe: bool = False,
                 num_experts_per_tok: int =2,
                 num_routed_experts: int=4,
                 n_shared_experts: bool = True,
                 scoring_func: str = 'softmax',
                 aux_loss_alpha: float = 0.1,
                 seq_aux: bool = True,
                 norm_topk_prob: bool= True,
                 **kwargs,
                 ):
        self.dim = dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim 
        self.multiple_of = multiple_of
        self.max_seq_len = max_seq_len
        self.rope_theta = rope_theta
        self.dropout = dropout
        self.flash_attn = flash_attn
        self.norm_eps = norm_eps
### 这里是moe相关的参数
        self.use_moe = use_moe
        self.num_experts_per_tok = num_experts_per_tok
        self.num_routed_experts = num_routed_experts
        self.n_shared_experts = n_shared_experts
        self.scoring_func = scoring_func
        self.aux_loss_alpha = aux_loss_alpha
        self.seq_aux = seq_aux
        self.norm_topk_prob = norm_topk_prob
        super().__init__(**kwargs)

In [9]:
import json
import random
import re
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.model_selection import train_test_split
import os 
import ast


In [10]:
class PretrainDataset(Dataset):
    def __init__(self, data_path: str, tokenizer, max_length: int = 512):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.samples = self.load_data(data_path)

    def load_data(self, data_path: str):
        samples = []
        with open(data_path,'r', encoding = 'utf-8') as f:
            for line_num, line in enumerate(f,1):
                data = json.loads(line.strip())
                samples.append(data)
        return samples 
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index: int):
        sample = self.samples[index]

        text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"

        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        inputs_ids = encoding['input_ids'].squeeze()
        loss_mask = (inputs_ids != self.tokenizer.pad_token_id)
        X = torch.tensor(inputs_ids[:-1], dtype=torch.long)
        Y = torch.tensor(inputs_ids[1:], dtype=torch.long)
        loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
        return X,Y,loss_mask
    

In [11]:
xq,xk = torch.randn((2,16,4,64)), torch.randn((2,16,4,64))
pos_cis = precompute_pos_cis(64,16)
print(f"pos_cis shape: {pos_cis.shape}, pos_cis[0,0]: {pos_cis[0,0]}")

pos_cis shape: torch.Size([16, 32]), pos_cis[0,0]: (1+0j)


In [12]:
xq_rope, xk_rope = apply_rotray_emb(xq, xk, pos_cis)
print(f"xq_rope shape: {xq_rope.shape}, xk_rope shape: {xk_rope.shape}")

xq_rope shape: torch.Size([2, 16, 4, 64]), xk_rope shape: torch.Size([2, 16, 4, 64])


In [13]:
from typing import Any, Optional, Tuple, List 
import torch.nn as nn
import math 
import torch
import  torch.nn.functional as F

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep ==1:
        return x
    return (
        x[:,:,:,None,:].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )

In [14]:
class Attention(nn.Module):
    def __init__(self,args: LMConfig):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        assert args.n_heads % args.n_kv_heads == 0
        
        self.n_local_heads = args.n_heads 
        self.n_local_kv_heads = args.n_kv_heads 
        self.n_rep = self.n_local_heads // self.n_local_kv_heads 
        self.head_dim = args.dim // args.n_heads 

        # q,k,v, o projection

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
        mask = torch.full((1,1,args.max_seq_len, args.max_seq_len), float('-inf'))
        mask = torch.tril(mask, diagonal=1)
        self.register_buffer('mask', mask, persistent=False)

    def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None,
                use_cache=False):
        bsz, seq_len, _ = x.shape

        ####Forward Q,K,V && RoPE #### 
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
        xq, xk = apply_rotray_emb(xq, xk, pos_cis)

        ###E KV Cache ####
        if past_key_value is not None:
            xk = torch.cat([past_key_value[0], xk], dim=1)
            xv = torch.cat([past_key_value[1], xv], dim=1)

        past_kv = (xk, xv) if use_cache else None
        xq, xk, xv = (
            xq.transpose(1,2),
            repeat_kv(xk, self.n_rep).transpose(1,2),
            repeat_kv(xv, self.n_rep).transpose(1,2)
        )

        #### Scaled Dot Production ####
        if self.flash and seq_len !=1:
            dropout_p = self.dropout if self.training else 0.0
            output = F.scaled_dot_product_attention(xq,xk,xv,
                                                    attn_mask = None,
                                                    dropout_p = dropout_p,
                                                    is_causal = True)
        else:
            scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
            scores += self.mask[:,:,:seq_len,:seq_len]
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            scores = self.attn_dropout(scores)
            output = scores @ xv

        output = output.transpose(1,2).reshape(bsz, seq_len, -1)
        output = self.resid_dropout(self.wo(output))
        return output, past_kv


In [15]:
# LMConfig = LMConfig(n_layers=2)
# attn = Attention(LMConfig)
# x = torch.randn((4,16,512)) # (batch_size, seq_len, embed_dim)
# pos_cis = precompute_pos_cis(64,16) # (head_dim, batch_size) 其中 head_dim = embed_dim // n_heads 
# output , past_kv = attn(x, pos_cis=pos_cis, use_cache=True)

# print(f"输入张量x 的形状 {x.shape}, RoPE 旋转角度: pos_cis.shape = {pos_cis.shape}")
# print(f"输出张量 output 的形状: {output.shape}, kv_cache 的形状: {past_kv[0].shape}, size_value = {past_kv[1].shape}")

In [16]:
class FeedForward(nn.Module):
    def __init__(self, config: LMConfig):
        super().__init__()
        if config.hidden_dim is None:
            hidden_dim = config.dim * 4
            hidden_dim = int(2* hidden_dim /3)
            config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of -1) // config.multiple_of)
        self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
        self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
        return x

In [17]:
ffn = FeedForward(LMConfig(n_layers=2))
x = torch.randn((4,16,512)) # batch_size, seq_len, embed_dim
output = ffn(x)
print(f"输入张量x 的形状 {x.shape}, 输出张量 output 的形状: {output.shape}")

输入张量x 的形状 torch.Size([4, 16, 512]), 输出张量 output 的形状: torch.Size([4, 16, 512])


In [18]:
class GPTBlock(nn.Module):
    def __init__(self, layer_id: int, config: LMConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.dim = config.dim
        self.attention = Attention(config)
        
        self.layer_id = layer_id 
        self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.feed_forward = FeedForward(config)
    
    def forward(self, x, pos_cis, past_key_value=None, use_cache=False):
        h_attn , past_kv= self.attention(self.attention_norm(x),pos_cis,past_key_value= past_key_value, use_cache=use_cache)

        h = x + h_attn 
        out = h + self.feed_forward(self.ffn_norm(h))
        return out, past_kv


In [19]:
miniblock = GPTBlock(layer_id=1, config=LMConfig(n_layers=2))
x = torch.randn((4,16,512)) # batch_size, seq_len, embed_dim
pos_cis = precompute_pos_cis(64,16) # (head_dim,
out, past_kv = miniblock(x,pos_cis=pos_cis, use_cache=True)
print(f"输入张量x 的形状 {x.shape}, RoPE 旋转角度: pos_cis.shape = {pos_cis.shape}")
print(f"输出张量 out 的形状: {out.shape}, kv_cache 的形状: {past_kv[0].shape}, size_value = {past_kv[1].shape}")

输入张量x 的形状 torch.Size([4, 16, 512]), RoPE 旋转角度: pos_cis.shape = torch.Size([16, 32])
输出张量 out 的形状: torch.Size([4, 16, 512]), kv_cache 的形状: torch.Size([4, 16, 2, 64]), size_value = torch.Size([4, 16, 2, 64])


In [20]:
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast

class cqxGPT(PreTrainedModel):
    config_class = LMConfig

    def __init__(self, params: LMConfig):

        self.params = params or LMConfig()
        super().__init__(self.params)
        self.vocab_size , self.n_layers = params.vocab_size, params.n_layers
        self.token_embeddings = nn.Embedding(params.vocab_size, params.dim)
        self.dropout = nn.Dropout(params.dropout)
        self.layers = nn.ModuleList([GPTBlock(1,params) for layer in range(self.n_layers)])
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
        self.token_embeddings.weight = self.output.weight
        self.register_buffer(
            "pos_cis",
            precompute_pos_cis(dim=params.dim//params.n_heads, theta = params.rope_theta),
            persistent=False
        )
        self.OUT = CausalLMOutputWithPast()
    def forward(self,
                input_ids: Optional[torch.Tensor] = None,
                past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
                use_cache: bool = False,
                **args):
        past_key_values = past_key_values or [None] * self.n_layers
        start_pos = args.get('start_pos', 0)
        h = self.dropout(self.token_embeddings(input_ids))

        pos_cis = self.pos_cis[start_pos: start_pos + input_ids.size(1)]
        past_kvs = []

        for l , layer  in enumerate(self.layers):
            print(f"第 {l} 层的输入张量 h 的形状: {h.shape}, pos_cis 的形状: {pos_cis.shape}")

            h, past_kv = layer(h, pos_cis=pos_cis, past_key_value=past_key_values[l], use_cache=use_cache)
            print(f"finished layer {l}, output h 的形状: {h.shape}, size_cache_k 的形状: {past_kv[0].shape},size_cache_v = {past_kv[1].shape}")
            past_kvs.append(past_kv)

        print(f"forward operation completed, num_kv_cache = {len(past_kvs)}")
        logits = self.output(self.norm(h))
        aux_loss = 0
        self.OUT.__setitem__('logits', logits)
        self.OUT.__setitem__("aux_loss", aux_loss)
        self.OUT.__setitem__("past_key_values", past_kvs)
        return self.OUT 
    
    @torch.inference_mode()
    def generate(self, input_ids: torch.Tensor, eos_token_id= 2, max_new_tokens: int = 512, temperature=0.75, top_p=0.90, stream=False, rp=1, use_cache=True, pad_token_id=0, **args):

        if stream:
            return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, pad_token_id, **args)
        
        generated = []

        for i in range(input_ids.size(0)):
            non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
            out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, pad_token_id, **args)
            tokens_list = [tokens[:,-1:] for tokens in out]
            print(f"new token list : {tokens_list}")
            gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad 
            full_sequence = torch.cat([non_pad, gen], dim=-1)
            generated.append(full_sequence)
        max_length = max(seq.size(1) for seq in generated)
        generated = [
            torch.cat([seq, torch.full((1,max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)], dim=-1)
            for seq in generated
        ]
        return torch.cat(generated, dim=0)

    def _stream(self, input_ids: torch.Tensor, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, pad_token_id, **args):
        # Implement the streaming logic here
        start, first_seq, past_kvs = input_ids.shape[1], True, True
        new_token_idx = 0
        while input_ids.shape[1]< max_new_tokens -1:
            print(f"input_ids.shape = {input_ids.shape}, start = {start}, new_token_idx = {start+new_token_idx} , ")
            if first_seq or not use_cache:
                out , first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache,**args) , False
            else:
                out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,start_pos=input_ids.shape[1] - 1, **args)
            
            logits, past_kvs = out.logits[:,-1,:], out.past_key_values
            logits[:, list(set(input_ids.to_list()[0]))] /= rp 
            logits /= (temperature + 1e-9)
            if top_p is not None and top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                sorted_probs = F.softmax(sorted_logits, dim=-1)
                cumulative_probs = torch.cunsum(sorted_probs, dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
                sorted_indices_to_remove[:,0] = False
                indices_to_remove = sorted_indices_to_remove.scatter(1,sorted_indices, sorted_indices_to_remove)
                logits[indices_to_remove] = float('-inf')
            input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
            input_ids = torch.cat((input_ids, input_ids_next),dim=1)
            new_token_idx += 1
            yield input_ids[:, start:]
            if input_ids_next.item() == eos_token_id:
                break