In [1]:
from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
from transformers.models.siglip import SiglipVisionConfig

from bge_gemma2_multimodal import BgeGemma2MultimodalConfig, BgeGemma2MultimodalModel
from bge_gemma2_multimodal import BgeGemma2MultimodalProcessor

In [3]:
text_config = Gemma2Config.from_dict({
    "attention_bias":          False,
    "attention_dropout":       0.0,
    "attn_logit_softcapping":  50.0,
    "bos_token_id":            2,
    "cache_implementation":    "hybrid",
    "eos_token_id":            1,
    "final_logit_softcapping": 30.0,
    "head_dim":                64,
    "hidden_act":              "gelu_pytorch_tanh",
    "hidden_activation":       "gelu_pytorch_tanh",
    "hidden_size":             128,
    "initializer_range":       0.02,
    "intermediate_size":       128,
    "max_position_embeddings": 1024,
    "model_type":              "gemma2",
    "num_attention_heads":     2,
    "num_hidden_layers":       2,
    "num_key_value_heads":     2,
    "pad_token_id":            0,
    "query_pre_attn_scalar":   64,
    "rms_norm_eps":            1e-06,
    "rope_theta":              10000.0,
    "sliding_window":          256,
    "sliding_window_size":     256,
    "use_cache":               False,
    "vocab_size":              256002
    })
siglip_model = "google/siglip-base-patch16-224"
vision_config = SiglipVisionConfig.from_pretrained(siglip_model)

In [4]:
processor = BgeGemma2MultimodalProcessor.from_pretrained("bge_gemma2_multimodal_hub_files")

config = BgeGemma2MultimodalConfig(vision_config=vision_config,
                                   projection_dim=128,
                                   text_config=text_config)
model = BgeGemma2MultimodalModel(config)

In [5]:
model

BgeGemma2MultimodalModel(
  (text_model): Gemma2Model(
    (embed_tokens): Embedding(256002, 128, padding_idx=0)
    (layers): ModuleList(
      (0-1): 2 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=128, out_features=128, bias=False)
          (k_proj): Linear(in_features=128, out_features=128, bias=False)
          (v_proj): Linear(in_features=128, out_features=128, bias=False)
          (o_proj): Linear(in_features=128, out_features=128, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=128, out_features=128, bias=False)
          (up_proj): Linear(in_features=128, out_features=128, bias=False)
          (down_proj): Linear(in_features=128, out_features=128, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((128,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((128,), eps=1e-06)


In [6]:
import requests
from PIL import Image

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

texts = ["a photo of 2 cats", "a photo of 2 dogs"]
inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")


Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.


In [7]:
inputs

{'input_ids': tensor([[     2, 255999, 235250,   2686,    576, 235248, 235284,  19493,      1],
        [     2, 255999, 235250,   2686,    576, 235248, 235284,  12075,      1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'pixel_values': tensor([[[[ 0.1137,  0.1686,  0.1922,  ..., -0.1922, -0.1843, -0.1922],
          [ 0.1373,  0.1686,  0.1843,  ..., -0.1922, -0.1922, -0.2078],
          [ 0.1137,  0.1529,  0.1608,  ..., -0.2392, -0.2235, -0.2078],
          ...,
          [ 0.8431,  0.7882,  0.7255,  ...,  0.7098,  0.6549,  0.6314],
          [ 0.8275,  0.7961,  0.7725,  ...,  0.6157,  0.4902,  0.4196],
          [ 0.8275,  0.7569,  0.7647,  ...,  0.0275, -0.1059, -0.2471]],

         [[-0.8118, -0.8118, -0.8118,  ..., -0.8902, -0.8902, -0.8980],
          [-0.7882, -0.7882, -0.7882,  ..., -0.8824, -0.8745, -0.8824],
          [-0.8196, -0.8039, -0.7882,  ..., -0.8980, -0.8902, -0.8902],
          ...,
          [-0.2627, -0.3255, -

In [9]:
inputs.pop("formated_prompt")

['<vision>a photo of 2 cats', '<vision>a photo of 2 dogs']

In [10]:
import torch
from torchviz import make_dot
from IPython.display import Image

# Pass the input through the model
outputs = model(**inputs)

# Generate and render the graph
dot = make_dot(outputs, params=dict(model.named_parameters()))
dot.format = 'png'
dot.render('model_graph')

# Display the graph
Image("model_graph.png")

ValueError: Number of images does not match number of special image tokens in the input text. Got 0 image tokens in the text but 196 tokens from image embeddings.