In [1]:
import torch

In [2]:
from torch import nn

In [4]:
from typing import Optional,Tuple,List

In [5]:
from torch.nn import CrossEntropyLoss

In [6]:
import math

In [7]:
import os

In [8]:
import json

In [9]:
class KVCache():
    
    def __init__(self) -> None:
        
        self.key_cache :List[torch.Tensor] = []
        self.value_cache :List[torch.Tensor] = []
    
    
    def num_items(self) -> int:
        if len(self.key_cache) == 0:
            return 0
        else:
            return self.key_cache[0].shape[-2]
    
    def update(self,key_states:torch.Tensor,value_states:torch.Tensor,
              layer_idx:int) -> Tuple[torch.Tensor,torch.Tensor]:
        
        if len(self.key_cache) <= layer_idx:
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        
        else:
            
            self.key_cache[layer_idx] = torch.cat(
             [
                 self.key_cache[layer_idx],
                 key_states
             ],dim=-2)
            self.value_cache[layer_idx] = torch.cat(
            [
                self.value_cache[layer_idx],
                value_states
            ],dim=-2)
        return self.key_cache[layer_idx],self.value_cache[layer_idx]

In [10]:
import torch

# 初始化 KVCache 实例
kv_cache = KVCache()

# 创建两个新的张量，假设每个张量的形状为 (2, 3, 4)
key_tensor = torch.randn(2, 3, 4)
value_tensor = torch.randn(2, 3, 4)

# 更新第一个层的缓存
key_cache, value_cache = kv_cache.update(key_tensor, value_tensor, 0)

# 打印更新后的缓存
print(f"Layer 0 key_cache shape: {key_cache.shape}, value_cache shape: {value_cache.shape}")

# 创建另一个张量，形状也为 (2, 3, 4)
new_key_tensor = torch.randn(2, 3, 4)
new_value_tensor = torch.randn(2, 3, 4)

# 更新第一个层（索引为0），将新张量拼接到原缓存上
key_cache, value_cache = kv_cache.update(new_key_tensor, new_value_tensor, 0)

# 打印更新后的缓存
print(f"Layer 0 key_cache shape: {key_cache.shape}, value_cache shape: {value_cache.shape}")

# 创建第二个层（索引为1）的新张量
key_tensor_layer_1 = torch.randn(2, 3, 4)
value_tensor_layer_1 = torch.randn(2, 3, 4)

# 更新第二个层的缓存
key_cache, value_cache = kv_cache.update(key_tensor_layer_1, value_tensor_layer_1, 1)

# 打印第二个层的缓存
print(f"Layer 1 key_cache shape: {key_cache.shape}, value_cache shape: {value_cache.shape}")

# 查看缓存的条目数
print(f"Number of items in cache: {kv_cache.num_items()}")


Layer 0 key_cache shape: torch.Size([2, 3, 4]), value_cache shape: torch.Size([2, 3, 4])
Layer 0 key_cache shape: torch.Size([2, 6, 4]), value_cache shape: torch.Size([2, 6, 4])
Layer 1 key_cache shape: torch.Size([2, 3, 4]), value_cache shape: torch.Size([2, 3, 4])
Number of items in cache: 6


