[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](
  https://colab.research.google.com/github/czovekboti/chess_rl/blob/sft%2Bgrpo/GRPO%20trainer%20notebook.ipynb
)

# Add where your config file is and where the model weights should be saved

In [None]:
#@title Colab Extra Install { display-mode: "form" }
#%%capture
import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1" # [NEW] Extra 30% context lengths!
if "COLAB_" not in "".join(os.environ.keys()):
    # If you're not in Colab, just use pip install or uv pip install
    !pip install unsloth vllm
else:
    pass # For Colab / Kaggle, we need extra instructions hidden below \/

In [None]:
#@title Colab Extra Install { display-mode: "form" }
#%%capture
!pip install python-chess
!apt-get install stockfish

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
stockfish is already the newest version (14.1-1).
0 upgraded, 0 newly installed, 0 to remove and 41 not upgraded.


In [None]:
#@title Colab Extra Install { display-mode: "form" }
#%%capture
import os
!pip install --upgrade -qqq uv
if "COLAB_" not in "".join(os.environ.keys()):
    # If you're not in Colab, just use pip install!
    !pip install unsloth vllm
else:
    try: import numpy, PIL; get_numpy = f"numpy=={numpy.__version__}"; get_pil = f"pillow=={PIL.__version__}"
    except: get_numpy = "numpy"; get_pil = "pillow"
    try: import subprocess; is_t4 = "Tesla T4" in str(subprocess.check_output(["nvidia-smi"]))
    except: is_t4 = False
    get_vllm, get_triton = ("vllm==0.9.2", "triton==3.2.0") if is_t4 else ("vllm==0.10.2", "triton")
    !uv pip install -qqq --upgrade \
        unsloth {get_vllm} {get_numpy} {get_pil} torchvision bitsandbytes xformers
    !uv pip install -qqq {get_triton}
!uv pip install transformers==4.56.2
!uv pip install --no-deps trl==0.22.2

