<a href="https://colab.research.google.com/github/hieunguyen7337/LLM_RL/blob/main/hangman_game_rl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# (wandb on; bitsandbytes kept for big-model path)
!pip -q install "trl>=0.16.0" transformers accelerate bitsandbytes peft wandb

In [2]:
import os, torch
from datasets import Dataset, load_dataset
from trl import GRPOTrainer, GRPOConfig
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import re, ast
from typing import List, Tuple

In [3]:
# Toggle: use 4-bit for big models only
model_name   = "Qwen/Qwen2.5-0.5B-Instruct"
ENABLE_4BIT  = False      # <- small model: False. Set True for bigger models (e.g., ≥7B).
GRAD_CKPT    = ENABLE_4BIT
USE_CACHE    = not GRAD_CKPT  # avoid the "caching incompatible with checkpointing" spam
DATA_PATH    = "/content/LLM_RL/training_hangman_dataset.json"

In [4]:
os.environ.setdefault("WANDB_PROJECT", "huggingface")   # or "grpo-demos"
os.environ.setdefault("WANDB_LOG_MODEL", "end")

'end'

In [5]:
!git clone https://github.com/hieunguyen7337/LLM_RL.git

Cloning into 'LLM_RL'...
remote: Enumerating objects: 86, done.[K
remote: Counting objects: 100% (86/86), done.[K
remote: Compressing objects: 100% (80/80), done.[K
remote: Total 86 (delta 39), reused 13 (delta 5), pack-reused 0 (from 0)[K
Receiving objects: 100% (86/86), 497.71 KiB | 1.05 MiB/s, done.
Resolving deltas: 100% (39/39), done.


In [None]:
dataset = load_dataset("json", data_files=DATA_PATH)["train"]

In [7]:
dataset

Dataset({
    features: ['prompt', 'word'],
    num_rows: 1825
})

In [8]:
# --- helpers ---
def _parse_prompt(prompt: str) -> Tuple[List[str], List[str]]:
    """Extract guessed letters (uppercased) and current state tokens ['_', 'A', ...]."""
    guessed = []
    m = re.search(r"Guessed letters:\s*(\[[^\]]*\])", prompt, flags=re.IGNORECASE | re.DOTALL)
    if m:
        try:
            guessed = [s.upper() for s in ast.literal_eval(m.group(1)) if isinstance(s, str)]
        except Exception:
            guessed = []
    guessed = list(dict.fromkeys(guessed))  # dedupe, keep order

    state_tokens = []
    m2 = re.search(r"The current state is:\s*([A-Za-z_ ]+)", prompt, flags=re.IGNORECASE)
    if m2:
        state_tokens = [t.upper() for t in m2.group(1).split()]  # e.g. ["_", "_", "P", "O", "_", "_", "E", "_"]
    return guessed, state_tokens

In [9]:
def _count_new_reveals(state_tokens: List[str], word_upper: str, guess: str) -> int:
    """How many new positions this guess would newly reveal given current visible state."""
    n = 0
    L = len(word_upper)
    for i in range(L):
        st = state_tokens[i] if i < len(state_tokens) else "_"
        if st == "_" and word_upper[i] == guess:
            n += 1
    return n

In [10]:
# Reward funcs
def hangman_reward_func(prompts, completions, word, **kwargs):
    """
    Rewards in [-1, 1].
      +0.8  correct new reveal (plus +0.2 per extra same-letter reveal)
      +0.1  wrong letter
      -0.5  already guessed
      -1.0  no alphabetic output
      -0.6  more than one alphabetic char generated
    Tunables can be overridden via kwargs: pos_base, pos_multi_bonus, neg_wrong, neg_repeat, neg_missing, neg_multi_alpha.
    """
    # tunables
    pos_base        = float(kwargs.get("pos_base", 0.8))
    pos_multi_bonus = float(kwargs.get("pos_multi_bonus", 0.2))
    neg_wrong       = float(kwargs.get("neg_wrong", 0.1))
    neg_repeat      = float(kwargs.get("neg_repeat", -0.5))
    neg_missing     = float(kwargs.get("neg_missing", -1.0))
    neg_multi_alpha = float(kwargs.get("neg_multi_alpha", -0.6))

    rewards = []
    for prompt, completion, w in zip(prompts, completions, word):
        w_up = str(w).strip().upper()

        guessed, state = _parse_prompt(str(prompt))

        # default reward
        r = 0.0

        if not completion.isalpha():
            rewards.append(neg_missing)
            continue

        if len(completion) == 0:
            rewards.append(neg_missing)
            continue

        elif len(completion) != 1:
            rewards.append(neg_multi_alpha)
            continue

        completion = completion.upper()

        # logic
        if completion in set(guessed):
            r += neg_repeat
        elif completion not in set(w_up):
            r += neg_wrong
        else:
            new_reveals = _count_new_reveals(state, w_up, completion)
            if new_reveals > 0:
                r += pos_base + pos_multi_bonus * (new_reveals - 1)

        # clamp to [-1, 1] for stability
        r = max(-1.0, min(1.0, r))
        rewards.append(float(r))
    return rewards

In [11]:
prompts = ["You are playing a game of Hangman.\n\nYour task is to guess a single character.\n\nThe word has a certain number of letters.\nThe current state of the word is shown with guessed letters filled in and blanks for the unknown letters.\nThe number of incorrect guesses remaining is listed.\nAll letters that have been guessed so far are listed.\n\nYou will format your response as a single uppercase letter at the end\n\nThe word has 5 letters.\nThe current state is: _ _ _ _ _\nIncorrect guesses remaining: 4\nGuessed letters: ['F', 'Q']\n\nCorrect response:",
           "You are playing a game of Hangman.\n\nYour task is to guess a single character.\n\nThe word has a certain number of letters.\nThe current state of the word is shown with guessed letters filled in and blanks for the unknown letters.\nThe number of incorrect guesses remaining is listed.\nAll letters that have been guessed so far are listed.\n\nYou will format your response as a single uppercase letter at the end\n\nThe word has 6 letters.\nThe current state is: _ o _ _ _ _\nIncorrect guesses remaining: 6\nGuessed letters: ['O']\n\nCorrect response:"]
completions = ["A", "o"]
words = ["above","policy"]
hangman_reward_func(prompts=prompts, completions=completions, word=words)

[0.8, -0.5]

In [None]:
# ---------------------------
# Tokenizer
# ---------------------------
tok = AutoTokenizer.from_pretrained(model_name, use_fast=True, padding_side="left")
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

In [None]:
# ---------------------------
# Model (no quant for small; 4-bit kept for big)
# ---------------------------
quant = None
if ENABLE_4BIT:
    quant = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16,
    )

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="cuda:0" if not ENABLE_4BIT else "auto",
    torch_dtype=torch.float16,
    attn_implementation="sdpa",
    use_cache=USE_CACHE,
    quantization_config=quant,
)

In [14]:
# ---------------------------
# LoRA (kept for both; safe with/without quant)
# ---------------------------
peft_cfg = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.05, bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
)

In [15]:
# ---------------------------
# GRPO config (W&B enabled; will save at epoch end)
# ---------------------------
args = GRPOConfig(
    output_dir="qwen2.5-0.5b-Instruct-grpo",
    per_device_train_batch_size=16,  # keep divisible by num_generations
    gradient_accumulation_steps=16,
    num_generations=16,              # default is 8, batch must be divisible by this
    max_prompt_length=512,          # default 512
    max_completion_length=1,        # default 256, set
    fp16=True,                      # T4 uses fp16
    gradient_checkpointing=GRAD_CKPT,
    report_to="wandb",              # <-- keep W&B reporting
    run_name="qwen2.5-0.5b-Instruct-GRPO-2",  # change per run
    logging_steps=3,
    save_strategy="epoch",          # save at epoch end too
    save_total_limit=2,
)

In [16]:
trainer = GRPOTrainer(
    model=model,
    reward_funcs=hangman_reward_func,
    train_dataset=dataset,
    args=args,
    peft_config=peft_cfg, # LoRA reduces trainable params & VRAM
)

In [18]:
trainer.train()

  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhieunn16[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
3,-0.0
6,-0.0
9,0.0
12,-0.0
15,-0.0
18,-0.0
21,0.0
24,-0.0
27,-0.0
30,-0.0


TrainOutput(global_step=342, training_loss=-2.5037463689059542e-08, metrics={'train_runtime': 4595.483, 'train_samples_per_second': 1.191, 'train_steps_per_second': 0.074, 'total_flos': 0.0, 'train_loss': -2.5037463689059542e-08})

In [19]:
save_dir = args.output_dir

In [20]:
# ---------------------------
# Save: adapter/weights + tokenizer + trainer state
# ---------------------------
trainer.save_model(save_dir)        # saves PEFT adapter (and weights) appropriately
tok.save_pretrained(save_dir)
trainer.save_state()

In [21]:
# (Optional) If you're NOT using 4-bit, also export a merged FP16 model without LoRA adapters:
if not ENABLE_4BIT:
    try:
        merged = trainer.model.merge_and_unload()
        merged_dir = os.path.join(save_dir, "merged-fp16")
        merged.save_pretrained(merged_dir)
        tok.save_pretrained(merged_dir)
        print(f"Merged full model saved to: {merged_dir}")
    except Exception as e:
        print("Merge skipped (not a PEFT model or unsupported):", e)

Merged full model saved to: qwen2.5-0.5b-Instruct-grpo/merged-fp16


In [22]:
import shutil

# compress the folder
shutil.make_archive("qwen2.5-0.5b-Instruct-grpo", 'zip', "qwen2.5-0.5b-Instruct-grpo")

# now download to your computer
from google.colab import files
files.download("qwen2.5-0.5b-Instruct-grpo.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>