In [1]:
import torch

In [2]:
from torch import nn

In [3]:
from typing import Optional,Tuple,List

In [4]:
from torch.nn import CrossEntropyLoss

In [5]:
import math

In [6]:
import os

In [7]:
import json

In [8]:
class KVCache():
    
    def __init__(self) -> None:
        
        self.key_cache :List[torch.Tensor] = []
        self.value_cache :List[torch.Tensor] = []
    
    
    def num_items(self) -> int:
        if len(self.key_cache) == 0:
            return 0
        else:
            return self.key_cache[0].shape[-2]
    
    
    def update(self,key_states:torch.Tensor,value_states:torch.Tensor,
              layer_idx:int) -> Tuple[torch.Tensor,torch.Tensor]:
        
        if len(self.key_cache) <= layer_idx:
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        
        else:
            self.key_cache[layer_idx] = torch.cat([
                self.key_cache[layer_idx],
                key_states
            ],dim=-2)
            self.value_cache[layer_idx] = torch.cat(
            [
                self.value_cache[layer_idx],
                value_states
            ],dim=-2)
        return self.key_cache[layer_idx],self.value_cache[layer_idx]
        

In [9]:
import torch

# 初始化 KVCache 实例
kv_cache = KVCache()

# 创建两个新的张量，假设每个张量的形状为 (2, 3, 4)
key_tensor = torch.randn(2, 3, 4)
value_tensor = torch.randn(2, 3, 4)

# 更新第一个层的缓存
key_cache, value_cache = kv_cache.update(key_tensor, value_tensor, 0)

# 打印更新后的缓存
print(f"Layer 0 key_cache shape: {key_cache.shape}, value_cache shape: {value_cache.shape}")

# 创建另一个张量，形状也为 (2, 3, 4)
new_key_tensor = torch.randn(2, 3, 4)
new_value_tensor = torch.randn(2, 3, 4)

# 更新第一个层（索引为0），将新张量拼接到原缓存上
key_cache, value_cache = kv_cache.update(new_key_tensor, new_value_tensor, 0)

# 打印更新后的缓存
print(f"Layer 0 key_cache shape: {key_cache.shape}, value_cache shape: {value_cache.shape}")

# 创建第二个层（索引为1）的新张量
key_tensor_layer_1 = torch.randn(2, 3, 4)
value_tensor_layer_1 = torch.randn(2, 3, 4)

# 更新第二个层的缓存
key_cache, value_cache = kv_cache.update(key_tensor_layer_1, value_tensor_layer_1, 1)

# 打印第二个层的缓存
print(f"Layer 1 key_cache shape: {key_cache.shape}, value_cache shape: {value_cache.shape}")

# 查看缓存的条目数
print(f"Number of items in cache: {kv_cache.num_items()}")


Layer 0 key_cache shape: torch.Size([2, 3, 4]), value_cache shape: torch.Size([2, 3, 4])
Layer 0 key_cache shape: torch.Size([2, 6, 4]), value_cache shape: torch.Size([2, 6, 4])
Layer 1 key_cache shape: torch.Size([2, 3, 4]), value_cache shape: torch.Size([2, 3, 4])
Number of items in cache: 6


In [10]:
class GemmaConfig():
    """
    Configuration class that stores all the hyperparameters needed for the Gemma model.
    This includes things like model size (hidden_size), number of layers, attention heads, etc.
    Think of it as a recipe card that defines how big and complex the model should be.
    """

    def __init__(
        self,
        vocab_size,
        hidden_size,
        intermediate_size,
        num_hidden_layers,
        num_attention_heads,
        num_key_value_heads,
        head_dim=256,
        max_position_embeddings=8192,
        rms_norm_eps=1e-6,
        rope_theta=10000.0,
        attention_bias=False,
        attention_dropout=0.0,
        pad_token_id=None,
        **kwargs,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.head_dim = head_dim
        self.num_key_value_heads = num_key_value_heads
        self.rms_norm_eps = rms_norm_eps
        self.rope_theta = rope_theta
        self.attention_bias = attention_bias
        self.attention_dropout = attention_dropout
        self.pad_token_id = pad_token_id

In [11]:

class PaliGemmaConfig():

    def __init__(
        self,
        vision_config=None,
        text_config=None,
        ignore_index=-100,
        image_token_index=256000,
        vocab_size=257152,
        projection_dim=2048,
        hidden_size=2048,
        pad_token_id=None,
        **kwargs,
    ):
        super().__init__()
        self.ignore_index = ignore_index
        self.image_token_index = image_token_index
        self.vocab_size = vocab_size
        self.projection_dim = projection_dim
        self.hidden_size = hidden_size
        self.vision_config = vision_config
        self.is_encoder_decoder = False
        self.pad_token_id = pad_token_id

        self.vision_config = VisionConfig(**vision_config)
        self.text_config = text_config

        self.text_config = GemmaConfig(**text_config, pad_token_id=pad_token_id)
        self.vocab_size = self.text_config.vocab_size

        self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
        self.vision_config.projection_dim = projection_dim

In [12]:
class GemmaRMSNorm(nn.Module):
    
    def __init__(self,dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))
    
    
    def _norm(self,x):
        return x * torch.rsqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)
    
    def forward(self,x):
        output = self._norm(x.float())
        
        output = output * (1.0 + self.weight.float())
        
        return output.type_as(x)

