In [1]:
import math

In [2]:
import struct

In [3]:
import inspect

In [4]:
from dataclasses import dataclass

In [5]:
from typing import Any,Optional,Tuple

In [6]:
import numpy as np

In [7]:
import torch

In [8]:
import torch.nn.functional as F

In [9]:
from torch import nn

In [50]:
@dataclass
class ModelArgs:

    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = 32000
    hidden_dim: Optional[int] = None
    multiple_of: int = 256
    norm_eps: float = 1e-5
    max_seq_len: int = 2048
    dropout: float = 0.0

In [18]:
class RMSNorm(torch.nn.Module):

    def __init__(self,dim:int, eps:float):
        super().__init__()
        self.eps = eps
        self.dim = dim
        self.weight = nn.Parameter(torch.ones(self.dim))
        
    def _norm(self,x):
        return x * torch.rsqrt(torch.square(x).mean(dim=-1,keepdim=True))

    def forward(self,x):
        return self.weight * self._norm(x)

In [31]:
x = torch.randn((5,2))

In [34]:
rms = RMSNorm(2,0.1)

In [35]:
a = rms(x)

In [36]:
a.square().mean(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<MeanBackward1>)

In [24]:
def precompute_freqs_cis(dim: int,end: int,theta: float = 10000.0):
    freqs = 1.0 / (theta ** torch.arange(0,dim,2)[:(dim // 2)].float()/dim)
    t = torch.arange(end,device=freqs.device)
    freqs = torch.outer(t,freqs).float()
    freqs_cos = torch.cos(freqs)
    freqs_sin = torch.sin(freqs)
    return freqs_cos, freqs_sin

In [37]:
b = precompute_freqs_cis(4,5)

In [27]:
def reshape_for_broadcast(freqs_cis: torch.Tensor,x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1],x.shape[-1])
    shape = [d if i == 1 or i == ndim -1 else 1 for i,d in enumerate(x.shape)]
    return freqs_cis.view(shape)

In [43]:
import torch

# 假设 x 是一个形状为 [batch_size, seq_len, dim] 的张量
x = torch.randn(2, 4, 6)  # [batch_size, seq_len, dim]

# freqs_cis 是与 x 的第二个维度和最后一个维度相匹配的张量
freqs_cis = torch.randn(4, 6)  # shape (seq_len, dim)

# 调用 reshape_for_broadcast 函数
reshaped_freqs_cis = reshape_for_broadcast(freqs_cis, x)

# 输出结果
print("Original freqs_cis shape:", freqs_cis.shape)
print("Reshaped freqs_cis shape:", reshaped_freqs_cis.shape)


Original freqs_cis shape: torch.Size([4, 6])
Reshaped freqs_cis shape: torch.Size([1, 4, 6])


In [41]:
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cos: torch.Tensor,
    freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor,torch.Tensor]:

    xq_r,xq_i = xq.float().reshape(xq.shape[:-1] + (-1,2)).unbind(-1)
    xk_r,xk_i = xk.float().reshape(xk.shape[:-1] + (-1,2)).unbind(-1)

    freqs_cos = reshape_for_broadcast(freqs_cos,xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin,xq_r)

    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

    xq_out = torch.stack([xq_out_r,xq_out_i],dim=-1).flatten(3)
    xk_out = torch.stack([xk_out_r,xk_out_i],dim=-1).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

In [42]:
import torch

# 模拟输入张量
xq = torch.randn(2, 4, 8)  # [batch_size, seq_len, dim]
xk = torch.randn(2, 4, 8)

# 模拟旋转频率（通常基于正弦和余弦函数生成）
freqs_cos = torch.randn(4, 4)  # shape (seq_len, dim // 2)
freqs_sin = torch.randn(4, 4)

# 调用 apply_rotary_emb
xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

print("Output Query Tensor Shape: ", xq_out.shape)
print("Output Key Tensor Shape: ", xk_out.shape)


Output Query Tensor Shape:  torch.Size([2, 4, 4, 2])
Output Key Tensor Shape:  torch.Size([2, 4, 4, 2])


In [46]:
def repeat_kv(x: torch.Tensor,n_rep: int) -> torch.Tensor:
    bs,slen,n_kv_heads,head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:,:,:,None,:].expand(
            bs,slen,n_kv_heads,n_rep,head_dim
        ).reshape(
            bs,slen,n_kv_heads * n_rep, head_dim
        )
    )

