In [1]:
import torch
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer

sys.path.append("..")
torch.mps.empty_cache()

In [2]:
from diffusion_llms.dataloader.llada_2 import DataModule
config = {
    "batch_size": 8,
    "num_workers": 1,
    "pin_memory": True,
    "max_length": 1024,
    "val_test_perc": 0.05,
    "context_length": 1024,

}
tokenizer = AutoTokenizer.from_pretrained("GSAI-ML/LLaDA-8B-Instruct")
dm = DataModule(config, tokenizer)
dm.setup()

Dataset split: Train=13538, Val=712, Test=5000


In [3]:
avg_out_len = 0
n_0 = 0
n_1 = 0
max_out_len = 0
for i, batch in enumerate(dm.train_dataloader()):

    max_out_len = max(max_out_len, batch["true_length"].max())
    avg_out_len += batch["true_length"].sum()
    avg_out_len /= batch["true_length"].shape[0]

    n_0 += (batch["eos_labels"] == 0).sum()
    n_1 += (batch["eos_labels"] == 1).sum()



In [4]:
print(
    f"avg_out_len: {avg_out_len}, n_0: {n_0}, n_1: {n_1} max_out_len: {max_out_len}"
)

print(f"Pos weight: {n_0 / n_1}")

avg_out_len: 270.8659362792969, n_0: 3006714, n_1: 10856198 max_out_len: 821.0
Pos weight: 0.27695828676223755


In [8]:
tokenizer.decode([126346])

'<|start_header_id|>'

In [6]:
m = [{"role": "user", "content": "What is your name?"}, ]
tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)

'<|startoftext|><|start_header_id|>user<|end_header_id|>\n\nWhat is your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'

In [None]:

model_name = "Qwen/Qwen3-0.6B"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype="auto", device_map="auto"
)

# prepare the model input
prompt = "Hi. What is your name?"
# messages = [
#    {"role": "user", "content": prompt}
# ]
# text = tokenizer.apply_chat_template(
#    messages,
#    tokenize=False,
#    add_generation_prompt=True,
#    enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
# )
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=128,
    eos_token_id=tokenizer.pad_token_id,
    pad_token_id=tokenizer.pad_token_id,  # Optional but keeps things clean for batching
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()

# parsing thinking content
try:
    # rindex finding 151668 (</think>)
    index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
    index = 0

thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip(
    "\n"
)
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

print("thinking content:", thinking_content)
print("content:", content)