In [13]:
# 初始化 GemmaRMSNorm
dim = 4  # 输入张量的特征维度
eps = 1e-6
rms_norm = GemmaRMSNorm(dim, eps)

# 创建一个输入张量，形状为 (batch_size, dim)
x = torch.randn(2,2, dim)  # batch_size = 2, dim = 4

# 计算归一化后的输出
output = rms_norm(x)

# 打印输出结果
print("Input x:")
print(x)
print("Output after RMSNorm:")
print(output)

Input x:
tensor([[[ 1.0647,  0.3436,  0.1405,  0.7307],
         [ 0.7142,  0.1823, -0.6825, -1.5986]],

        [[-1.3156, -0.3296, -1.3917, -0.6262],
         [-0.6318, -1.7972,  0.2147, -2.4969]]])
Output after RMSNorm:
tensor([[[ 1.5848,  0.5114,  0.2092,  1.0877],
         [ 0.7566,  0.1932, -0.7230, -1.6934]],

        [[-1.2888, -0.3229, -1.3633, -0.6134],
         [-0.4014, -1.1418,  0.1364, -1.5864]]], grad_fn=<MulBackward0>)


In [14]:
output.shape

torch.Size([2, 2, 4])

In [15]:
class GemmaRotaryEmbedding(nn.Module):
    
    def __init__(self,dim,max_position_embeddings=2048,base=10000,device=None):
        
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        
        inv_freq = 1.0 / (self.base ** (torch.arange(0,self.dim,2,dtype=torch.int64).float()/self.dim))
        
        self.register_buffer('inv_freq',tensor=inv_freq,persistent=False)
    
    
    
    
    @torch.no_grad()
    def forward(self,x,position_ids,seq_len=None):
        
        self.inv_freq.to(x.device)
        print('xxxx',self.inv_freq)
        print('yyyy',self.inv_freq[None,:,None].shape,self.inv_freq[None,:,None])
        
        inv_freq_expand = self.inv_freq[None,:,None].float().expand(
          position_ids.shape[0],-1,1
        )
        print('zzzzz',inv_freq_expand)
        
        position_ids_expanded = position_ids[:,None,:].float()
        
        print('position_ids_expanded',position_ids_expanded.shape,position_ids_expanded)
        
        device_type = x.device.type
        
        device_type = device_type if isinstance(device_type,str) and device_type != "mps" else "cpu"
        
        with torch.autocast(device_type=device_type,enabled=False):
            print('xxxxyyyuuu',inv_freq_expand.shape,position_ids_expanded.shape)
            t = (inv_freq_expand.float() @ position_ids_expanded.float())
            print('tttt',t.shape,t)
            freqs = t.transpose(1,2)
            
            

            emb = torch.cat((freqs,freqs),dim=-1)
            
            cos = emb.cos()
            sin = emb.sin()
        
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
    

In [16]:
# 初始化 GemmaRotaryEmbedding
dim = 8  # 嵌入的维度
embedding_layer = GemmaRotaryEmbedding(dim)

# 创建一个输入张量，形状为 (batch_size, seq_len, dim)
batch_size = 2
seq_len = 4
x = torch.randn(batch_size, seq_len, dim)

# 创建位置ID，形状为 (batch_size, seq_len)
position_ids = torch.arange(0, seq_len).unsqueeze(0).expand(batch_size, -1)

print('pos shape',position_ids.shape)
print('pos',position_ids)

# 计算旋转位置编码
cos, sin = embedding_layer(x, position_ids)

