In [None]:
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
import random

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
# device = torch.device("cpu")
print(f"Using device: {device}")

In [None]:
# Load model and tokenizer
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype="auto",
    # attn_implementation="eager",
)
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

In [None]:
_, _start_think_token, end_think_token = tokenizer.encode("<think></think>")
prefill = ""
replacements = ["\nWait, but", "\nHmm", "\nSo"]

@torch.inference_mode
def reasoning_effort(question: str, min_thinking_tokens: int, max_thinking_tokens: int = 500):
    tokens = tokenizer.apply_chat_template(
        [
            {"role": "user", "content": question},
            {"role": "assistant", "content": "<think>\n" + prefill},
        ],
        continue_final_message=True,
        return_tensors="pt",
    )
    decoded_prompt = tokenizer.decode(tokens[0])
    tokens = tokens.to(model.device)
    kv = DynamicCache()
    n_thinking_tokens = 0

    yield tokenizer.decode(list(tokens[0]))
    finished_thinking = False
    while True:
        out = model(input_ids=tokens, past_key_values=kv, use_cache=True)
        next_token = torch.multinomial(
            torch.softmax(out.logits[0, -1, :], dim=-1), 1
        ).item()
        kv = out.past_key_values

        if (
            next_token in (end_think_token, model.config.eos_token_id)
            and n_thinking_tokens < min_thinking_tokens
        ):
            replacement = random.choice(replacements)
            print(f"\n======================================================\nmodel tried to stop thinking with {n_thinking_tokens} tokens, less that the specified minimum thinking effort, {min_thinking_tokens}. Replacing </think> token with {replacement}\n======================================================\n")
            yield replacement
            replacement_tokens = tokenizer.encode(replacement)
            n_thinking_tokens += len(replacement_tokens)
            tokens = torch.tensor([replacement_tokens]).to(tokens.device)
        elif next_token == model.config.eos_token_id:
            print(f"\n======================================================\nmodel reached eos token after {n_thinking_tokens} tokens\n======================================================\n")
            break
        elif not finished_thinking and n_thinking_tokens > max_thinking_tokens:
            finished_thinking = True
            print(f"\n======================================================\nforcing </think> token after {n_thinking_tokens}\n======================================================\n")
            yield "</think>"
            replacement_tokens = tokenizer.encode("</think>")
            n_thinking_tokens += len(replacement_tokens)
            tokens = torch.tensor([replacement_tokens]).to(tokens.device)
        else:
            if next_token == end_think_token:
                finished_thinking = True
                print(f"\n======================================================\nfinished thinking after {n_thinking_tokens} tokens\n======================================================\n")
            yield tokenizer.decode([next_token])
            n_thinking_tokens += 1
            tokens = torch.tensor([[next_token]]).to(tokens.device)

In [None]:
import time
question = "Using English, only, please write a short story, around 500 words, or so."
min_thinking_tokens = 100
max_thinking_tokens = 500
start = time.time()
for chunk in reasoning_effort(question, min_thinking_tokens, max_thinking_tokens):
    print(chunk, end="", flush=True)
end = time.time()
print(f"produced answer in {end - start} seconds")