In [11]:
class GemmaConfig():
    """
    Configuration class that stores all the hyperparameters needed for the Gemma model.
    This includes things like model size (hidden_size), number of layers, attention heads, etc.
    Think of it as a recipe card that defines how big and complex the model should be.
    """

    def __init__(
        self,
        vocab_size,
        hidden_size,
        intermediate_size,
        num_hidden_layers,
        num_attention_heads,
        num_key_value_heads,
        head_dim=256,
        max_position_embeddings=8192,
        rms_norm_eps=1e-6,
        rope_theta=10000.0,
        attention_bias=False,
        attention_dropout=0.0,
        pad_token_id=None,
        **kwargs,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.head_dim = head_dim
        self.num_key_value_heads = num_key_value_heads
        self.rms_norm_eps = rms_norm_eps
        self.rope_theta = rope_theta
        self.attention_bias = attention_bias
        self.attention_dropout = attention_dropout
        self.pad_token_id = pad_token_id

In [12]:

class PaliGemmaConfig():

    def __init__(
        self,
        vision_config=None,
        text_config=None,
        ignore_index=-100,
        image_token_index=256000,
        vocab_size=257152,
        projection_dim=2048,
        hidden_size=2048,
        pad_token_id=None,
        **kwargs,
    ):
        super().__init__()
        self.ignore_index = ignore_index
        self.image_token_index = image_token_index
        self.vocab_size = vocab_size
        self.projection_dim = projection_dim
        self.hidden_size = hidden_size
        self.vision_config = vision_config
        self.is_encoder_decoder = False
        self.pad_token_id = pad_token_id

        self.vision_config = VisionConfig(**vision_config)
        self.text_config = text_config

        self.text_config = GemmaConfig(**text_config, pad_token_id=pad_token_id)
        self.vocab_size = self.text_config.vocab_size

        self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
        self.vision_config.projection_dim = projection_dim

In [14]:
class GemmaRMSNorm(nn.Module):
    
    def __init__(self,dim: int,eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))
    
    def _norm(self,x):
        return x * torch.rsqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)
    
    def forward(self,x):
        output = self._norm(x.float())
        output = output * (1.0 + self.weight.float())
        
        return output.type_as(x)

In [15]:
# 初始化 GemmaRMSNorm
dim = 4  # 输入张量的特征维度
eps = 1e-6
rms_norm = GemmaRMSNorm(dim, eps)

# 创建一个输入张量，形状为 (batch_size, dim)
x = torch.randn(2,2, dim)  # batch_size = 2, dim = 4

# 计算归一化后的输出
output = rms_norm(x)

# 打印输出结果
print("Input x:")
print(x)
print("Output after RMSNorm:")
print(output)
    

Input x:
tensor([[[-0.4246, -0.7982, -0.8357,  0.1792],
         [ 0.6600, -0.3522, -0.7154, -0.3926]],

        [[-0.1582,  0.5388,  0.2481, -0.0237],
         [-0.0523, -0.7688, -1.4402, -0.6932]]])
Output after RMSNorm:
tensor([[[-0.6826, -1.2831, -1.3434,  0.2881],
         [ 1.1923, -0.6363, -1.2925, -0.7092]],

        [[-0.5149,  1.7540,  0.8076, -0.0773],
         [-0.0589, -0.8665, -1.6233, -0.7814]]], grad_fn=<MulBackward0>)