In [47]:
import torch

# 定义输入张量 x，形状为 [batch_size, seq_len, n_kv_heads, head_dim]
x = torch.randn(2, 3, 4, 5)  # 形状为 [2, 3, 4, 5]，batch_size=2, seq_len=3, n_kv_heads=4, head_dim=5

# 需要重复的次数 n_rep
n_rep = 3

# 调用 repeat_kv 函数
result = repeat_kv(x, n_rep)

# 输出结果
print("Original shape:", x.shape)
print("Result shape after repeating:", result.shape)


Original shape: torch.Size([2, 3, 4, 5])
Result shape after repeating: torch.Size([2, 3, 12, 5])


In [58]:
class Attention(nn.Module):

    def __init__(self,args: ModelArgs):

        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        assert args.n_heads % self.n_kv_heads == 0
        model_parallel_size = 1
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.wq = nn.Linear(args.dim,args.n_heads * self.head_dim,bias=False)
        self.wk = nn.Linear(args.dim,self.n_kv_heads * self.head_dim,bias=False)
        self.wv = nn.Linear(args.dim,self.n_kv_heads * self.head_dim,bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim,bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_drop = nn.Dropout(args.dropout)
        self.dropout = args.dropout

        self.flash = hasattr(torch.nn.functional,'scaled_dot_product_attention')

        if not self.flash:
            print("Warning")
            mask = torch.full((1,1,args.max_seq_len,args.max_seq_len),float("-inf"))
            mask = torch.triu(mask,diagonal=1)
            self.register_buffer("mask",mask)
            
    def forward(self,x:torch.Tensor,freqs_cos:torch.Tensor,freqs_sin:torch.Tensor):
        bsz,seq_len,_ = x.shape
        xq,xk,xv = self.wq(x),self.wk(x),self.wv(x)
        xq = xq.view(bsz,seq_len,self.n_local_heads,self.head_dim)
        xk = xk.view(bsz,seq_len,self.n_local_kv_heads,self.head_dim)
        xv = xv.view(bsz,seq_len,self.n_local_kv_heads,self.head_dim)
        xq,xk = apply_rotary_emb(xq,xk,freqs_cos,freqs_sin)

        xk = repeat_kv(xk,self.n_rep)
        xv = repeat_kv(xv,self.n_rep)


        xq = xq.transpose(1,2)
        xk = xk.transpose(1,2)
        xv = xv.transpose(1,2)

        if self.flash:
            output = torch.nn.functional.scaled_dot_product_attention(xq,xk,xv,attn_mask=None,
                                                                     dropout_p=self.dropout if self.training else 0.0,is_causal=True)
        else:
            scores = torch.matmul(xq,xk.transpose(2,3))
            assert hasattr(self,'mask')
            scores = scores + self.mask[:,:,:seq_len,:seq_len]
            scores = F.softmax(scores.float(),dim=-1).tyype_as(xq)
            scores = self.attn_dropout(scores)
            output = torch.matmul(scores,xv)
        output = output.transpose(1,2).contiguous().view(bsz,seq_len,-1)
        output = self.wo(output)
        output = self.resid_drop(output)
        return output

        




In [59]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math



# 随机生成频率张量和输入张量
def generate_random_tensors(args):
    x = torch.randn(2, args.max_seq_len, args.dim)  # 输入 (bsz, seq_len, dim)
    freqs_cos = torch.randn(args.max_seq_len, args.dim // args.n_heads)
    freqs_sin = torch.randn(args.max_seq_len, args.dim // args.n_heads)
    return x, freqs_cos, freqs_sin

# 使用模型参数和Attention类
args = ModelArgs(dim=512, n_heads=8)
attention = Attention(args)

# 随机输入
x, freqs_cos, freqs_sin = generate_random_tensors(args)

# 前向传播
output = attention(x, freqs_cos, freqs_sin)

# 输出结果
print("Output shape:", output.shape)


AssertionError: 

In [61]:
class FeedForward(nn.Module):

    def __init__(self,dim:int,hidden_dim:int,multiple_of:int,dropout:float):

        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * dim
            hidden_dim = int(2 * hidden_dim / 3)
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim,hidden_dim,bias=False)
        self.w2 = nn.Linear(hidden_dim,dim,bias=False)
        self.w3 = nn.Linear(dim,hidden_dim,bias=False)
        self.dropout = nn.Dropout(dropout)
    def forward(self,x):
        return self.dropout(self.w2(F.silu(self.w1(x))*self.w3(x)))
    

In [62]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义FeedForward类（如上所述）

# 模拟输入张量 x
x = torch.randn(10, 512)  # 假设输入大小为 (batch_size=10, dim=512)

# 定义FeedForward层
ffn = FeedForward(dim=512, hidden_dim=None, multiple_of=64, dropout=0.1)

# 前向传播
output = ffn(x)

# 打印输出形状
print("Output shape:", output.shape)


Output shape: torch.Size([10, 512])


In [63]:
class TransformerBlock(nn.Module):

    def __init__(self,layer_id: int,args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim = args.dim,
            hidden_dim=args.hidden_dim,
            multiple_of=args.multiple_of,
            dropout=args.dropout
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim,eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim,eps=args.norm_eps)
    def forward(self,x,freqs_cos,freqs_sin):
        h = x + self.attention.forward(self.attention_norm(x),freqs_cos,freqs_sin)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

In [64]:
# 随机输入
args = ModelArgs(dim=512, n_heads=8, hidden_dim=1024, multiple_of=64, dropout=0.1, norm_eps=1e-5)
block = TransformerBlock(layer_id=1, args=args)

# 创建随机输入
x = torch.randn(10, 128, 512)  # 假设 (batch_size=10, seq_len=128, dim=512)
freqs_cos = torch.randn(128, 512 // 8)  # 频率张量
freqs_sin = torch.randn(128, 512 // 8)

# 前向传播
output = block(x, freqs_cos, freqs_sin)

# 输出结果
print("Output shape:", output.shape)

AssertionError: 

In [77]:

class Transformer(nn.Module):
    last_loss: Optional[torch.Tensor]

    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
        self.dropout = nn.Dropout(params.dropout)
        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

        # share the unembedding parameters with the embedding parameters
        self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying

        # some useful precompute for the RoPE relative positional embeddings
        freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))

        # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
        self.last_loss = None

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        h = self.dropout(h)
        freqs_cos = self.freqs_cos[:seqlen]
        freqs_sin = self.freqs_sin[:seqlen]

        for layer in self.layers:
            h = layer(h, freqs_cos, freqs_sin)
        h = self.norm(h)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.output(h)
            self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the output on the very last position
            logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
            self.last_loss = None

        return logits

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
        # first estimate the number of flops we do per iteration.
        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = sum(p.numel() for p in self.parameters())
        cfg = self.params
        L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
        flops_per_token = 6*N + 12*L*H*Q*T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        # express our flops throughput as ratio of A100 bfloat16 peak flops
        flops_achieved = flops_per_iter * (1.0/dt) # per second
        flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
        mfu = flops_achieved / flops_promised
        return mfu

    @torch.inference_mode()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        Also note this is a super inefficient version of sampling with no key/value cache.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
            # forward the model to get the logits for the index in the sequence
            logits = self(idx_cond)
            logits = logits[:, -1, :] # crop to just the final time step
            if temperature == 0.0:
                # "sample" the single most likely index
                _, idx_next = torch.topk(logits, k=1, dim=-1)
            else:
                # pluck the logits at the final step and scale by desired temperature
                logits = logits / temperature
                # optionally crop the logits to only the top k options
                if top_k is not None:
                    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = -float('Inf')
                # apply softmax to convert logits to (normalized) probabilities
                probs = F.softmax(logits, dim=-1)
                idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [78]:
import torch

# # 假设 ModelArgs 是一个包含模型参数的类
# class ModelArgs:
#     vocab_size = 1000  # 词汇表大小
#     dim = 512          # 嵌入和隐藏层维度
#     n_layers = 6       # Transformer 层数
#     n_heads = 8        # 注意力头数
#     dropout = 0.1      # Dropout 概率
#     norm_eps = 1e-6    # 归一化的 epsilon
#     max_seq_len = 100  # 最大序列长度

params = ModelArgs()

# 初始化 Transformer 模型
model = Transformer(params)

# 输入数据：tokens 和 targets
tokens = torch.randint(0, params.vocab_size, (2, 10))  # 假设有2个序列，每个序列10个 token
targets = torch.randint(0, params.vocab_size, (2, 10))

# 前向传播
logits = model(tokens, targets)

# 输出 logits
print(logits)

# 配置优化器
optimizer = model.configure_optimizers(weight_decay=0.01, learning_rate=1e-3, betas=(0.9, 0.999), device_type='cuda')

# 生成新 token
generated_tokens = model.generate(tokens, max_new_tokens=5, temperature=1.0, top_k=50)

print(generated_tokens)


tensor([[[ 0.5163,  0.8264,  0.1863,  ...,  1.0220, -2.0998,  0.3624],
         [ 1.2761,  0.5518, -0.5195,  ...,  1.4396, -0.0480, -0.0997],
         [-0.6359,  1.4571, -0.1298,  ...,  0.3464, -3.3644, -1.8127],
         ...,
         [ 0.2648,  2.2320,  0.3610,  ...,  0.7244, -2.0868, -2.7610],
         [ 1.1659,  1.1208,  0.8284,  ..., -0.9833, -1.4072, -1.0673],
         [ 1.4421,  1.6555,  2.3052,  ...,  1.1808, -2.7671, -1.1113]],

        [[ 2.0097, -0.9510,  0.9372,  ..., -0.6056,  0.4065, -0.4302],
         [ 1.0257,  0.4921,  0.0514,  ...,  0.6228, -0.0052,  1.3652],
         [ 1.1492, -0.1643, -2.6517,  ...,  1.0021, -0.1882, -0.1027],
         ...,
         [-0.2462, -0.4856,  2.1299,  ..., -0.6836,  2.3092, -1.5769],
         [ 1.9551,  0.8440,  1.7641,  ..., -0.3515,  0.8568, -0.0155],
         [ 1.1114, -1.3413, -0.4130,  ..., -1.4152,  0.8416, -1.8739]]],
       grad_fn=<UnsafeViewBackward0>)
num decayed parameter tensors: 225, with 6,607,077,376 parameters
num non-deca


KeyboardInterrupt



In [2]:
import torch

def precompute_theta_pos_frequencies(head_dim, seq_len, device, theta=10000.0):
    theta_numerator = torch.arange(0, head_dim, 2).float()
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)
    m = torch.arange(seq_len, device=device)
    freqs = torch.outer(m, theta).float()
    print('freqs',freqs)
    print('theta',theta)
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

# 参数设置
head_dim = 8  # 头维度
seq_len = 5   # 序列长度
device = 'cpu'  # 设备选择

# 调用函数
frequencies = precompute_theta_pos_frequencies(head_dim, seq_len, device)

# 打印结果
print(frequencies)


freqs tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03],
        [2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03],
        [3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03],
        [4.0000e+00, 4.0000e-01, 4.0000e-02, 4.0000e-03]])
theta tensor([1.0000, 0.1000, 0.0100, 0.0010])
tensor([[ 1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j],
        [ 0.5403+0.8415j,  0.9950+0.0998j,  0.9999+0.0100j,  1.0000+0.0010j],
        [-0.4161+0.9093j,  0.9801+0.1987j,  0.9998+0.0200j,  1.0000+0.0020j],
        [-0.9900+0.1411j,  0.9553+0.2955j,  0.9996+0.0300j,  1.0000+0.0030j],
        [-0.6536-0.7568j,  0.9211+0.3894j,  0.9992+0.0400j,  1.0000+0.0040j]])
