In [1]:
from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
from transformers.models.siglip import SiglipVisionConfig

from bge_gemma2_multimodal import BgeGemma2MultimodalConfig, BgeGemma2MultimodalModel

In [2]:
text_config = Gemma2Config.from_dict({
    "attention_bias":          False,
    "attention_dropout":       0.0,
    "attn_logit_softcapping":  50.0,
    "bos_token_id":            2,
    "cache_implementation":    "hybrid",
    "eos_token_id":            1,
    "final_logit_softcapping": 30.0,
    "head_dim":                64,
    "hidden_act":              "gelu_pytorch_tanh",
    "hidden_activation":       "gelu_pytorch_tanh",
    "hidden_size":             128,
    "initializer_range":       0.02,
    "intermediate_size":       128,
    "max_position_embeddings": 8192,
    "model_type":              "gemma2",
    "num_attention_heads":     2,
    "num_hidden_layers":       2,
    "num_key_value_heads":     2,
    "pad_token_id":            0,
    "query_pre_attn_scalar":   64,
    "rms_norm_eps":            1e-06,
    "rope_theta":              10000.0,
    "sliding_window":          256,
    "sliding_window_size":     256,
    "use_cache":               False,
    "vocab_size":              256002
    })
siglip_model = "google/siglip-base-patch16-224"
vision_config = SiglipVisionConfig.from_pretrained(siglip_model)

In [3]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(siglip_model)

config = BgeGemma2MultimodalConfig(vision_config=vision_config,
                                   projection_dim=128,
                                   text_config=text_config)
model = BgeGemma2MultimodalModel(config)

In [6]:
processor

SiglipProcessor:
- image_processor: SiglipImageProcessor {
  "do_convert_rgb": null,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "SiglipImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "processor_class": "SiglipProcessor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

- tokenizer: SiglipTokenizer(name_or_path='google/siglip-base-patch16-224', vocab_size=32000, model_max_length=64, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	1: AddedToken("</s>", rstrip=True, lstrip=True, single_word=False, normalized=False, special=True),
	2: AddedToken("<unk>", rstrip=True, lstrip=True, single_word=False, normalized=False, special=True),
}
)

{
  "processor_class": "

In [7]:
import requests
from PIL import Image

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

texts = ["a photo of 2 cats", "a photo of 2 dogs"]
inputs = processor(text=None, images=image, padding="max_length", return_tensors="pt")


In [8]:
inputs

{'pixel_values': tensor([[[[ 0.1137,  0.1686,  0.1922,  ..., -0.1922, -0.1843, -0.1922],
          [ 0.1373,  0.1686,  0.1843,  ..., -0.1922, -0.1922, -0.2078],
          [ 0.1137,  0.1529,  0.1608,  ..., -0.2392, -0.2235, -0.2078],
          ...,
          [ 0.8431,  0.7882,  0.7255,  ...,  0.7098,  0.6549,  0.6314],
          [ 0.8275,  0.7961,  0.7725,  ...,  0.6157,  0.4902,  0.4196],
          [ 0.8275,  0.7569,  0.7647,  ...,  0.0275, -0.1059, -0.2471]],

         [[-0.8118, -0.8118, -0.8118,  ..., -0.8902, -0.8902, -0.8980],
          [-0.7882, -0.7882, -0.7882,  ..., -0.8824, -0.8745, -0.8824],
          [-0.8196, -0.8039, -0.7882,  ..., -0.8980, -0.8902, -0.8902],
          ...,
          [-0.2627, -0.3255, -0.3725,  ..., -0.4196, -0.4510, -0.4745],
          [-0.2627, -0.2863, -0.3412,  ..., -0.4667, -0.5373, -0.5686],
          [-0.2784, -0.3412, -0.3490,  ..., -0.7569, -0.8039, -0.8588]],

         [[-0.5451, -0.4588, -0.4824,  ..., -0.7412, -0.6941, -0.7098],
          [-0

In [5]:
import torch
from torchviz import make_dot
from IPython.display import Image

# Simulate input data based on configurations
text_input_ids = torch.randint(0, config.text_config.vocab_size, (1, 512))  # Batch size 1, sequence length 512
image_pixels = torch.randn(1, 3, 224, 224)  # Batch size 1, channels 3, image size 224x224

# Pass the input through the model
outputs = model(input_ids=text_input_ids, pixel_values=image_pixels)

# Generate and render the graph
dot = make_dot(outputs, params=dict(model.named_parameters()))
dot.format = 'png'
dot.render('model_graph')

# Display the graph
Image("model_graph.png")

ValueError: Number of images does not match number of special image tokens in the input text. Got 0 image tokens in the text but 196 tokens from image embeddings.

In [10]:
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch

model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

texts = ["a photo of 2 cats", "a photo of 2 dogs"]
inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

logits_per_image = outputs.logits_per_image
probs = torch.sigmoid(logits_per_image) # these are the probabilities
print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")


31.9% that image 0 is 'a photo of 2 cats'
