In [1]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1" # [NEW] Extra 30% context lengths!
if "COLAB_" not in "".join(os.environ.keys()):
    # If you're not in Colab, just use pip install or uv pip install
    !pip install unsloth vllm
else:
    pass # For Colab / Kaggle, we need extra instructions hidden below \/

In [2]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
!pip install python-chess
!apt-get install stockfish

In [3]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
!pip install --upgrade -qqq uv
if "COLAB_" not in "".join(os.environ.keys()):
    # If you're not in Colab, just use pip install!
    !pip install unsloth vllm
else:
    try: import numpy, PIL; get_numpy = f"numpy=={numpy.__version__}"; get_pil = f"pillow=={PIL.__version__}"
    except: get_numpy = "numpy"; get_pil = "pillow"
    try: import subprocess; is_t4 = "Tesla T4" in str(subprocess.check_output(["nvidia-smi"]))
    except: is_t4 = False
    get_vllm, get_triton = ("vllm==0.9.2", "triton==3.2.0") if is_t4 else ("vllm==0.10.2", "triton")
    !uv pip install -qqq --upgrade \
        unsloth {get_vllm} {get_numpy} {get_pil} torchvision bitsandbytes xformers
    !uv pip install -qqq {get_triton}
!uv pip install transformers==4.56.2
!uv pip install --no-deps trl==0.22.2

In [4]:
#@title Connect drive{ display-mode: "form" }
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Load Config Files
- Model parameters are loaded from config
- Stockfish has to be downloaded already, add path in env file

In [5]:
import os
from dotenv import load_dotenv
load_dotenv()

import yaml

# Path to your YAML config file
path = '/content/config.yaml'

def load_config(path: str):
    with open(path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config = load_config(path)
config_name = "qwen4b"
print("Selected config_name:", config_name)

match config_name:
    case "llama":
        config = config["llama_config"]
    case "phi":
        config = config["PHI_config"]
    case "mistral":
        config = config["mistral_config"]
    case "qwen7b":
        config = config["qwen7b_config"]
    case "qwen4b":
        config = config["qwen4b_config"]
    case _:
        raise ValueError("Check model name â€“ perhaps the keyboard got excited.")

# Stockfish path from env
# stockfish_path = os.getenv("STOCKFISH_PATH")
stockfish_path= '/usr/games/stockfish'

print("STOCKFISH_PATH:", stockfish_path)








Selected config_name: qwen4b
STOCKFISH_PATH: /usr/games/stockfish


# Load model
- Adding pad tokens so its compatible with sft trained lora adapters

In [6]:
from unsloth import FastLanguageModel
import torch
max_seq_length = config['max_seq_length']# Can increase for longer reasoning traces
lora_rank = config['lora_rank'] # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = config["model"],
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = False, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.7, # Reduce if out of memory
)
#add paddings
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token or "<|pad|>" # needed for compatibility with lora

model.resize_token_embeddings(len(tokenizer))



ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
INFO 11-12 22:16:06 [__init__.py:244] Automatically detected platform cuda.
ERROR 11-12 22:16:08 [fa_utils.py:57] Cannot use FA version 2 is not supported due to FA2 is only supported on devices with compute capability >= 8
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.11.2: Fast Qwen3 patching. Transformers: 4.56.2. vLLM: 0.9.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Embedding(151669, 2560, padding_idx=151654)

# Load Lora
- Change adapter dir as neccessary

In [7]:
from peft import PeftModel
lora_path = "./sft_outputs/checkpoint-xy"
# model = FastLanguageModel.get_peft_model(
#     model,
#     r = lora_rank,
#     target_modules = [
#         "q_proj", "k_proj", "v_proj", "o_proj",
#         "gate_proj", "up_proj", "down_proj",
#     ],
#     lora_alpha = lora_rank,
#     use_gradient_checkpointing = "unsloth",
#     random_state = 3407,
# )
model = PeftModel.from_pretrained(
    model,
    lora_path,
    is_trainable=True,
    adapter_name="sft_adapter",
)




In [8]:
#@title Wandb Setup{ display-mode: "form" }
# Initialize wandb
import wandb
os.environ["WANDB_LOG_MODEL"] = "end"
os.environ["WANDB_PROJECT"] = "Chess_RL_Project"
os.environ["WANDB_ENTITY"] = "czovekboti-budapesti-m-szaki-s-gazdas-gtudom-nyi-egyetem"
wandb.login()
wandb.init(
    project="Chess_RL_Project",
    entity = "",
    name=config["name"],
    config={
        "model": config["model"],
        "max_seq_length": config['max_seq_length'],
        "lora_rank": lora_rank,
        "learning_rate": config["learning_rate"],
        "max_steps": config["max_steps"],
    }
)


