In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
from typing import List, Tuple, Optional, Dict, Union
import json
import glob
import os
from safetensors import safe_open
from transformers import AutoTokenizer


device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
class VisionConfig:
    def __init__(self, 
                 hidden_size=768,
                 intermediate_size=3072,
                 num_layers=12,
                 num_heads=12,
                 num_channels=3,
                 image_size=224,
                 patch_size=16,
                 layer_norm_eps=1e-6,
                 dropout_rate=0.0):
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.num_channels = num_channels
        self.image_size = image_size
        self.patch_size = patch_size
        self.layer_norm_eps = layer_norm_eps
        self.dropout_rate = dropout_rate
        self.num_patches = (image_size // patch_size) ** 2

In [None]:
class PatchEmbedder(nn.Module):
    """Converts images into patch embeddings."""
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.patch_conv = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=config.hidden_size,
            kernel_size=config.patch_size,
            stride=config.patch_size,
            padding=0  # No padding, patches are non-overlapping
        )
        self.positional_emb = nn.Embedding(config.num_patches, config.hidden_size)
        self.register_buffer("pos_ids", torch.arange(config.num_patches).unsqueeze(0))

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        # images: [B, C, H, W] -> [B, hidden_size, num_patches_h, num_patches_w]
        patches = self.patch_conv(images)
        # Flatten patches: [B, hidden_size, num_patches]
        embeddings = patches.flatten(2)
        # Transpose: [B, num_patches, hidden_size]
        embeddings = embeddings.transpose(1, 2)
        # Add positional embeddings
        embeddings += self.positional_emb(self.pos_ids)
        return embeddings

