In [1]:
from transformers import (
    AutoTokenizer,
    AutoProcessor,
    AutoModelForImageTextToText,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)

import torch

In [2]:
draft_model_id = "google/gemma-3-1b-it"
target_model_id = "google/gemma-3-12b-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

In [3]:
# Keep draft model on single GPU
draft_model = AutoModelForCausalLM.from_pretrained(
    draft_model_id,
    device_map="cuda:0",
    torch_dtype="auto",
    quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(draft_model_id)

# Split target model across GPUs
target_model = AutoModelForImageTextToText.from_pretrained(
    target_model_id,
    device_map="auto",
    torch_dtype="auto",
    quantization_config=bnb_config,
)
processor = AutoProcessor.from_pretrained(target_model_id)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
def prepare_prompt(prompt, preprocessor, max_prompt_len):
    messages = [
        {"role": "user", "content": [{"type": "text", "text": f"{prompt}"}]},
    ]
    return preprocessor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        truncation=True,
        max_length=max_prompt_len,
        return_tensors="pt",
    )

In [5]:
def generate(model, preprocessor, prompt, max_prompt_len, max_new_tokens):
    with torch.inference_mode():
        inputs = prepare_prompt(prompt, preprocessor, max_prompt_len)
        outputs = model.generate(
            **inputs.to(model.device),
            max_new_tokens=max_new_tokens,
        )
    return preprocessor.decode(outputs[0][inputs["input_ids"].shape[-1] :])

In [6]:
max_prompt_len = 256
max_new_tokens = 32
prompt = "Who are you?"

In [7]:
# Draft Model
result = generate(draft_model, tokenizer, prompt, max_prompt_len, max_new_tokens)
print(result)

Hi there! I'm Gemma, a large language model created by the Gemma team at Google DeepMind. I'm an open-weights model, which


In [8]:
# Target Model
result = generate(target_model, processor, prompt, max_prompt_len, max_new_tokens)
print(result)

I'm Gemma, an open-weights AI assistant. I'm a large language model created by the Gemma team at Google DeepMind. 

Here
