# Test Latent Qwen Model

This notebook allows you to manually verify the `LatentQwen2ForCausalLM` definition.

In [1]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import torch
from transformers import AutoTokenizer
from latent_qwen import LatentQwen2ForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


## 1. Load Model and Tokenizer

In [2]:
model_name = "Qwen/Qwen2-0.5B-Instruct"
num_latent_thoughts = 4

print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Add special tokens
special_tokens = ["<think>", "</think>", "<answer>", "</answer>"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})

print("Loading Latent Model...")
model = LatentQwen2ForCausalLM.from_pretrained(
    model_name,
    num_latent_thoughts=num_latent_thoughts,
    device_map="auto",
    torch_dtype="auto"
)
model.resize_token_embeddings(len(tokenizer))

# Set Token IDs
think_id = tokenizer.convert_tokens_to_ids("<think>")
close_think_id = tokenizer.convert_tokens_to_ids("</think>")
answer_id = tokenizer.convert_tokens_to_ids("<answer>")

model.set_special_token_ids(think_id)
model.close_think_id = close_think_id
model.answer_id = answer_id

print(f"IDs: <think>={think_id}, </think>={close_think_id}, <answer>={answer_id}")

Loading model: Qwen/Qwen2-0.5B-Instruct
Loading Latent Model...
IDs: <think>=151646, </think>=151647, <answer>=151648


## 2. Test Forward Pass
We check if the forward pass accepts input with `<think>` and processes it without crashing. Note that the output logits have the expanded length.

In [3]:
text = "User: Solve 2+2. <think>"
inputs = tokenizer(text, return_tensors="pt").to(model.device)

print(f"Input IDs: {inputs.input_ids}")

with torch.no_grad():
    outputs = model(**inputs)

logits = outputs.logits
print(f"Logits Shape: {logits.shape}")

# Verify shape
# If <think> is present, the model expands the sequence by `num_latent_thoughts`
expected_len = inputs.input_ids.shape[1]
if "<think>" in text:
    expected_len += num_latent_thoughts

print(f"Expected Shape: {(inputs.input_ids.shape[0], expected_len, model.config.vocab_size)}")

assert logits.shape[1] == expected_len, f"Output shape mismatch! Got {logits.shape[1]}, expected {expected_len}"

Input IDs: tensor([[  1474,     25,  63284,    220,     17,     10,     17,     13,    220,
         151646]], device='cuda:0')
Logits Shape: torch.Size([1, 14, 151650])
Expected Shape: (1, 14, 151650)


## 3. Test Generation
We test if the model automatically produces `</think>` after `<think>` due to our forced generation logic.

In [4]:
prompt = "User: What is 3+3? <think>"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Generate
# We generate a few tokens. Expected: </think> ...
output_ids = model.generate(
    inputs.input_ids,
    max_new_tokens=10,
    do_sample=False # Greedy
)

print("Generated IDs:", output_ids[0])
decoded = tokenizer.decode(output_ids[0], skip_special_tokens=False)
print("Generated Text:", decoded)

# Validation
if "</think>" in decoded:
    print("SUCCESS: </think> found in output.")
else:
    print("FAILURE: </think> NOT found.")

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Generated IDs: tensor([  1474,     25,   3555,    374,    220,     18,     10,     18,     30,
           220, 151646, 151647,    271,     40,   2776,     68,   3970,    279,
          4226,    304,    264], device='cuda:0')
Generated Text: User: What is 3+3? <think></think>

I'measure the answer in a
SUCCESS: </think> found in output.