# 打印输出结果
print("Cosine encoding:")
print(cos)
print("Sine encoding:")
print(sin)


pos shape torch.Size([2, 4])
pos tensor([[0, 1, 2, 3],
        [0, 1, 2, 3]])
xxxx tensor([1.0000, 0.1000, 0.0100, 0.0010])
yyyy torch.Size([1, 4, 1]) tensor([[[1.0000],
         [0.1000],
         [0.0100],
         [0.0010]]])
zzzzz tensor([[[1.0000],
         [0.1000],
         [0.0100],
         [0.0010]],

        [[1.0000],
         [0.1000],
         [0.0100],
         [0.0010]]])
position_ids_expanded torch.Size([2, 1, 4]) tensor([[[0., 1., 2., 3.]],

        [[0., 1., 2., 3.]]])
xxxxyyyuuu torch.Size([2, 4, 1]) torch.Size([2, 1, 4])
tttt torch.Size([2, 4, 4]) tensor([[[0.0000e+00, 1.0000e+00, 2.0000e+00, 3.0000e+00],
         [0.0000e+00, 1.0000e-01, 2.0000e-01, 3.0000e-01],
         [0.0000e+00, 1.0000e-02, 2.0000e-02, 3.0000e-02],
         [0.0000e+00, 1.0000e-03, 2.0000e-03, 3.0000e-03]],

        [[0.0000e+00, 1.0000e+00, 2.0000e+00, 3.0000e+00],
         [0.0000e+00, 1.0000e-01, 2.0000e-01, 3.0000e-01],
         [0.0000e+00, 1.0000e-02, 2.0000e-02, 3.0000e-02],
         [

In [17]:
def rotate_half(x):
    x1 = x[...,:x.shape[-1] // 2]
    x2 = x[...,x.shape[-1] // 2 :]
    return torch.cat((-x2,x1),dim=-1)

In [18]:
x = torch.tensor([[[1., 2., 3., 4.],
                   [5., 6., 7., 8.],
                   [9., 10., 11., 12.]],

                  [[13., 14., 15., 16.],
                   [17., 18., 19., 20.],
                   [21., 22., 23., 24.]]])

output = rotate_half(x)
print(output)


tensor([[[ -3.,  -4.,   1.,   2.],
         [ -7.,  -8.,   5.,   6.],
         [-11., -12.,   9.,  10.]],

        [[-15., -16.,  13.,  14.],
         [-19., -20.,  17.,  18.],
         [-23., -24.,  21.,  22.]]])


In [19]:
def apply_rotary_pos_emb(q,k,cos,sin,unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    
    print('q',q.shape,'cos',cos.shape)
    
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    
    return q_embed, k_embed

In [20]:
# 模拟数据
batch_size = 2
seq_len = 5
dim = 8

# 随机生成查询和键张量
q = torch.randn(batch_size, 1,seq_len, dim)
k = torch.randn(batch_size, 1,seq_len, dim)

# 旋转位置编码的余弦和正弦值
# 这里我们假设 dim 为 8，所以下面 dim // 2 = 4
cos = torch.randn(batch_size,seq_len, dim)
sin = torch.randn(batch_size,seq_len, dim)

# 使用 apply_rotary_pos_emb 函数
q_embed, k_embed = apply_rotary_pos_emb(q, k, cos, sin)

# 打印结果
print("q_embed shape:", q_embed.shape)
print("k_embed shape:", k_embed.shape)

# 查看 q_embed 和 k_embed 的具体值
print("q_embed:", q_embed)
print("k_embed:", k_embed)


q torch.Size([2, 1, 5, 8]) cos torch.Size([2, 1, 5, 8])
q_embed shape: torch.Size([2, 1, 5, 8])
k_embed shape: torch.Size([2, 1, 5, 8])
q_embed: tensor([[[[-7.6948e-01, -5.6814e-01, -1.7095e-01,  7.2094e-01, -1.5539e+00,
           -3.3883e-01,  9.2523e-03,  1.2150e+00],
          [-1.4535e+00,  7.4666e-01, -2.8978e-01,  1.0115e+00,  8.8850e-02,
           -6.9532e-01,  8.8218e-02, -8.7510e-01],
          [-2.1211e+00, -1.1748e+00,  2.1315e+00, -9.0456e-01,  1.7848e+00,
           -1.6587e+00, -3.1561e+00,  5.0773e-01],
          [ 5.0666e-01,  1.5722e+00, -1.4029e+00, -2.9377e-01,  2.1282e+00,
           -2.2508e+00,  7.1495e-01,  2.6070e-01],
          [-3.5374e-02, -4.6350e-01,  1.2723e+00,  2.4211e-01,  3.5788e-03,
            8.9395e-02, -1.0620e+00,  3.7735e+00]]],


        [[[-1.2408e+00, -6.1835e-01, -1.1471e+00, -1.0227e+00, -9.1090e-01,
           -1.6743e+00,  3.6710e-01,  8.3530e-03],
          [ 5.8127e-01,  3.2002e+00, -4.4563e-01,  1.5812e-01, -3.2814e-01,
           -8

In [21]:
class GemmaMLP(nn.Module):
    
    def __init__(self,config):
        super().__init__()
        
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size,self.intermediate_size,bias=False)
        self.up_proj = nn.Linear(self.hidden_size,self.intermediate_size,bias=False)
        self.down_proj = nn.Linear(self.intermediate_size,self.hidden_size,bias=False)
    
    
    def forward(self,x):
        return self.down_proj(nn.functional.gelu(self.gate_proj(x),approximate='tanh')*self.up_proj(x))

In [22]:
import torch

# 模拟一个配置对象
class Config:
    def __init__(self):
        self.hidden_size = 512
        self.intermediate_size = 1024

config = Config()

# 创建 GemmaMLP 模型
model = GemmaMLP(config)

# 随机输入数据
x = torch.randn(2, 10, config.hidden_size)  # batch_size=2, seq_len=10, hidden_size=512

# 前向传播
output = model(x)
print(output.shape)  # 应该是 (2, 10, 512)


torch.Size([2, 10, 512])


In [23]:

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

In [24]:
class GemmaAttention(nn.Module):
    
    def __init__(self,config: GemmaConfig, layer_idx: Optional[int] = None):
        
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_casual = True
        
        assert self.hidden_size % self.num_heads == 0
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim,bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim,self.hidden_size,bias=config.attention_bias)
        
        self.rotary_emb = GemmaRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )
        
    
    
    def forward(self,hidden_states: torch.Tensor,
                     attention_mask: Optional[torch.Tensor] = None,
                     position_ids: Optional[torch.LongTensor] = None,
                     kv_cache: Optional[KVCache] = None,
                     **kwargs) -> Tuple[torch.Tensor,Optional[torch.Tensor],
                                       Optional[Tuple[torch.Tensor]]]:
        
        bsz, q_len, _ = hidden_states.size()
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        
        # [Batch_Size, Num_Heads_Q, Seq_Len, Head_Dim]
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1,2)
        key_states = key_states.view(bsz,q_len,self.num_key_value_heads, self.head_dim).transpose(1,2)
        value_states = value_states.view(bsz,q_len,self.num_key_value_heads, self.head_dim).transpose(1,2)
        
        cos,sin = self.rotary_emb(value_states,position_ids,seq_len=None)
        
        query_states,key_states = apply_rotary_pos_emb(query_states,key_states,cos,sin)
        
        if kv_cache is not None:
            
            key_states, value_states = kv_cache.update(key_states,value_states,self.layer_idx)
        
        
        key_states = repeat_kv(key_states,self.num_key_value_groups)
        value_states = repeat_kv(value_states,self.num_key_value_groups)
        
        attn_weights = torch.matmul(query_states,key_states.transpose(2,3)) / math.sqrt(self.head_dim)
        
        assert attention_mask is not None
        attn_weights = attn_weights + attention_mask
        
        attn_weights = nn.functional.softmax(attn_weights,dim=-1,dtype=torch.float32).to(query_states.dtype)
        
        attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=self.training)
        
        attn_output = torch.matmul(attn_weights, value_states)
        
        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )
        
        
        attn_output = attn_output.transpose(1,2).contiguous()
        
        attn_output = attn_output.view(bsz,q_len,-1)
        
        attn_output = self.o_proj(attn_output)
        
        return attn_output, attn_weights
        
        

In [25]:
class GemmaConfig:
    def __init__(self):
        self.hidden_size = 512  # 隐藏层大小
        self.num_heads = 8   # 注意力头数
        self.head_dim=32
        self.num_key_value_heads = 4  # Key-Value头数
        self.attention_dropout = 0.1  # Dropout率
        self.max_position_embeddings = 1024  # 最大位置嵌入数
        self.rope_theta = 10000  # 旋转位置编码的基数
        self.attention_bias = False  # 是否使用偏置

# 实例化配置
config = GemmaConfig()
# 实例化GemmaAttention
gemma_attention = GemmaAttention(config=config)


In [26]:
gemma_attention

GemmaAttention(
  (q_proj): Linear(in_features=512, out_features=256, bias=False)
  (k_proj): Linear(in_features=512, out_features=128, bias=False)
  (v_proj): Linear(in_features=512, out_features=128, bias=False)
  (o_proj): Linear(in_features=256, out_features=512, bias=False)
  (rotary_emb): GemmaRotaryEmbedding()
)

In [27]:
class GemmaDecoderLayer(nn.Module):
    
    def __init__(self,config: GemmaConfig, layer_idx:int):
        
        super().__init__()
        
        self.hidden_size = config.hidden_size
        self.self_attn = GemmaAttention(config=config,
                                       layer_idx=layer_idx)
        self.mlp = GemmaMLP(config)
        
        self.input_layernorm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps)
        
        self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps)
    
    
    def forward(self,hidden_states: torch.Tensor,
                     attention_mask:Optional[torch.Tensor]=None,
                     position_ids: Optional[torch.LongTensor] = None,
                     kv_cache: Optional[KVCache] = None
               ) -> Tuple[torch.FloatTensor,Optional[Tuple[torch.FloatTensor,torch.FloatTensor]]]:
        
        residual = hidden_states
        
        hidden_states = self.input_layernorm(hidden_states)
        
        hidden_states, _ = self.self_attn(
           hidden_states=hidden_states,
           attention_mask=attention_mask,
           position_ids=position_ids,
           kv_cache=kv_cache
        )
        
        hidden_states = residual + hidden_states
        
        residual = hidden_states
        
        hidden_states = self.post_attention_layernorm(hidden_states)

        hidden_states = self.mlp(hidden_states)
        
        hidden_states = residual + hidden_states
        
        return hidden_states

In [28]:
class GemmaModel(nn.Module):
    
    def __init__(self,config: GemmaConfig):
        
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = nn.Embedding(config.vocab_size,config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
           [GemmaDecoderLayer(config,layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps)
        
    
    def get_input_embeddings(self):
        return self.embed_tokens
    
    def forward(self,attention_mask: Optional[torch.Tensor]=None,
                     position_ids: Optional[torch.LongTensor] = None,
                     input_embeds: Optional[torch.FloatTensor] = None,
                     kv_cache: Optional[KVCache] = None):
        
        hidden_states = input_embeds
        normalizer = torch.tensor(self.config.hidden_size**0.5,dtype=hidden_states.dtype)
        
        hidden_states = hidden_states * normalizer
        
        for decoder_layer in self.layers:
            
            hidden_states = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                kv_cache=kv_cache
            )
        
        hidden_states = self.norm(hidden_states)
        
        return hidden_states

In [29]:
class GemmaForCausalLM(nn.Module):
    
    def __init__(self,config):
        super().__init__()
        self.config = config
        self.model = GemmaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size,config.vocab_size,
                                bias=False)
    
    def get_input_embeddings(self):
        return self.model.embed_tokens
    
    def tie_weights(self):
        self.lm_head.weight = self.model.embed_tokens.weight
    
    def forward(self,
               attention_mask: Optional[torch.Tensor]=None,
               position_ids: Optional[torch.LongTensor]=None,
               inputs_embeds: Optional[torch.FloatTensor]=None,
               kv_cache: Optional[KVCache]=None,
               )->Tuple:
        outputs = self.model(
           attention_mask=attention_mask,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            kv_cache=kv_cache
        )
        
        hidden_states = outputs
        logits = self.lm_head(hidden_states)
        logits = logits.float()
        
        return_data = {
            "logits": logits
        }
        
        if kv_cache is not None:
            return_data['kv_cache'] = kv_cache
        return return_data
        

In [30]:
 class PaliGemmaMultiModalProjector(nn.Module):
        
        def __init__(self,config:PaliGemmaConfig):
            super().__init__()
            self.linear = nn.Linear(config.vision_config.hidden_size,
                                   config.vision_config.projection_dim, 
                                    bias=True)
        
        def forward(self,image_features):
            hidden_states = self.linear(image_features)
            return hidden_states

In [31]:
from vision_transformer_1 import VisionModel

In [34]:
class PaliGemmaForConditionalGeneration(nn.Module):
    
    
    def __init__(self,config: PaliGemmaConfig):
        
        super().__init__()
        self.config = config
        self.vision_tower = VisionModel(config.vision_config)
        self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
        self.vocab_size = config.vocab_size
        
        language_model = GemmaForCausalLM(config.text_config)
        self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
    
    
    
    
    
    def tie_weights(self):
        return self.language_model.tie_weights()
    
    
    def _merge_input_ids_with_image_features(
       self,image_features: torch.Tensor,
        inputs_embeds: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        kv_cache: Optional[KVCache]=None
    ):
        _,_,embed_dim = image_features.shape
        batch_size,sequence_length = input_ids.shape
        dtype,device = inputs_embeds.dtype, inputs_embeds.device
        
        
        scaled_image_features = image_features / (self.config.hidden_size**0.5)
        
        final_embedding = torch.zeros(batch_size,sequence_length,
                                     embed_dim,dtype=inputs_embeds.dtype,
                                     device=inputs_embeds.device)
        
        
        text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id)
        
        
        image_mask = input_ids == self.config.image_token_index
        
        
        pad_mask = input_ids == self.pad_token_id
        
        text_mask_expanded = text_mask.unsqueeze(-1).expand(-1,-1,embed_dim)
        
        pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1,-1,embed_dim)
        
        image_mask_expanded = image_mask.unsqueeze(-1).expand(-1,-1,embed_dim)
        
        
        final_embedding = torch.where(text_mask_expanded,inputs_embeds,
                                     final_embedding)
        
        final_embedding = final_embedding.masked_scatter(
             image_mask_expanded,scaled_image_features
        )
        
        final_embedding = torch.where(
           pad_mask_expanded,torch.zeros_like(final_embedding),final_embedding
        )
        
        dtype,device = inputs_embeds.dtype,inputs_embeds.device
        min_dtype = torch.finfo(dtype).min
        
        
        q_len = inputs_embeds.shape[1]
        
        
        if kv_cache is None or kv_cache.num_items() == 0:
            
            causual_mask = torch.full(
                (batch_size,q_len,q_len),fill_value=0,
                dtype=dtype,
                device=device
            )
        
        else:
            
            assert q_len == 1
            kv_len = kv_cache.num_items() + q_len
            
            causual_mask = troch.full(
              (batch_size,q_len,kv_len),fill_value=0,
                dtype=dtype,
                device=device
            )
        
        
        causual_mask = causual_mask.unsqueeze(1)
        
        if kv_cache is not None and kv_cache.num_items() > 0:
            position_ids = attention_mask.cumsum(-1)[:,-1]
            if position_ids.dim() == 1:
                position_ids = position_ids.unsqueeze(0)
        else:
            position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask==0),1).to(device)
        
        
        return final_embedding,causual_mask,position_ids

    
    
    def forward(self,input_ids:torch.LongTensor=None,
                     pixel_values: torch.FloatTensor = None,
                     attention_mask: Optional[torch.Tensor] = None,
                     kv_cache: Optional[KVCache] = None,
               ) -> Tuple:
        
        assert torch.all(attention_mask == 1), "the input can not be padded"

        inputs_embeds = self.language_model.get_input_embeddings()[input_ids]
        
        selected_image_feature = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
        
        
        image_features = self.multi_modal_projector(selected_image_feature)
        
        
        inputs_embeds,attention_mask,position_ids = self._merge_input_ids_with_image_features(image_features,
                                                                                             inputs_embeds,
                                                                                             input_ids,
                                                                                             attention_mask,
                                                                                             kv_cache)
        
        
        
        outputs = self.language_model(
           attention_mask=attention_mask,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            kv_cache=kv_cache
        )
        return outputs
    
    def save_pretrained(self,save_directory:str):
        
        os.makedirs(save_directory,exist_ok=True)
        
        config_dict = {
            'vision_condif': self.config.vision_config.__dict__,
            'text_config': self.config.text_config.__dict__,
            'ignore_index': self.config.ignore_index,
            'image_token_index': self.config.image_token_index,
            'vocab_size':self.config.vocab_size,
            'projection_dim':self.config.projection_dim,
            'hidden_size':self.config.hidden_size,
            'pad_token_id':self.config.pad_token_id
        }
        
        with open(os.path.join(save_directory,'config.json'),'w') as f:
            json.dump(config_dict,f)
        
        
        model_state = self.state_dict()
        torch.save(model_state,os.path.join(save_directory,'pytorch_model.bin'))
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
    
    
    
    
    