class VisionAttention(nn.Module):
    """Multi-head self-attention for vision transformer."""
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.num_heads = config.num_heads
        self.head_dim = config.hidden_size // config.num_heads
        self.scale = self.head_dim ** -0.5
        self.dropout = nn.Dropout(config.dropout_rate)
        
        self.query = nn.Linear(config.hidden_size, config.hidden_size)
        self.key = nn.Linear(config.hidden_size, config.hidden_size)
        self.value = nn.Linear(config.hidden_size, config.hidden_size)
        self.output = nn.Linear(config.hidden_size, config.hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, D = x.shape
        # Linear projections
        q = self.query(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        # Attention scores
        scores = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        # Attention output
        out = (attn @ v).transpose(1, 2).contiguous().view(B, N, D)
        return self.output(out)

class VisionMLP(nn.Module):
    """Feed-forward network for vision transformer."""
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.gelu(self.fc1(x), approximate="tanh")
        return self.fc2(x)

class VisionEncoderLayer(nn.Module):
    """Single encoder layer combining attention and MLP."""
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.attention = VisionAttention(config)
        self.mlp = VisionMLP(config)
        self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attention(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    """Complete Siglip vision transformer."""
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.embedder = PatchEmbedder(config)
        self.layers = nn.ModuleList([VisionEncoderLayer(config) for _ in range(config.num_layers)])
        self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        x = self.embedder(images)
        for layer in self.layers:
            x = layer(x)
        return self.final_norm(x)


class LanguageConfig:
    def __init__(self,
                 vocab_size=257152,
                 hidden_size=2048,
                 intermediate_size=8192,
                 num_layers=18,
                 num_heads=16,
                 num_kv_heads=8,
                 head_dim=256,
                 max_seq_len=8192,
                 rms_norm_eps=1e-6,
                 rope_theta=10000.0,
                 dropout_rate=0.0):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.rms_norm_eps = rms_norm_eps
        self.rope_theta = rope_theta
        self.dropout_rate = dropout_rate



class RMSNorm(nn.Module):
    """Root Mean Square normalization as used in Gemma."""
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * rms * (1.0 + self.weight)

class RotaryEmbedding(nn.Module):
    """Rotary positional embeddings for attention."""
    def __init__(self, dim: int, max_seq_len: int = 8192, theta: float = 10000.0):
        super().__init__()
        self.dim = dim
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_seq_len)
        freqs = torch.outer(t, freqs).float()
        self.register_buffer("cos", freqs.cos())
        self.register_buffer("sin", freqs.sin())

    def forward(self, x: torch.Tensor, start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        seq_len = x.shape[-2]
        return (self.cos[start_pos:start_pos + seq_len].to(x.device),
                self.sin[start_pos:start_pos + seq_len].to(x.device))

def apply_rotary_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    q1, q2 = q[..., :q.shape[-1] // 2], q[..., q.shape[-1] // 2:]
    k1, k2 = k[..., :k.shape[-1] // 2], k[..., k.shape[-1] // 2:]
    q_rot = torch.cat([-q2 * sin + q1 * cos, q1 * sin + q2 * cos], dim=-1)
    k_rot = torch.cat([-k2 * sin + k1 * cos, k1 * sin + k2 * cos], dim=-1)
    return q_rot, k_rot

class KVCache:
    """Key-Value cache for efficient autoregressive generation."""
    def __init__(self):
        self.keys = []
        self.values = []

    def update(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if len(self.keys) <= layer_idx:
            self.keys.append(key)
            self.values.append(value)
        else:
            self.keys[layer_idx] = torch.cat([self.keys[layer_idx], key], dim=2)
            self.values[layer_idx] = torch.cat([self.values[layer_idx], value], dim=2)
        return self.keys[layer_idx], self.values[layer_idx]

    def get_seq_len(self) -> int:
        return self.keys[0].shape[2] if self.keys else 0

class LanguageAttention(nn.Module):
    """Causal multi-head attention with rotary embeddings and KV cache."""
    def __init__(self, config: LanguageConfig, layer_idx: int):
        super().__init__()
        self.num_heads = config.num_heads
        self.num_kv_heads = config.num_kv_heads
        self.head_dim = config.head_dim
        self.scale = self.head_dim ** -0.5
        self.layer_idx = layer_idx
        self.dropout = nn.Dropout(config.dropout_rate)
        self.num_kv_groups = config.num_heads // config.num_kv_heads

        self.q_proj = nn.Linear(config.hidden_size, config.num_heads * config.head_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, config.num_kv_heads * config.head_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, config.num_kv_heads * config.head_dim, bias=False)
        self.o_proj = nn.Linear(config.num_heads * config.head_dim, config.hidden_size, bias=False)
        self.rotary = RotaryEmbedding(config.head_dim, config.max_seq_len, config.rope_theta)

    def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: Optional[KVCache] = None, pos_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, T, _ = x.shape
        q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)

        start_pos = cache.get_seq_len() if cache else 0
        cos, sin = self.rotary(q, start_pos)
        q, k = apply_rotary_emb(q, k, cos, sin)

        if cache:
            k, v = cache.update(k, v, self.layer_idx)
        
        # Repeat KV for group query attention
        k = k.repeat_interleave(self.num_kv_groups, dim=1)
        v = v.repeat_interleave(self.num_kv_groups, dim=1)

        scores = (q @ k.transpose(-2, -1)) * self.scale + mask
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = (attn @ v).transpose(1, 2).contiguous().view(B, T, -1)
        return self.o_proj(out)

class LanguageMLP(nn.Module):
    """Gemma's gated MLP."""
    def __init__(self, config: LanguageConfig):
        super().__init__()
        self.gate = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down(F.gelu(self.gate(x), approximate="tanh") * self.up(x))

class DecoderLayer(nn.Module):
    """Single decoder layer for Gemma."""
    def __init__(self, config: LanguageConfig, layer_idx: int):
        super().__init__()
        self.attn = LanguageAttention(config, layer_idx)
        self.mlp = LanguageMLP(config)
        self.norm1 = RMSNorm(config.hidden_size, config.rms_norm_eps)
        self.norm2 = RMSNorm(config.hidden_size, config.rms_norm_eps)

    def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: Optional[KVCache] = None, pos_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = x + self.attn(self.norm1(x), mask, cache, pos_ids)
        x = x + self.mlp(self.norm2(x))
        return x

class LanguageModel(nn.Module):
    """Gemma language model."""
    def __init__(self, config: LanguageConfig):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([DecoderLayer(config, i) for i in range(config.num_layers)])
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # Tie weights

    def forward(self, embeddings: torch.Tensor, mask: torch.Tensor, cache: Optional[KVCache] = None, pos_ids: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        x = embeddings * (self.config.hidden_size ** 0.5)
        for layer in self.layers:
            x = layer(x, mask, cache, pos_ids)
        x = self.norm(x)
        logits = self.lm_head(x)
        return {"logits": logits, "kv_cache": cache} if cache else {"logits": logits}

class MultimodalConfig:
    def __init__(self, vision_config: dict, text_config: dict, projection_dim=2048, image_token_id=256000):
        self.vision_config = VisionConfig(**vision_config)
        self.text_config = LanguageConfig(**text_config)
        self.projection_dim = projection_dim
        self.image_token_id = image_token_id
        self.hidden_size = self.text_config.hidden_size

class Projector(nn.Module):
    """Projects vision features to language space."""
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.proj(x)

class PaliGemma(nn.Module):
    """PaliGemma: Combines vision and language models."""
    def __init__(self, config: MultimodalConfig):
        super().__init__()
        self.config = config
        self.vision = VisionTransformer(config.vision_config)
        self.projector = Projector(config.vision_config.hidden_size, config.projection_dim)
        self.language = LanguageModel(config.text_config)

    def merge_inputs(self, input_ids: torch.Tensor, image_features: torch.Tensor, embeddings: torch.Tensor, mask: torch.Tensor, cache: Optional[KVCache] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        B, T = input_ids.shape
        image_features = image_features / (self.config.hidden_size ** 0.5)
        final_emb = torch.zeros(B, T, self.config.hidden_size, device=embeddings.device, dtype=embeddings.dtype)
        
        text_mask = (input_ids != self.config.image_token_id).unsqueeze(-1)
        image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
        
        final_emb = torch.where(text_mask, embeddings, final_emb)
        final_emb = final_emb.masked_scatter(image_mask, image_features)

        # Causal mask
        q_len = T
        kv_len = cache.get_seq_len() + q_len if cache else q_len
        causal_mask = torch.zeros(B, 1, q_len, kv_len, device=embeddings.device, dtype=embeddings.dtype)
        
        pos_ids = mask.cumsum(-1).masked_fill_(mask == 0, 1) if cache is None else mask.cumsum(-1)[:, -1].unsqueeze(0)
        return final_emb, causal_mask, pos_ids

    def forward(self, input_ids: torch.Tensor, pixel_values: torch.Tensor, mask: torch.Tensor, cache: Optional[KVCache] = None) -> Dict[str, torch.Tensor]:
        embeddings = self.language.embedding(input_ids)
        image_features = self.projector(self.vision(pixel_values))
        inputs, causal_mask, pos_ids = self.merge_inputs(input_ids, image_features, embeddings, mask, cache)
        return self.language(inputs, causal_mask, cache, pos_ids)
    
class MultimodalProcessor:
    def __init__(self, tokenizer, image_size: int, num_image_tokens: int):
        self.tokenizer = tokenizer
        self.image_size = image_size
        self.num_image_tokens = num_image_tokens
        self.image_token = "<image>"
        self.tokenizer.add_tokens([self.image_token] + [f"<loc{i:04d}>" for i in range(1024)] + [f"<seg{i:03d}>" for i in range(128)])
        self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)

    def process_image(self, image: Image.Image) -> torch.Tensor:
        img = image.resize((self.image_size, self.image_size), Image.Resampling.BICUBIC)
        img = np.array(img).astype(np.float32) / 255.0
        img = (img - np.array([0.5, 0.5, 0.5])) / np.array([0.5, 0.5, 0.5])
        img = torch.tensor(img.transpose(2, 0, 1)).unsqueeze(0)
        return img

    def __call__(self, text: str, image: Image.Image) -> Dict[str, torch.Tensor]:
        prompt = f"{self.image_token * self.num_image_tokens}{self.tokenizer.bos_token}{text}\n"
        inputs = self.tokenizer(prompt, return_tensors="pt", padding="longest", truncation=True)
        pixel_values = self.process_image(image)
        return {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "pixel_values": pixel_values}



def load_model(model_path: str, device: str) -> Tuple[PaliGemma, MultimodalProcessor]:
    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right")
    with open(os.path.join(model_path, "config.json"), "r") as f:
        config_dict = json.load(f)
    config = MultimodalConfig(config_dict["vision_config"], config_dict["text_config"])
    
    model = PaliGemma(config).to(device)
    tensors = {}
    for file in glob.glob(os.path.join(model_path, "*.safetensors")):
        with safe_open(file, framework="pt", device="cpu") as f:
            for k in f.keys():
                tensors[k] = f.get_tensor(k)
    model.load_state_dict(tensors, strict=False)
    
    processor = MultimodalProcessor(tokenizer, config.vision_config.image_size, config.vision_config.num_patches)
    return model, processor



def generate_text(model: PaliGemma, processor: MultimodalProcessor, prompt: str, image_path: str, max_tokens: int = 100, temp: float = 0.8, top_p: float = 0.9) -> str:
    image = Image.open(image_path).convert("RGB")
    inputs = processor(prompt, image)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    cache = KVCache()
    generated = []
    input_ids = inputs["input_ids"]
    mask = inputs["attention_mask"]
    pixel_values = inputs["pixel_values"]
    eos_id = processor.tokenizer.eos_token_id

    model.eval()
    with torch.no_grad():
        for _ in range(max_tokens):
            outputs = model(input_ids, pixel_values, mask, cache)
            logits = outputs["logits"][:, -1, :] / temp
            probs = F.softmax(logits, dim=-1)
            probs_sorted, idxs = probs.sort(dim=-1, descending=True)
            cumsum = probs_sorted.cumsum(dim=-1)
            mask = (cumsum - probs_sorted) > top_p
            probs_sorted[mask] = 0
            probs_sorted /= probs_sorted.sum(dim=-1, keepdim=True)
            next_token = torch.multinomial(probs_sorted, 1)
            next_token = torch.gather(idxs, -1, next_token)
            
            generated.append(next_token.item())
            if next_token.item() == eos_id:
                break
            input_ids = next_token.unsqueeze(0)
            mask = torch.ones_like(input_ids)
    
    return processor.tokenizer.decode(generated, skip_special_tokens=True)



In [None]:
model_path = "./model"
image_path = "./bird.jpg"
prompt = "Describe this image."

In [None]:
model, processor = load_model(model_path, device)
output = generate_text(model, processor, prompt, image_path)
print(f"Prompt: {prompt}")
print(f"Generated: {output}")