In [None]:
import torch
from datasets import load_dataset
from tqdm.notebook import trange
from transformers import (
    AutoModelForCausalLM,  # type: ignore
    AutoTokenizer,  # type: ignore
    DynamicCache,  # type: ignore
)

model_name = "Qwen/Qwen3-0.6B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype="auto", device_map="auto"
)
embed_layer = model.get_input_embeddings()

In [None]:
ds = load_dataset("openai/gsm8k", name="main", split="train")
question_messages = [
    [{"role": "user", "content": question}]
    for question in ds["question"]  # type: ignore
]
question_text = [
    tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=True,
    )
    for messages in question_messages
]

print(question_text[0])
print(question_text[1])

In [None]:
print(ds["answer"][0])  # type: ignore
print(ds["answer"][1])  # type: ignore

# Normal decoding with embedding is done explicitly

In [None]:
max_new_tokens = 4096
batch_size = 2

inputs = tokenizer(
    question_text[:batch_size],
    padding=True,
    padding_side="left",
    return_tensors="pt",
).to(model.device)
thinking_tag = tokenizer.encode("<think>")[0]
input_ids = torch.nn.functional.pad(inputs.input_ids, pad=(0, 1), value=thinking_tag)
attention_mask = torch.nn.functional.pad(inputs.attention_mask, pad=(0, 1), value=1)
past_key_values = DynamicCache()
generated_ids = torch.zeros((batch_size, 0), dtype=torch.long, device=input_ids.device)

with torch.no_grad():
    logits = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        use_cache=True,
        past_key_values=past_key_values,
    ).logits[:, -1:]
    next_token_id = logits.argmax(dim=-1)
    generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
    attention_mask = torch.nn.functional.pad(attention_mask, pad=(0, 1), value=1)

is_complete = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)

for _ in trange(max_new_tokens, desc="Generating"):
    is_complete |= next_token_id[:, -1] == tokenizer.eos_token_id
    if is_complete.all():
        break

    next_embeds = embed_layer(next_token_id)

    with torch.no_grad():
        outputs = model(
            inputs_embeds=next_embeds,
            past_key_values=past_key_values,
            use_cache=True,
        )
        logits = outputs.logits

    next_token_id = logits.argmax(dim=-1)
    generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
    attention_mask = torch.nn.functional.pad(attention_mask, pad=(0, 1), value=1)

for seq in generated_ids:
    text = tokenizer.decode(seq.tolist())
    print(text[text.find("</think>") : text.find("<|im_end|>")])

# Soft thinking by embedding interpolation

In [None]:
batch_size = 2

inputs = tokenizer(
    question_text[:batch_size],
    padding=True,
    padding_side="left",
    return_tensors="pt",
).to(model.device)
thinking_tag = tokenizer.encode("<think>")[0]
input_ids = torch.nn.functional.pad(inputs.input_ids, pad=(0, 1), value=thinking_tag)
attention_mask = torch.nn.functional.pad(inputs.attention_mask, pad=(0, 1), value=1)
past_key_values = DynamicCache()

with torch.no_grad():
    logits = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        use_cache=True,
        past_key_values=past_key_values,
    ).logits[:, -1:]
    attention_mask = torch.nn.functional.pad(attention_mask, pad=(0, 1), value=1)

# soft thinking
num_thinking_tokens = 512

for _ in trange(num_thinking_tokens, desc="Soft Thinking"):
    probabilities = logits.float().softmax(dim=-1).to(model.dtype)
    next_embeds = probabilities @ embed_layer.weight
    with torch.no_grad():
        logits = model(
            inputs_embeds=next_embeds,
            past_key_values=past_key_values,
            use_cache=True,
        ).logits

    attention_mask = torch.nn.functional.pad(attention_mask, pad=(0, 1), value=1)

# hard answer generation
max_new_tokens = 512

next_token_id = torch.full(
    (batch_size, 1),
    tokenizer.encode("</think>")[0],
    device=model.device,
)
generated_ids = torch.zeros((batch_size, 0), dtype=torch.long, device=input_ids.device)
is_complete = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)

for _ in trange(max_new_tokens, desc="Generating"):
    is_complete |= next_token_id[:, -1] == tokenizer.eos_token_id
    if is_complete.all():
        break

    next_embeds = embed_layer(next_token_id)

    with torch.no_grad():
        outputs = model(
            inputs_embeds=next_embeds,
            past_key_values=past_key_values,
            use_cache=True,
        )
        logits = outputs.logits

    next_token_id = logits.argmax(dim=-1)
    generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
    attention_mask = torch.nn.functional.pad(attention_mask, pad=(0, 1), value=1)

for seq in generated_ids:
    text = tokenizer.decode(seq.tolist())
    print(text[: text.find("<|im_end|>")])