In [21]:
class GemmaRotaryEmbedding(nn.Module):
    
    def __init__(self,dim,max_position_embeddings=2048,
                base=10000,device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        
        inv_freq = 1.0 / (self.base ** (torch.arange(0,self.dim,2,dtype=torch.int64).float()/self.dim))
        
        self.register_buffer('inv_freq',tensor=inv_freq,persistent=False)
    
    @torch.no_grad()
    def forward(self,x,position_ids,seq_len=None):
        ## position_ids  B,seq_len
        self.inv_freq.to(x.device)
        
        # seq_len
#         self.inv_freq[None,:,None].float() 1,seq_len,1
        inv_freq_expand = self.inv_freq[None,:,None].float().expand(
          position_ids.shape[0],-1,1
        )
        # B,seq_len,1 
        
        position_ids_expanded = position_ids[:,None,:].float()
        ## B,1,seq_len
        
        device_type = x.device.type
        device_type = device_type if isinstance(device_type,str) and device_type != "mps" else "cpu"
        
        with torch.autocast(device_type=device_type,enabled=False):
            
            t = (inv_freq_expand.float() @ position_ids_expanded.float())
            # B,seq_len,1  @ B, 1,seq_len
            freqs = t.transpose(1,2)
            # B, seq_len,seq_len
            
            emb = torch.cat((freqs,freqs),dim=-1)
            
            print('embedd....',emb.shape,emb)
            
            cos = emb.cos()
            sin = emb.sin()
        
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            

In [22]:
# 初始化 GemmaRotaryEmbedding
dim = 8  # 嵌入的维度
embedding_layer = GemmaRotaryEmbedding(dim)

# 创建一个输入张量，形状为 (batch_size, seq_len, dim)
batch_size = 2
seq_len = 4
x = torch.randn(batch_size, seq_len, dim)

# 创建位置ID，形状为 (batch_size, seq_len)
position_ids = torch.arange(0, seq_len).unsqueeze(0).expand(batch_size, -1)

print('pos shape',position_ids.shape)
print('pos',position_ids)

# 计算旋转位置编码
cos, sin = embedding_layer(x, position_ids)

# 打印输出结果
print("Cosine encoding:")
print(cos)
print("Sine encoding:")
print(sin)


pos shape torch.Size([2, 4])
pos tensor([[0, 1, 2, 3],
        [0, 1, 2, 3]])
embedd.... torch.Size([2, 4, 8]) tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03, 1.0000e+00,
          1.0000e-01, 1.0000e-02, 1.0000e-03],
         [2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03, 2.0000e+00,
          2.0000e-01, 2.0000e-02, 2.0000e-03],
         [3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03, 3.0000e+00,
          3.0000e-01, 3.0000e-02, 3.0000e-03]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03, 1.0000e+00,
          1.0000e-01, 1.0000e-02, 1.0000e-03],
         [2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03, 2.0000e+00,
          2.0000e-01, 2.0000e-02, 2.0000e-03],
         [3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03, 3.0000e+00

In [23]:
def rotate_half(x):
    x1 = x[...,:x.shape[-1]//2]
    x2 = x[...,x.shape[-1]//2:]
    return torch.cat((-x2,x1),dim=-1)

In [24]:
x = torch.tensor([[[1., 2., 3., 4.],
                   [5., 6., 7., 8.],
                   [9., 10., 11., 12.]],

                  [[13., 14., 15., 16.],
                   [17., 18., 19., 20.],
                   [21., 22., 23., 24.]]])

output = rotate_half(x)
print(output)


tensor([[[ -3.,  -4.,   1.,   2.],
         [ -7.,  -8.,   5.,   6.],
         [-11., -12.,   9.,  10.]],

        [[-15., -16.,  13.,  14.],
         [-19., -20.,  17.,  18.],
         [-23., -24.,  21.,  22.]]])


In [25]:
def apply_rotary_pos_emb(q,k,cos,sin,unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [26]:
# 模拟数据
batch_size = 2
seq_len = 5
dim = 8

# 随机生成查询和键张量
q = torch.randn(batch_size, 1,seq_len, dim)
k = torch.randn(batch_size, 1,seq_len, dim)

# 旋转位置编码的余弦和正弦值
# 这里我们假设 dim 为 8，所以下面 dim // 2 = 4
cos = torch.randn(batch_size,seq_len, dim)
sin = torch.randn(batch_size,seq_len, dim)

# 使用 apply_rotary_pos_emb 函数
q_embed, k_embed = apply_rotary_pos_emb(q, k, cos, sin)

# 打印结果
print("q_embed shape:", q_embed.shape)
print("k_embed shape:", k_embed.shape)

# 查看 q_embed 和 k_embed 的具体值
print("q_embed:", q_embed)
print("k_embed:", k_embed)


q_embed shape: torch.Size([2, 1, 5, 8])
k_embed shape: torch.Size([2, 1, 5, 8])
q_embed: tensor([[[[ 2.2672, -1.2059,  1.1212, -1.8954, -0.6763,  0.5578, -1.0251,
            3.3326],
          [ 0.1687,  1.0473,  0.1882, -0.2636,  0.5831, -0.3073, -0.9306,
            1.1093],
          [-0.0251, -2.3684, -0.1824,  1.7625,  2.9609,  0.5655, -0.1282,
            0.8968],
          [-0.6032,  0.0208, -0.5212,  0.5081,  0.2427, -0.7860, -2.0760,
            0.0394],
          [ 0.4414,  1.2566,  1.1808, -0.7668,  0.0767, -0.0609,  1.3863,
           -1.8431]]],


        [[[ 2.6425,  1.4354, -0.4998, -1.7827,  0.0269, -0.4905,  0.7460,
            0.5805],
          [ 0.0751,  2.4062,  1.3051,  0.8908,  0.8581,  1.7343, -0.3459,
            0.3036],
          [ 0.6975,  4.6343,  0.7887, -1.8581,  1.8242, -1.2034,  0.9783,
            1.1471],
          [ 0.1794, -0.2345,  1.5733,  0.3319, -0.1704, -1.3123, -0.4432,
           -0.3902],
          [-0.1505, -1.5761, -1.2177, -0.0479,  0.44

In [29]:
class GemmaMLP(nn.Module):
    
    def __init__(self,config):
        
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size,
                                  self.intermediate_size,
                                  bias=False)
        self.up_proj = nn.Linear(
            self.hidden_size,
            self.intermediate_size,
            bias=False
        )
        
        self.down_proj = nn.Linear(
            self.intermediate_size,
            self.hidden_size,
            bias=False
        )
    
    
    def forward(self,x):
        return self.down_proj(nn.functional.gelu(self.gate_proj(x),
                                                approximate='tanh')*self.up_proj(x))

In [30]:
import torch

# 模拟一个配置对象
class Config:
    def __init__(self):
        self.hidden_size = 512
        self.intermediate_size = 1024

config = Config()

# 创建 GemmaMLP 模型
model = GemmaMLP(config)

# 随机输入数据
x = torch.randn(2, 10, config.hidden_size)  # batch_size=2, seq_len=10, hidden_size=512

# 前向传播
output = model(x)
print(output.shape)  # 应该是 (2, 10, 512)


torch.Size([2, 10, 512])


In [31]:
def repeat_kv(hidden_states: torch.Tensor,
             n_rep: int) -> torch.Tensor:
    batch,num_key_value_heads,slen,head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:,:,None,:,:].expand(
        batch,
        num_key_value_heads,
        n_rep,
        slen,
        head_dim
    )
    return hidden_states.reshape(batch,num_key_value_heads*n_rep,
                                slen,head_dim)
    

In [32]:
import torch

# 创建一个形状为 (2, 3, 4, 5) 的随机张量，表示一个 batch size 为 2，3 个键值头，序列长度为 4，head_dim 为 5 的输入
hidden_states = torch.randn(2, 3, 4, 5)

# 设置重复次数为 2
n_rep = 2

# 调用 repeat_kv 函数
result = repeat_kv(hidden_states, n_rep)

print("Result shape:", result.shape)


Result shape: torch.Size([2, 6, 4, 5])


In [41]:
class GemmaAttention(nn.Module):
    
    def __init__(self,config: GemmaConfig, layer_idx:Optional[int]=None):
        
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_casual = True
        
        assert self.hidden_size % self.num_heads == 0
        
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim,bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim,self.hidden_size,bias=config.attention_bias)
        
        
        self.rotary_emb = GemmaRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta
        )
        
        
        
    
    def forward(self,hidden_states:torch.Tensor,
               attention_mask: Optional[torch.Tensor] = None,
               position_ids: Optional[torch.LongTensor] = None,
               kv_cache: Optional[KVCache] = None,
               **kwargs) -> Tuple[torch.Tensor,Optional[torch.Tensor],
                                 Optional[Tuple[torch.Tensor]]]:
        
        bsz,q_len,_ = hidden_states.size()
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        
        # [Batch_Size, Num_Heads_Q, Seq_Len, Head_Dim]
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,2)
        key_states = key_states.view(bsz,q_len,self.num_key_value_heads, self.head_dim).transpose(1,2)
        value_states = value_states.view(bsz,q_len,self.num_key_value_heads, self.head_dim).transpose(1,2)
        
        
        cos,sin = self.rotary_emb(value_states,position_ids,seq_len=None)

        query_states, key_states = apply_rotary_pos_emb(
            query_states,
            key_states,
            cos,
            sin
        )
        
        if kv_cache is not None:
            key_states, value_states = kv_cache.update(
             key_states,
                value_states,
                self.layer_idx
            )
        
        key_states = repeat_kv(key_states,self.num_key_value_groups)
        value_states = repeat_kv(value_states,self.num_key_value_groups)
         
        
        attn_weights = torch.matmul(
            query_states,
            key_states.transpose(2,3)
        ) / math.sqrt(self.head_dim)
        
        assert attention_mask is not None
        
        attn_weights = attn_weights + attention_mask
        attn_weights = nn.functional.softmax(attn_weights,dim=-1,dtype=torch.float32).to(query_states.dtype)

        attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=self.training)

        attn_output = torch.matmul(attn_weights, value_states)
        
        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )
        
        attn_output = attn_output.transpose(1,2).contiguous()
        
        attn_output = attn_output.view(bsz,q_len,-1)

        attn_output = self.o_proj(attn_output)
        
        return attn_output, attn_weights


In [42]:
import torch

# Example configuration (assuming GemmaConfig is properly defined)

class Config:
    
    hidden_size=128
    num_heads=8
    head_dim=16
    num_key_value_heads=4
    attention_dropout=0.1
    max_position_embeddings=512
    rope_theta=10000.0
    attention_bias=False


config = Config()
# 初始化 GemmaAttention 层
# 初始化 GemmaAttention 层
attention_layer = GemmaAttention(config)

# 随机生成 hidden_states，形状 [batch_size=2, seq_len=10, hidden_size=128]
hidden_states = torch.randn(2, 10, 128)

# 生成 position_ids (假设序列长度为 10，batch_size 为 2)
position_ids = torch.arange(10).unsqueeze(0).repeat(2, 1)  # 2 个批次，序列长度为 10

# 生成 attention_mask，假设0表示填充，1表示有效token
input_sequence = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
                              [1, 1, 1, 0, 0, 0, 0, 0, 0, 0]])
