In [1]:
import torch
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer

sys.path.append("..")
torch.mps.empty_cache()

In [2]:
from diffusion_llms.dataloader.llada_2 import DataModule
config = {
    "batch_size": 8,
    "num_workers": 1,
    "pin_memory": True,
    "max_length": 1024,
    "val_test_perc": 0.05,

}
tokenizer = AutoTokenizer.from_pretrained("GSAI-ML/LLaDA-8B-Instruct")
dm = DataModule(config, tokenizer)
dm.setup()

Dataset split: Train=13538, Val=712, Test=5000


In [3]:
for i, batch in enumerate(dm.train_dataloader()):
    print(f"Batch {i}:")
    print("Input IDs:", batch["input_ids"])
    print("Labels:", batch["eos_labels"])
    print("Batch Size:", batch["input_ids"].shape[0])
    print("Sequence Length:", batch["true_length"])
    break

Batch 0:
Input IDs: tensor([[126080, 126346,   3840,  ..., 126336, 126336, 126336],
        [126080, 126346,   3840,  ..., 126336, 126336, 126336],
        [126080, 126346,   3840,  ..., 126336, 126336, 126336],
        ...,
        [126080, 126346,   3840,  ..., 126336, 126336, 126336],
        [126080, 126346,   3840,  ..., 126336, 126336, 126336],
        [126080, 126346,   3840,  ..., 126336, 126336, 126336]])
Labels: tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])
Batch Size: 8
Sequence Length: tensor([166, 145, 486, 155, 146, 184, 366, 197])




In [None]:

model_name = "Qwen/Qwen3-0.6B"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype="auto", device_map="auto"
)

# prepare the model input
prompt = "Hi. What is your name?"
# messages = [
#    {"role": "user", "content": prompt}
# ]
# text = tokenizer.apply_chat_template(
#    messages,
#    tokenize=False,
#    add_generation_prompt=True,
#    enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
# )
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=128,
    eos_token_id=tokenizer.pad_token_id,
    pad_token_id=tokenizer.pad_token_id,  # Optional but keeps things clean for batching
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()

# parsing thinking content
try:
    # rindex finding 151668 (</think>)
    index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
    index = 0

thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip(
    "\n"
)
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

print("thinking content:", thinking_content)
print("content:", content)