[34m[1mwandb[0m: Currently logged in as: [33mczovekboti[0m ([33mczovekboti-budapesti-m-szaki-s-gazdas-gtudom-nyi-egyetem[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


In [9]:
#@title Load dataset{ display-mode: "form" }
from datasets import load_dataset
dataset = load_dataset("czovekboti/chessdata", split="train")

# Training functions and prompt
- Prompt:
    - Gives instructions to the model alongside with examples
    - Same as SFT training prompt
- Functions:
  - Basic functions for extracting answer and checking existance of reasoning tags
  - correctness_reward_func: Loads board than checks if the move by the model is syntactically correct and valid. If yes always positive reward +/- the scaled evaluation given by stockfish. If the answer is incorrect negative reward is given.

In [10]:
SYSTEM_PROMPT = """
You are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and suggest the best legal move for the player whose turn it is.

Please follow this exact format in your response:

<reasoning>
(Brief explanation of what you see on the board â€” piece activity, threats, and candidate moves)
</reasoning>
<answer>
(best move written in correct SAN format, such as Nf3 or exd5)
</answer>

Do not invent illegal or impossible moves. The move must be legal in the given FEN position.
Do not use UCI format like e2e4 â€” only SAN notation like e4, Nf3, or O-O.
In case of taking a piece use the [file]x[target square] format
### Example:
FEN: rnbqkbnr/pppppppp/8/8/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 1

<reasoning>
White has just played e4 and developed the knight to f3. Itâ€™s Blackâ€™s turn. The e4 pawn is undefended. Capturing it with the pawn from d7 to d5 is a natural central counter.
</reasoning>
<answer>
d5
</answer>

Now solve the following position:
"""

# import chess libaries and load engine
import chess, chess.engine
from chess import InvalidMoveError, IllegalMoveError, AmbiguousMoveError
import math
import re
from datasets import load_dataset, Dataset
# Load and prep dataset


XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()
def get_board(data, split = "train"):
    def fen_color(fen: str) -> str:
        return "White" if fen.split()[1] == 'w' else "Black"
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['FEN'] + " You are with the following pieces: " + fen_color(x['FEN'])}
        ], 'evaluation': x['Evaluation'], 'fen': x['FEN']
    }, remove_columns=data.column_names)
    print(data[0])
    return data #


dataset = get_board(dataset.select(range(2000)))
def reward_move(board, dataeval):
  result = engine.analyse(board, chess.engine.Limit(time=1.0)) # time doesn't make a real difference above this
  evaluation = result['score'].relative.score() #evaluation from opponents point of view
  print(f"\n----------------------\n")
  if evaluation is not None:
      scaled_evaluation = math.tanh(evaluation / 900.0) * 2.0 # biggest eval for position in file is around 15000 but 2000+ evals are rare
      if -evaluation > dataeval: # give reward if it improved position (-evaluation cause we need other players pov)
        scaled_evaluation -= 0.5 # -0.5 because the sign is going to be flipped
        print(f"Eval = {-evaluation}, Dataeval = {dataeval}. State was improved->reward = 0.5")
      print(f"Scaled Evaluation: {-scaled_evaluation} ")
      return -scaled_evaluation # *-1 because we need the score of the player who is not in turn
  else:
    return 0.0


# Reward functions
def correctness_reward_func(prompts,fen, completions, evaluation,**kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_moves =  [extract_xml_answer(r) for r in responses]
    fen_str = fen[0] if isinstance(fen, list) else fen
    board = chess.Board(fen_str)
    print(f"------------\nFEN: {fen}\n--------- \nResponse: {responses[0]} \n----------\nExtracted_Move: {extracted_moves[0]}")
    rewards = []
    try:
        if isinstance(evaluation, list):
            evaluation = float(evaluation[0]) # evaluation maybe a list due to a bug
    except (ValueError, TypeError) as e:
        print(f"Error: Could not convert evaluation '{evaluation}' to float. Using default value 0.0.")
        evaluation = 0.0
    # This also checks if the move is right both syntactically and legally
    for move in extracted_moves:
        try:
          board.push_san(move)
          scaled_evaluation = reward_move(board,evaluation) #evaluate board after the move was made
          rewards.append(3.0+scaled_evaluation) # +5
        except InvalidMoveError:
            print(f"\n----------------------\n-1.0 reward for illegal syntax")
            rewards.append(-5.0)
        except ValueError:
            print(f"\n----------------------\n -0.7 reward for illegal move")
            rewards.append(-3.0)
        except AmbiguousMoveError: #meaning two pieces could go to the declared square
            print(f"\n----------------------\n 0.5 reward for right syntax but ambigous move")
            rewards.append(0.5)
    return rewards

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\s*.+?\s*</reasoning>\s*<answer>\s*.+?\s*</answer>\s*$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r,re.DOTALL) for r in responses]
    return [0.2 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, re.DOTALL) for r in responses]
    return [0.2 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

{'prompt': [{'content': '\nYou are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and suggest the best legal move for the player whose turn it is.\n\nPlease follow this exact format in your response:\n\n<reasoning>\n(Brief explanation of what you see on the board â€” piece activity, threats, and candidate moves)\n</reasoning>\n<answer>\n(best move written in correct SAN format, such as Nf3 or exd5)\n</answer>\n\nDo not invent illegal or impossible moves. The move must be legal in the given FEN position.\nDo not use UCI format like e2e4 â€” only SAN notation like e4, Nf3, or O-O.\nIn case of taking a piece use the [file]x[target square] format\n### Example:\nFEN: rnbqkbnr/pppppppp/8/8/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 1\n\n<reasoning>\nWhite has just played e4 and developed the knight to f3. Itâ€™s Blackâ€™s turn. The e4 pawn is undefended. Capturing it with the pawn from d7 to d5 is a natural central counter.\n</reasoning>

# Train model

In [11]:

max_prompt_length = 256

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = float(config["learning_rate"]),
    adam_beta1 = config["adam_beta1"],
    adam_beta2 = config["adam_beta2"],
    weight_decay = config["weight_decay"],
    warmup_ratio = config["warmup_ratio"],
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    generation_kwargs = {
        "eos_token_id": tokenizer.eos_token_id,
        "repetition_penalty": 1.2,  # Discourage repetition
    },
    logging_steps = 1,
    per_device_train_batch_size = config["per_device_train_batch_size"], #2 for bigger model 4 for smaller #16 gb gpu could do 8 with 14b model
    gradient_accumulation_steps = 2, # overall batch size should be 16 or 32 -> sslows training down
    num_generations = 6, # Decrease if out of memory
    max_steps = 10,
    max_grad_norm = 0.1,
    save_total_limit =1,
    report_to = "wandb", # report to weights and biases
    output_dir = "./GPRO",
    run_name = "chess_llama_grpo",
)

engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
try:
    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[
            xmlcount_reward_func,
            soft_format_reward_func,
            strict_format_reward_func,
            correctness_reward_func,
        ],
        args=training_args,
        train_dataset=dataset,
    )

    trainer.train()

finally:
    engine.quit()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 2,000 | Num Epochs = 1 | Total steps = 10
O^O/ \_/ \    Batch size per device = 6 | Gradient accumulation steps = 2
\        /    Data Parallel GPUs = 1 | Total batch size (6 x 2 x 1) = 12
 "-____-"     Trainable parameters = 66,060,288 of 4,087,844,864 (1.62% trained)
`generation_config` default values have been modified to match model-specific defaults: {'max_length': 262144, 'temperature': 0.7, 'top_p': 0.8}. If this is not desired, please set these values explicitly.
  out = torch_matmul(X, W.t(), out = out)


------------
FEN: ['rnbqkbnr/pp2pppp/8/3p4/3P4/8/PPP2PPP/RNBQKBNR w KQkq - 0 4', 'rnbqkbnr/pp2pppp/8/3p4/3P4/8/PPP2PPP/RNBQKBNR w KQkq - 0 4', 'rnbqkbnr/pp2pppp/8/3p4/3P4/8/PPP2PPP/RNBQKBNR w KQkq - 0 4', 'rnbqkbnr/pp2pppp/8/3p4/3P4/8/PPP2PPP/RNBQKBNR w KQkq - 0 4', 'rnbqkbnr/pp2pppp/8/3p4/3P4/8/PPP2PPP/RNBQKBNR w KQkq - 0 4', 'rnbqkbnr/pp2pppp/8/3p4/3P4/8/PPP2PPP/RNBQKBNR w KQkq - 0 4', '8/5pk1/R5p1/4Kn1p/7P/6P1/8/8 w - - 3 44', '8/5pk1/R5p1/4Kn1p/7P/6P1/8/8 w - - 3 44', '8/5pk1/R5p1/4Kn1p/7P/6P1/8/8 w - - 3 44', '8/5pk1/R5p1/4Kn1p/7P/6P1/8/8 w - - 3 44', '8/5pk1/R5p1/4Kn1p/7P/6P1/8/8 w - - 3 44', '8/5pk1/R5p1/4Kn1p/7P/6P1/8/8 w - - 3 44']
--------- 
Response: <reasoning>
The position shows a standard opening setup with white to play. The pawns are mostly centralized, with black having a strong center presence due to their three pawns on c6, d6, and e6 (implied by "3p4" on the fourth rank). White's light-squared bishop is active, but there seems to be some imbalance in development. Ho

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
model.save_pretrained("lora_model")
tokenizer.save_pretrained("lora_model")