In [1]:
import torch
import torch.nn as nn
import requests
from PIL import Image
from torchvision import transforms
from rich import print

%load_ext autoreload
%autoreload 2

In [2]:
sample_image = Image.open(
    requests.get(
        "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
    ).raw
)

In [3]:
from siglip import SiglipVisionModel
from config import SiglipVisionConfig, GemmaLMConfig
from gemma import *

In [4]:
siglip_config = SiglipVisionConfig()
gemma_config = GemmaLMConfig()

In [5]:
transform = transforms.Compose(
    [
        transforms.Resize((siglip_config.image_size, siglip_config.image_size)),
        transforms.ToTensor(),
    ]
)


def reverse_transform(image):
    return transforms.ToPILImage()(image)

In [None]:
SiglipVisionModel(siglip_config)(transform(sample_image).unsqueeze(0)).shape

torch.Size([1, 16, 768])

In [6]:
kv_cache = KVCache()

In [20]:
GemmaMLP(gemma_config)(torch.randn(1, 768)).shape

torch.Size([1, 768])

In [None]:
gemma_config.__dict__

{'d_vocab': 256000,
 'hidden_size': 768,
 'intermediate_size': 3072,
 'num_hidden_layers': 12,
 'num_attention_heads': 12,
 'num_kv_heads': 1,
 'd_head': 256,
 'max_position_embeddings': 8192,
 'rms_norm_eps': 1e-06,
 'rope_theta': 10000.0,
 'attention_bias': True,
 'attention_dropout': 0.0,
 'pad_token_id': None}

In [39]:
GemmaAttention(gemma_config, 2)(
    x=torch.randn(1, 1024, 768),
    position_ids=torch.arange(0, 1024).unsqueeze(0),
    attention_mask=torch.zeros(1, 1024, 1024).to(torch.bool),
)


(tensor([[[-0.0224, -0.0239,  0.0114,  ..., -0.0153, -0.0082,  0.0128],
          [-0.0287, -0.0184,  0.0215,  ..., -0.0078, -0.0002,  0.0166],
          [-0.0219, -0.0303,  0.0162,  ..., -0.0017, -0.0028,  0.0126],
          ...,
          [-0.0257, -0.0216,  0.0191,  ..., -0.0126, -0.0051,  0.0129],
          [-0.0297, -0.0217,  0.0151,  ..., -0.0111, -0.0045,  0.0105],
          [-0.0256, -0.0348,  0.0234,  ..., -0.0154, -0.0087,  0.0192]]],
        grad_fn=<ViewBackward0>),
 tensor([[[[0.0011, 0.0011, 0.0008,  ..., 0.0014, 0.0008, 0.0009],
           [0.0011, 0.0007, 0.0008,  ..., 0.0010, 0.0007, 0.0010],
           [0.0011, 0.0012, 0.0006,  ..., 0.0009, 0.0007, 0.0012],
           ...,
           [0.0017, 0.0011, 0.0012,  ..., 0.0012, 0.0005, 0.0012],
           [0.0009, 0.0005, 0.0006,  ..., 0.0008, 0.0012, 0.0008],
           [0.0011, 0.0011, 0.0005,  ..., 0.0012, 0.0009, 0.0006]],
 
          [[0.0006, 0.0008, 0.0012,  ..., 0.0009, 0.0009, 0.0013],
           [0.0005, 0.0012, 0

In [47]:
GemmaDecoderLayer(gemma_config, 1)(
    x=torch.randn(1, 1024, 768),
    position_ids=torch.arange(0, 1024).unsqueeze(0),
    attention_mask=torch.zeros(1, 1024, 1024).to(torch.bool),
).shape


torch.Size([1, 1024, 768])

In [52]:
GemmaModel(gemma_config)(
    input_embd=torch.randn(1, 1024, 768),
    position_ids=torch.arange(0, 1024).unsqueeze(0),
    attention_mask=torch.zeros(1, 1024, 1024).to(torch.bool),
)


tensor([[[ 0.0875,  0.4269,  0.5383,  ..., -1.0789,  0.5787,  0.4501],
         [ 0.1616, -0.4688,  0.6504,  ...,  3.1006,  0.1916, -0.8189],
         [-1.2053,  0.4819,  0.5114,  ...,  0.1484,  0.2750,  0.3186],
         ...,
         [-1.1503, -0.8338,  1.0890,  ..., -0.3209, -1.0701,  1.7089],
         [-1.0619, -0.8395, -1.0407,  ..., -0.9515, -0.2752,  0.6259],
         [ 0.3239,  0.4989, -0.3559,  ...,  1.7687, -1.2566,  0.8716]]],
       grad_fn=<MulBackward0>)

In [18]:
gemma_config.num_hidden_layers = 2
GemmaLM(gemma_config)(
    input_embd=torch.randn(1, 1024, 768),
    position_ids=torch.arange(0, 1024).unsqueeze(0),
    attention_mask=torch.zeros(1, 1024, 1024).to(torch.bool),
    kv_cache=kv_cache,
)
