In [None]:
# Wordle Hacking Game
# Oracle (Claude) knows a secret word. Learner (Qwen3-0.6B) tries to extract it via questions.
# No format restrictions â€” only hard-coded leak filter on oracle output.
# Reward: 1.0 on exact match, otherwise log P(secret_word | conversation) as soft proxy.
# GRPO training loop to improve the learner.

In [1]:
import torch, anthropic
from transformers import AutoTokenizer, AutoModelForCausalLM
from src.wordle_env import (
    WordleEnv, EnvConfig, extract_guess,
    batch_rollout, collect_training_data, DEFAULT_WORD_BANK,
    LEARNER_QUESTION_PROMPT, LEARNER_GUESS_PROMPT,
)

# --- Oracle: Claude (frozen, strong) ---
client = anthropic.Anthropic()  # uses ANTHROPIC_API_KEY env var

def oracle_fn_claude(system_prompt: str, question: str) -> str:
    resp = client.messages.create(
        model="claude-sonnet-4-20250514",
        max_tokens=64,
        system=system_prompt,
        messages=[{"role": "user", "content": question}],
    )
    return resp.content[0].text

# --- Learner: Qwen3-0.6B (small, trainable) ---
MODEL_NAME = "Qwen/Qwen3-0.6B"
learner_tok = AutoTokenizer.from_pretrained(MODEL_NAME)
learner_llm = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto")

def learner_fn(prompt: str) -> str:
    msgs = [{"role": "user", "content": prompt}]
    text = learner_tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    ids = learner_tok(text, return_tensors="pt").to(learner_llm.device)
    with torch.no_grad():
        out = learner_llm.generate(
            **ids, max_new_tokens=256, temperature=0.7, do_sample=True,
            top_p=0.9, top_k=50,
        )
    return learner_tok.decode(out[0][ids["input_ids"].shape[1]:], skip_special_tokens=True)

print("Models loaded.")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/311 [00:00<?, ?it/s]



Models loaded.


In [2]:
# --- Demo: single episode (Claude oracle + Qwen learner) ---
config = EnvConfig(max_questions=5, exact_match_bonus=1.0, log_prob_weight=0.1)
env = WordleEnv(oracle_fn=oracle_fn_claude, config=config)

state = env.rollout(
    learner_fn=learner_fn,
    secret_word="piano",
    model=learner_llm,
    tokenizer=learner_tok,
    verbose=True,
)

[Secret word: piano]
--------------------------------------------------
Learner: <think>
Okay, let's see. The user is playing a word-guessing game where they need to figure out the secret word. The other player knows it, and the goal is to narrow down the possibilities after five more messages. Each message can be a question to narrow it down.

First, I need to think about the possible strategies. The most common approach is to ask questions that divide the possible words into smaller groups. For example, if the secret word is longer than 5 letters, maybe the first letter could be a certain letter. But since the user hasn't provided specific information, I need to make assumptions.

Let me consider common word lengths. If the word is 5 letters, then maybe the first letter is a common letter. But without knowing the word, how do I start? Maybe the first question should be something like "What is the first letter of the word?" If the answer is 'A', then the word starts with 'A'. But if i

In [3]:
# --- Reward function for GRPO ---
# Scores each (prompt, completion) pair from the learner.
# Guess prompts: reward plausible guesses, penalize gibberish.
# Question prompts: reward concise questions, penalize thinking-out-loud.

def wordle_reward_fn(prompts: list[str], completions: list[str], **kwargs) -> list[float]:
    rewards = []
    for prompt, completion in zip(prompts, completions):
        comp_text = completion[0]["content"] if isinstance(completion, list) else completion
        
        if "what is the secret word" in prompt.lower() or "reply with your guess" in prompt.lower():
            guess = extract_guess(comp_text)
            if guess in DEFAULT_WORD_BANK:
                rewards.append(0.5)
            elif len(guess) > 2:
                rewards.append(0.1)
            else:
                rewards.append(-0.5)
        else:
            text = comp_text.strip()
            score = 0.0
            if "?" in text:
                score += 0.3
            if len(text) < 200:
                score += 0.2
            if len(text) < 80:
                score += 0.2
            if "<think>" in text.lower():
                score -= 0.3
            rewards.append(score)
    return rewards

print("Reward function defined.")

Reward function defined.


In [4]:
# --- Build training dataset from wordle episodes ---
from datasets import Dataset

def make_prompt_dataset(n_episodes=16):
    """Run wordle episodes, collect all learner prompts."""
    states = batch_rollout(env, learner_fn=learner_fn, model=learner_llm, tokenizer=learner_tok, batch_size=n_episodes)
    samples = collect_training_data(states)
    prompts = [s["prompt"] for s in samples]
    return Dataset.from_dict({"prompt": prompts}), samples

train_dataset, raw_samples = make_prompt_dataset(n_episodes=16)
print(f"Training dataset: {len(train_dataset)} prompts")
print(f"  Questions: {sum(1 for s in raw_samples if s['type']=='question')}")
print(f"  Guesses:   {sum(1 for s in raw_samples if s['type']=='guess')}")
print(f"  Avg reward: {sum(s['reward'] for s in raw_samples) / len(raw_samples):.4f}")

Training dataset: 96 prompts
  Questions: 80
  Guesses:   16
  Avg reward: -1.4176


In [5]:
# --- GRPO Trainer Setup ---
from trl import GRPOTrainer, GRPOConfig

grpo_config = GRPOConfig(
    output_dir="./wordle_grpo_out",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_generations=4,
    max_completion_length=256,
    max_prompt_length=512,
    learning_rate=1e-6,
    logging_steps=1,
    save_steps=50,
    bf16=True,
    report_to="none",
)

trainer = GRPOTrainer(
    model=learner_llm,
    reward_funcs=wordle_reward_fn,
    args=grpo_config,
    train_dataset=train_dataset,
    processing_class=learner_tok,
)

print(f"Trainer ready. {len(train_dataset)} prompts, {grpo_config.num_generations} generations each.")

Trainer ready. 96 prompts, 4 generations each.




In [6]:
# --- Train ---
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
Passing `generation_config` together with generation-related arguments=({'disable_compile'}) is deprecated and will be removed in future versions. Please pass either a `generation_config` object OR all generation parameters explicitly, but not both.


RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

In [None]:
# --- Post-training evaluation ---
# Re-run a few episodes with the trained learner to see improvement

def learner_fn_trained(prompt: str) -> str:
    """Same as learner_fn but uses the (now trained) learner_llm."""
    msgs = [{"role": "user", "content": prompt}]
    text = learner_tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    ids = learner_tok(text, return_tensors="pt").to(learner_llm.device)
    with torch.no_grad():
        out = learner_llm.generate(
            **ids, max_new_tokens=256, temperature=0.7, do_sample=True,
            top_p=0.9, top_k=50,
        )
    return learner_tok.decode(out[0][ids["input_ids"].shape[1]:], skip_special_tokens=True)

print("=== Post-training evaluation ===")
for word in ["piano", "castle", "guitar"]:
    state = env.rollout(
        learner_fn=learner_fn_trained,
        secret_word=word,
        model=learner_llm,
        tokenizer=learner_tok,
        verbose=True,
    )
    print(f"\n{'='*50}\n")