[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m18 packages[0m [2min 101ms[0m[0m
[2mUninstalled [1m1 package[0m [2min 205ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 115ms[0m[0m
 [31m-[39m [1mtransformers[0m[2m==4.57.2[0m
 [32m+[39m [1mtransformers[0m[2m==4.56.2[0m
[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m1 package[0m [2min 1ms[0m[0m
[2mUninstalled [1m1 package[0m [2min 2ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 7ms[0m[0m
 [31m-[39m [1mtrl[0m[2m==0.24.0[0m
 [32m+[39m [1mtrl[0m[2m==0.22.2[0m


# Load Config Files
- Model parameters are loaded from config
- Stockfish has to be downloaded already, add path in env file

In [None]:
import os
from dotenv import load_dotenv
load_dotenv()

import yaml

# Path to your YAML config file
path = 'config.yaml'
output_path = 'grpo_outputs'

def load_config(path: str):
    with open(path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config = load_config(path)
config_name = "qwen4b"
print("Selected config_name:", config_name)

match config_name:
    case "llama":
        config = config["llama_config"]
    case "phi":
        config = config["PHI_config"]
    case "mistral":
        config = config["mistral_config"]
    case "qwen7b":
        config = config["qwen7b_config"]
    case "qwen4b":
        config = config["qwen4b_config"]
    case _:
        raise ValueError("Check model name ‚Äì perhaps the keyboard got excited.")

# Stockfish path from env
# stockfish_path = os.getenv("STOCKFISH_PATH")
stockfish_path= '/usr/games/stockfish'

print("STOCKFISH_PATH:", stockfish_path)

Selected config_name: qwen4b
STOCKFISH_PATH: /usr/games/stockfish


# Load model
- Adding pad tokens so its convertable with sft trained lora adapters

In [None]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "czovekboti/stf_model",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = False, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.8, # Reduce if out of memory
    resize_model_vocab = 151669,
)

ü¶• Unsloth: Will patch your computer to enable 2x faster free finetuning.
INFO 12-02 18:50:48 [__init__.py:244] Automatically detected platform cuda.
ERROR 12-02 18:50:52 [fa_utils.py:57] Cannot use FA version 2 is not supported due to FA2 is only supported on devices with compute capability >= 8
ü¶• Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.11.6: Fast Qwen3 patching. Transformers: 4.56.2. vLLM: 0.9.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.08G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/237 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/1.82G [00:00<?, ?B/s]

Unsloth 2025.11.6 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


In [None]:
# @title
# from unsloth import FastLanguageModel
# import torch
# max_seq_length = config['max_seq_length']# Can increase for longer reasoning traces
# lora_rank = config['lora_rank'] # Larger rank = smarter, but slower
# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name = "czovekboti/qwen-sft",
#     max_seq_length = max_seq_length,
#     load_in_4bit = False, # False for LoRA 16bit
#     fast_inference = False, # Enable vLLM fast inference
#     max_lora_rank = lora_rank,
#     gpu_memory_utilization = 0.7, # Reduce if out of memory
#     resize_model_vocab = 151669,
# )

# Load Lora
- Change adapter dir as neccessary

In [None]:
# @title
# from peft import PeftModel
# model = FastLanguageModel.get_peft_model(
#     model,
#     r = lora_rank,
#     target_modules = [
#         "q_proj", "k_proj", "v_proj", "o_proj",
#         "gate_proj", "up_proj", "down_proj",
#     ],
#     lora_alpha = lora_rank,
#     use_gradient_checkpointing = "unsloth",
#     random_state = 3407,
# )
# model = PeftModel.from_pretrained(
#     model,
#     "/content/lora_model/sft_adapter",
#     is_trainable=True,
#     adapter_name="sft_adapter",
# )


In [None]:
# @title
import torch

# 1) Make sure adapter is attached and active
print("Adapters:", getattr(model, "peft_config", {}).keys())
try:
    active = model.get_active_adapters()
    print("Active:", active)
except Exception:
    pass

# 2) Ensure train mode + cache off
model.train()
if hasattr(model, "config"):
    model.config.use_cache = False

# 3) Enable low-mem training niceties
if hasattr(model, "enable_input_require_grads"):
    model.enable_input_require_grads()
try:
    model.gradient_checkpointing_enable()
except Exception:
    pass

# 4) Only LoRA params trainable; base frozen
trainable, total = 0, 0
only_lora_trainable = True
for n,p in model.named_parameters():
    total += p.numel()
    if 'lora' in n.lower():
        p.requires_grad = True
        trainable += p.numel()
    else:
        # freeze base
        p.requires_grad = False
        if p.requires_grad:
            only_lora_trainable = False

print(f"Trainable params (LoRA): {trainable:,} / {total:,}")
print("Only LoRA trainable? ", only_lora_trainable)

# 5) Tiny forward/backward probe
x = tokenizer("probe", return_tensors="pt").to(next(model.parameters()).device)
out = model(**x, labels=x["input_ids"])
print("Loss requires_grad? ", out.loss.requires_grad)
out.loss.backward()
has_lora_grad = any(p.grad is not None and "lora" in n.lower() for n,p in model.named_parameters())
has_base_grad = any(p.grad is not None and "lora" not in n.lower() for n,p in model.named_parameters())
print("LoRA grads present? ", has_lora_grad)   # should be True
print("Base grads present? ", has_base_grad)   # should be False


Adapters: dict_keys(['default'])
Trainable params (LoRA): 66,060,288 / 4,087,844,864
Only LoRA trainable?  True
Loss requires_grad?  True
LoRA grads present?  True
Base grads present?  False


In [None]:
#@title Wandb Setup{ display-mode: "form" }
# Initialize wandb
import wandb
# os.environ["WANDB_LOG_MODEL"] = "end"
os.environ["WANDB_PROJECT"] = "Chess_RL_Project"
os.environ["WANDB_ENTITY"] = "czovekboti-budapesti-m-szaki-s-gazdas-gtudom-nyi-egyetem"
wandb.login()
wandb.init(
    project="huggingface",
    entity = "czovekboti-budapesti-m-szaki-s-gazdas-gtudom-nyi-egyetem",
    name=config["name"] + "_grpo_TARS_test_board",
    tags=["grpo"],
    config={
        "model": config["model"],
        "max_seq_length": config['max_seq_length'],
        "lora_rank": lora_rank,
        "learning_rate": config["learning_rate"],
        "max_steps": config["max_steps"],
    }
)


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mczovekboti[0m ([33mczovekboti-budapesti-m-szaki-s-gazdas-gtudom-nyi-egyetem[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


In [None]:
#@title Load dataset{ display-mode: "form" }
from datasets import load_dataset
dataset = load_dataset("czovekboti/chessdata", split="train")

# Training functions and prompt
- Prompt:
    - Gives instructions to the model alongside with examples
    - Same as SFT training prompt
- Functions:
  - Basic functions for extracting answer and checking existance of reasoning tags
  - correctness_reward_func: Loads board than checks if the move by the model is syntactically correct and valid. If yes always positive reward +/- the scaled evaluation given by stockfish. If the answer is incorrect negative reward is given.

In [None]:
SYSTEM_PROMPT = """
You are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and suggest the best legal move for the player whose turn it is.

Please follow this exact format in your response:

<reasoning>
(Brief explanation of what you see on the board ‚Äî piece activity, threats, and candidate moves)
</reasoning>
<answer>
(best move written in correct SAN format, such as Nf3 or exd5)
</answer>

Do not invent illegal or impossible moves. The move must be legal in the given FEN position.
Do not use UCI format like e2e4 ‚Äî only SAN notation like e4, Nf3, or O-O.
In case of taking a piece use the [file]x[target square] format
### Example:
FEN: rnbqkbnr/pppppppp/8/8/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 1

<reasoning>
White has just played e4 and developed the knight to f3. It‚Äôs Black‚Äôs turn. The e4 pawn is undefended. Capturing it with the pawn from d7 to d5 is a natural central counter.
</reasoning>
<answer>
d5
</answer>

Now solve the following position:
"""

# import chess libaries and load engine
import chess, chess.engine

from chess import InvalidMoveError, IllegalMoveError, AmbiguousMoveError
import math
import re
from datasets import load_dataset, Dataset
# Load and prep dataset


XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()
# More pieces,
def calculate_difficulty(fen: str) -> float:
    board = chess.Board(fen)
    piece_count = len(board.piece_map())
    legal_moves_count = len(list(board.legal_moves))
    fullmove_number = int(fen.split()[-1])
    difficulty = (32 - piece_count) * 2.0 + fullmove_number * 0.5 - legal_moves_count * 0.2
    return difficulty

def get_board(data, split = "train", samples_per_bucket=1000, num_buckets=10):
    total_samples_needed = samples_per_bucket * num_buckets
    sample_size = min(total_samples_needed * 3, len(data))
    print(f"Sampling {sample_size} positions from {len(data)} total...")
    import random
    sampled_indices = random.sample(range(len(data)), sample_size)
    data = data.select(sampled_indices)
    def fen_color(fen: str) -> str:
        return "White" if fen.split()[1] == 'w' else "Black"
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['FEN'] + " You are with the following pieces: " + fen_color(x['FEN'])}
        ], 'evaluation': x['Evaluation'], 'fen': x['FEN']
    }, remove_columns=data.column_names)
    print(data[0])
    data = data.map(lambda x: {'difficulty': calculate_difficulty(x['fen'])})
    data = data.sort('difficulty')
    curriculum_dataset = []
    bucket_size = samples_per_bucket
    for i in range(num_buckets):
        start_idx = i * bucket_size
        end_idx = min((i + 1) * bucket_size, sample_size)
        bucket = data.select(range(start_idx, end_idx))
        sample_size = min(samples_per_bucket, len(bucket))
        sampled_indices = list(range(len(bucket)))
        #import random
        #random.shuffle(sampled_indices)
        sampled = bucket.select(sampled_indices[:sample_size])

        curriculum_dataset.append(sampled)

    from datasets import concatenate_datasets
    final_dataset = concatenate_datasets(curriculum_dataset)

    print(f"Created curriculum dataset with {len(final_dataset)} examples from {num_buckets} difficulty buckets")
    print(final_dataset[0])
    return final_dataset


dataset = get_board(dataset)

def reward_move(board, dataeval):
  result = engine.analyse(board, chess.engine.Limit(time=1.0)) # time doesn't make a real difference above this
  evaluation = result['score'].relative.score() #evaluation from opponents point of view
  print(f"\n----------------------\n")
  if evaluation is not None:
      scaled_evaluation = math.tanh(evaluation / 900.0) * 2.0 # biggest eval for position in file is around 15000 but 2000+ evals are rare
      if -evaluation > dataeval: # give reward if it improved position (-evaluation cause we need other players pov)
        scaled_evaluation -= 0.5 # -0.5 because the sign is going to be flipped
        print(f"Eval = {-evaluation}, Dataeval = {dataeval}. State was improved->reward = 0.5")
      print(f"Scaled Evaluation: {-scaled_evaluation} ")
      return -scaled_evaluation # *-1 because we need the score of the player who is not in turn
  else:
    return 0.0

import wandb
from collections import Counter

def correctness_reward_func(prompts, fen, completions, evaluation, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_moves = [extract_xml_answer(r) for r in responses]

    rewards = []
    move_results = []  # Track: 'valid', 'invalid_syntax', 'illegal', 'ambiguous'

    try:
        if isinstance(evaluation, list):
            evaluation = float(evaluation[0])
    except (ValueError, TypeError):
        evaluation = 0.0

    for (fen_str, move) in zip(fen, extracted_moves):
        board = chess.Board(fen_str)
        try:
            board.push_san(move)
            scaled_evaluation = reward_move(board, evaluation)
            rewards.append(5.0 + scaled_evaluation)
            move_results.append('valid')
        except InvalidMoveError:
            rewards.append(-2.0)
            move_results.append('invalid_syntax')
        except ValueError:
            rewards.append(-1.0)
            move_results.append('illegal')
        except AmbiguousMoveError:
            rewards.append(1.0)
            move_results.append('ambiguous')

    # === Batch Statistics ===
    result_counts = Counter(move_results)
    total = len(move_results)
    unique_fens = list(set(fen))

    valid_pct = 100 * result_counts['valid'] / total if total > 0 else 0
    illegal_pct = 100 * (result_counts['illegal'] + result_counts['invalid_syntax']) / total if total > 0 else 0

    # Log to wandb
    wandb.log({
        "batch/valid_moves": result_counts['valid'],
        "batch/illegal_moves": result_counts['illegal'],
        "batch/invalid_syntax": result_counts['invalid_syntax'],
        "batch/ambiguous_moves": result_counts['ambiguous'],
        "batch/valid_pct": valid_pct,
        "batch/mean_reward": sum(rewards) / len(rewards) if rewards else 0,
        "batch/unique_positions": len(unique_fens),
    })

    # Console logging (concise)
    print(f"\n{'='*60}")
    print(f"GRPO BATCH SUMMARY")
    print(f"{'='*60}")
    print(f"Positions: {len(unique_fens)} unique / {total} total")
    print(f"Results: ‚úì {result_counts['valid']} valid ({valid_pct:.1f}%) | "
          f"‚úó {result_counts['illegal']} illegal | "
          f"‚ö† {result_counts['invalid_syntax']} bad syntax | "
          f"? {result_counts['ambiguous']} ambiguous")
    print(f"Rewards: mean={sum(rewards)/len(rewards):.2f}, "
          f"min={min(rewards):.2f}, max={max(rewards):.2f}")

    # Sample example (first one)
    print(f"\n--- Example ---")
    print(f"FEN: {fen[0]}")
    print(f"Prompt: {prompts[0][:200]}..." if len(prompts[0]) > 200 else f"Prompt: {prompts[0]}")
    print(f"Generated: {responses[0][:300]}..." if len(responses[0]) > 300 else f"Generated: {responses[0]}")
    print(f"Extracted move: {extracted_moves[0]} | Result: {move_results[0]} | Reward: {rewards[0]:.2f}")

    # Show unique FENs (truncated if many)
    if len(unique_fens) <= 5:
        print(f"\nUnique FENs: {unique_fens}")
    else:
        print(f"\nUnique FENs (first 3): {unique_fens[:3]}...")

    print(f"{'='*60}\n")

    return rewards


def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\s*.+?\s*</reasoning>\s*<answer>\s*.+?\s*</answer>\s*$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r,re.DOTALL) for r in responses]
    return [0.2 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, re.DOTALL) for r in responses]
    return [0.2 if match else 0.0 for match in matches]
def gentle_length_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content in contents:
        word_count = len(content.split())
        # Small bonus for reasonable length (50-150 words), no penalty otherwise
        if 50 <= word_count <= 150:
            rewards.append(0.1)  # Small, non-noisy bonus
        else:
            rewards.append(0.0)  # Neutral, not negative
    return rewards
def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

Sampling 30000 positions from 12958035 total...


Map:   0%|          | 0/30000 [00:00<?, ? examples/s]

{'prompt': [{'content': '\nYou are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and suggest the best legal move for the player whose turn it is.\n\nPlease follow this exact format in your response:\n\n<reasoning>\n(Brief explanation of what you see on the board ‚Äî piece activity, threats, and candidate moves)\n</reasoning>\n<answer>\n(best move written in correct SAN format, such as Nf3 or exd5)\n</answer>\n\nDo not invent illegal or impossible moves. The move must be legal in the given FEN position.\nDo not use UCI format like e2e4 ‚Äî only SAN notation like e4, Nf3, or O-O.\nIn case of taking a piece use the [file]x[target square] format\n### Example:\nFEN: rnbqkbnr/pppppppp/8/8/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 1\n\n<reasoning>\nWhite has just played e4 and developed the knight to f3. It‚Äôs Black‚Äôs turn. The e4 pawn is undefended. Capturing it with the pawn from d7 to d5 is a natural central counter.\n</reasoning>

Map:   0%|          | 0/30000 [00:00<?, ? examples/s]

Created curriculum dataset with 1000 examples from 10 difficulty buckets
{'prompt': [{'content': '\nYou are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and suggest the best legal move for the player whose turn it is.\n\nPlease follow this exact format in your response:\n\n<reasoning>\n(Brief explanation of what you see on the board ‚Äî piece activity, threats, and candidate moves)\n</reasoning>\n<answer>\n(best move written in correct SAN format, such as Nf3 or exd5)\n</answer>\n\nDo not invent illegal or impossible moves. The move must be legal in the given FEN position.\nDo not use UCI format like e2e4 ‚Äî only SAN notation like e4, Nf3, or O-O.\nIn case of taking a piece use the [file]x[target square] format\n### Example:\nFEN: rnbqkbnr/pppppppp/8/8/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 1\n\n<reasoning>\nWhite has just played e4 and developed the knight to f3. It‚Äôs Black‚Äôs turn. The e4 pawn is undefended. Capturing i

# Train model

In [None]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = float(config["learning_rate"]),
    adam_beta1 = config["adam_beta1"],
    adam_beta2 = config["adam_beta2"],
    weight_decay = config["weight_decay"],
    warmup_ratio = config["warmup_ratio"],
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    generation_kwargs = {
        "eos_token_id": tokenizer.eos_token_id,
        "repetition_penalty": 1.2,  # Discourage repetition
    },
    logging_steps = 1,
    per_device_train_batch_size = config["per_device_train_batch_size"], #2 for bigger model 4 for smaller #16 gb gpu could do 8 with 14b model
    gradient_accumulation_steps = 2, # overall batch size should be 16 or 32 -> sslows training down
    num_generations = 6, # Decrease if out of memory
    max_steps = 5,
    max_grad_norm = 0.1,
    save_total_limit =1,
    report_to = "wandb", # report to weights and biases
    output_dir = output_path,
    run_name = "chess_llama_grpo",
)

engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
try:
    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[
            xmlcount_reward_func,
            soft_format_reward_func,
            strict_format_reward_func,
            correctness_reward_func,
        ],
        args=training_args,
        train_dataset=dataset,
    )

    trainer.train()

finally:
    engine.quit()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,000 | Num Epochs = 1 | Total steps = 5
O^O/ \_/ \    Batch size per device = 6 | Gradient accumulation steps = 2
\        /    Data Parallel GPUs = 1 | Total batch size (6 x 2 x 1) = 12
 "-____-"     Trainable parameters = 66,060,288 of 4,087,844,864 (1.62% trained)
`generation_config` default values have been modified to match model-specific defaults: {'max_length': 262144, 'temperature': 0.7, 'top_p': 0.8}. If this is not desired, please set these values explicitly.
  out = torch_matmul(X, W.t(), out = out)


Unsloth: Will smartly offload gradients to save VRAM!

----------------------

Scaled Evaluation: -0.3797784003398156 

GRPO BATCH SUMMARY
Positions: 2 unique / 12 total
Results: ‚úì 1 valid (8.3%) | ‚úó 6 illegal | ‚ö† 5 bad syntax | ? 0 ambiguous
Rewards: mean=-0.95, min=-2.00, max=4.62

--- Example ---
FEN: r1bqkbr1/pp3p1p/2n1pn2/3p2p1/3pP3/2PB1N1P/PP1N1PP1/R1BQK2R w KQq - 0 9
Prompt: [{'content': '\nYou are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and suggest the best legal move for the player whose turn it is.\n\nPlease follow this exact format in your response:\n\n<reasoning>\n(Brief explanation of what you see on the board ‚Äî piece activity, threats, and candidate moves)\n</reasoning>\n<answer>\n(best move written in correct SAN format, such as Nf3 or exd5)\n</answer>\n\nDo not invent illegal or impossible moves. The move must be legal in the given FEN position.\nDo not use UCI format like e2e4 ‚Äî only SAN nota

Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / xmlcount_reward_func / mean,rewards / xmlcount_reward_func / std,rewards / soft_format_reward_func / mean,rewards / soft_format_reward_func / std,rewards / strict_format_reward_func / mean,rewards / strict_format_reward_func / std,rewards / correctness_reward_func / mean,rewards / correctness_reward_func / std
1,0.0011,-0.727898,1.470575,174.166672,74.0,256.0,0.416667,115.714294,74.0,177.0,1.07162,-0.012917,0.198512,0.116667,0.102986,0.116667,0.102986,-0.948315,1.822955
2,0.001,-1.848917,0.570558,225.166672,65.0,256.0,0.75,132.666672,65.0,253.0,0.982575,-0.08225,0.48824,0.033333,0.07785,0.033333,0.07785,-1.833333,0.389249
3,0.0011,-1.305124,1.276922,231.75,152.0,256.0,0.666667,183.25,152.0,222.0,1.10782,-0.05925,0.351485,0.066667,0.098473,0.066667,0.098473,-1.379207,1.568805
4,0.0014,-0.894347,1.807642,173.916672,67.0,256.0,0.25,146.555557,67.0,204.0,1.354649,-0.281333,0.361237,0.166667,0.07785,0.166667,0.07785,-0.946347,1.829515
5,0.0012,-1.374667,0.530642,192.583344,83.0,256.0,0.5,129.166672,83.0,184.0,1.153604,-0.074667,0.241599,0.1,0.104447,0.1,0.104447,-1.5,0.522233


  out = torch_matmul(X, W.t(), out = out)



GRPO BATCH SUMMARY
Positions: 2 unique / 12 total
Results: ‚úì 0 valid (0.0%) | ‚úó 2 illegal | ‚ö† 10 bad syntax | ? 0 ambiguous
Rewards: mean=-1.83, min=-2.00, max=-1.00

--- Example ---
FEN: rnb1k2r/1p1p1pbp/p3pnp1/q1P5/4P3/2N1BN1P/PPPQ1PP1/R3KB1R b KQkq - 0 8
Prompt: [{'content': '\nYou are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and suggest the best legal move for the player whose turn it is.\n\nPlease follow this exact format in your response:\n\n<reasoning>\n(Brief explanation of what you see on the board ‚Äî piece activity, threats, and candidate moves)\n</reasoning>\n<answer>\n(best move written in correct SAN format, such as Nf3 or exd5)\n</answer>\n\nDo not invent illegal or impossible moves. The move must be legal in the given FEN position.\nDo not use UCI format like e2e4 ‚Äî only SAN notation like e4, Nf3, or O-O.\nIn case of taking a piece use the [file]x[target square] format\n### Example:\nFEN: rnbqkb



0,1
batch/ambiguous_moves,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
batch/illegal_moves,‚ñà‚ñÅ‚ñÅ‚ñà‚ñà
batch/invalid_syntax,‚ñÅ‚ñà‚ñá‚ñÅ‚ñÇ
batch/mean_reward,‚ñà‚ñÅ‚ñÖ‚ñà‚ñÑ
batch/unique_positions,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
batch/valid_moves,‚ñà‚ñÅ‚ñà‚ñà‚ñÅ
batch/valid_pct,‚ñà‚ñÅ‚ñà‚ñà‚ñÅ
profiling/Time taken: UnslothGRPOTrainer._calculate_rewards,‚ñà‚ñÅ‚ñà‚ñà‚ñÅ
profiling/Time taken: UnslothGRPOTrainer._prepare_inputs,‚ñà‚ñÅ‚ñá‚ñÅ‚ñá‚ñÅ‚ñá‚ñÅ‚ñá‚ñÅ
profiling/Time taken: UnslothGRPOTrainer.correctness_reward_func,‚ñà‚ñÅ‚ñà‚ñà‚ñÅ

0,1
batch/ambiguous_moves,0
batch/illegal_moves,6
batch/invalid_syntax,6
batch/mean_reward,-1.5
batch/unique_positions,2
batch/valid_moves,0
batch/valid_pct,0
profiling/Time taken: UnslothGRPOTrainer._calculate_rewards,0.01043
profiling/Time taken: UnslothGRPOTrainer._prepare_inputs,1e-05
profiling/Time taken: UnslothGRPOTrainer.correctness_reward_func,0.00525


In [None]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

In [None]:
# Paste your "Write" token when prompted
repo_id = "czovekboti/grpo_model"
model.push_to_hub(repo_id)

README.md: 0.00B [00:00, ?B/s]



Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...adapter_model.safetensors:   0%|          | 1.12MB / 1.82GB            

Saved model to https://huggingface.co/czovekboti/grpo_model


In [None]:
adapter_dir = f"grpo_outputs/test_sft_board_fix"

# Save ONLY LoRA adapter weights + adapter config (no base model weights)
trainer.model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)
