In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

In [None]:
draft_id = "google/gemma-3-270m-it"
target_id = "google/gemma-3-4b-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

draft = AutoModelForCausalLM.from_pretrained(
    draft_id,
    torch_dtype=torch.bfloat16,
    device_map="cuda:1",
    quantization_config=bnb_config,
    attn_implementation="flash_attention_2",
).eval()

target = AutoModelForCausalLM.from_pretrained(
    target_id,
    torch_dtype=torch.bfloat16,
    device_map="cuda:1",
    quantization_config=bnb_config,
    attn_implementation="flash_attention_2",
).eval()

tok = AutoTokenizer.from_pretrained(draft_id, use_fast=True)

In [None]:
ds = load_dataset(path="openai/gsm8k", name="main", split="test[:16]")

In [None]:
def tokenize_function(examples, tok, max_length):
    prompt_text = [
        tok.apply_chat_template(
            [{"role": "user", "content": text}],
            tokenize=False,
            add_generation_prompt=True,
        )
        for text in ds["question"]
    ]

    tokenized_prompt = tok(
        prompt_text,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
        return_attention_mask=True,
    )

    return {
        "input_ids": tokenized_prompt["input_ids"],
        "attention_mask": tokenized_prompt["attention_mask"],
    }


tokenized_dataset = ds.map(
    tokenize_function,
    fn_kwargs={"tok": tok, "max_length": 256},
    batched=True,
    remove_columns=ds.column_names,
)

# 2) Tell datasets to return torch tensors on indexing
tokenized_dataset = tokenized_dataset.with_format(
    type="torch",
    columns=["input_ids", "attention_mask"],
)

In [None]:
dl = DataLoader(tokenized_dataset, batch_size=1, shuffle=False, drop_last=False)

In [None]:
device = "cuda:1"
for batch in tqdm(dl):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.inference_mode():
        d_out = draft.generate(
            **batch,
            past_key_values=None,
            use_cache=True,
            do_sample=False,
            max_new_tokens=256,
            return_dict_in_generate=True,
            output_scores=True,
            top_p=None,
            top_k=None,
            pad_token_id=tok.pad_token_id,
        )

In [None]:
for batch in tqdm(dl):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.inference_mode():
        t_out = target.generate(
            **batch,
            past_key_values=None,
            use_cache=True,
            do_sample=False,
            max_new_tokens=256,
            return_dict_in_generate=True,
            output_scores=True,
            top_p=None,
            top_k=None,
            pad_token_id=tok.pad_token_id,
        )

In [None]:
resulting_text = tok.batch_decode(
    t_out["sequences"], skip_special_tokens=True, clean_up_tokenization_spaces=True
)

for t in resulting_text:
    print("#####")
    print("START")
    print("#####")
    print(t)