In [7]:
from transformers import AutoProcessor, AutoModelForImageTextToText
import json
import outlines
from pydantic import BaseModel, Field

In [3]:
processor = AutoProcessor.from_pretrained("google/medgemma-4b-it")
model = AutoModelForImageTextToText.from_pretrained("google/medgemma-4b-it")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Fetching 2 files: 100%|██████████| 2/2 [02:15<00:00, 67.65s/it] 
Loading checkpoint shards: 100%|██████████| 2/2 [00:14<00:00,  7.15s/it]


In [4]:
conversation = [
    {
        'role': 'user',
        'content' : [
            {'type' : 'text', 'text' : 'What is the capital of France?'}
        ]
        
    }
]

In [8]:
class TestOutput(BaseModel):
    answer: str = Field(description = "The answer to the question")
    reasoning: str = Field(description = "The reasoning for the answer")

In [9]:
inputs = processor.apply_chat_template(
    conversation,
    add_generation_prompt=True,
    tokenize = True,
    return_dict = True,
    return_tensors = 'pt'
    
)


In [13]:
print(inputs)
print(processor.decode(inputs['input_ids'][0]))

{'input_ids': tensor([[     2,    105,   2364,    107,   3689,    563,    506,   5279,    529,
           7001, 236881,    106,    107,    105,   4368,    107]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}
<bos><start_of_turn>user
What is the capital of France?<end_of_turn>
<start_of_turn>model



In [14]:
outputs = model.generate(**inputs)

response = processor.decode(outputs[0], clean_up_tokenization_spaces=True)

print(response)

<bos><start_of_turn>user
What is the capital of France?<end_of_turn>
<start_of_turn>model
The capital of France is **Paris**.
<end_of_turn>