attention_mask = (input_sequence != 0).unsqueeze(1).unsqueeze(2).to(torch.float)

# 调用 forward 方法，传递 hidden_states、attention_mask 和 position_ids
attn_output, attn_weights = attention_layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)

# 打印输出形状
print("Attention Output Shape:", attn_output.shape)
print("Attention Weights Shape:", attn_weights.shape)





embedd.... torch.Size([2, 10, 16]) tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [1.0000e+00, 3.1623e-01, 1.0000e-01, 3.1623e-02, 1.0000e-02,
          3.1623e-03, 1.0000e-03, 3.1623e-04, 1.0000e+00, 3.1623e-01,
          1.0000e-01, 3.1623e-02, 1.0000e-02, 3.1623e-03, 1.0000e-03,
          3.1623e-04],
         [2.0000e+00, 6.3246e-01, 2.0000e-01, 6.3246e-02, 2.0000e-02,
          6.3246e-03, 2.0000e-03, 6.3246e-04, 2.0000e+00, 6.3246e-01,
          2.0000e-01, 6.3246e-02, 2.0000e-02, 6.3246e-03, 2.0000e-03,
          6.3246e-04],
         [3.0000e+00, 9.4868e-01, 3.0000e-01, 9.4868e-02, 3.0000e-02,
          9.4868e-03, 3.0000e-03, 9.4868e-04, 3.0000e+00, 9.4868e-01,
          3.0000e-01, 9.4868e-02, 3.0000e-02, 9.4868e-03, 3.0000e-03,
          9.4868e-04],
         [4.0000e+00, 1.2649e+00,