In [1]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1" # [NEW] Extra 30% context lengths!
if "COLAB_" not in "".join(os.environ.keys()):
    # If you're not in Colab, just use pip install or uv pip install
    !pip install unsloth vllm
else:
    pass # For Colab / Kaggle, we need extra instructions hidden below \/

In [2]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
!pip install python-chess
!apt-get install stockfish

In [3]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
!pip install --upgrade -qqq uv
if "COLAB_" not in "".join(os.environ.keys()):
    # If you're not in Colab, just use pip install!
    !pip install unsloth vllm
else:
    try: import numpy, PIL; get_numpy = f"numpy=={numpy.__version__}"; get_pil = f"pillow=={PIL.__version__}"
    except: get_numpy = "numpy"; get_pil = "pillow"
    try: import subprocess; is_t4 = "Tesla T4" in str(subprocess.check_output(["nvidia-smi"]))
    except: is_t4 = False
    get_vllm, get_triton = ("vllm==0.9.2", "triton==3.2.0") if is_t4 else ("vllm==0.10.2", "triton")
    !uv pip install -qqq --upgrade \
        unsloth {get_vllm} {get_numpy} {get_pil} torchvision bitsandbytes xformers
    !uv pip install -qqq {get_triton}
!uv pip install transformers==4.56.2
!uv pip install --no-deps trl==0.22.2

In [4]:
#@title Load model config{ display-mode: "form" }
import os
from dotenv import load_dotenv
load_dotenv()

import yaml

# Path to your YAML config file
path = '/content/config.yaml'

def load_config(path: str):
    with open(path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config = load_config(path)
config_name = "qwen4b"
print("Selected config_name:", config_name)

match config_name:
    case "llama":
        config = config["llama_config"]
    case "phi":
        config = config["PHI_config"]
    case "mistral":
        config = config["mistral_config"]
    case "qwen7b":
        config = config["qwen7b_config"]
    case "qwen4b":
        config = config["qwen4b_config"]
    case _:
        raise ValueError("Check model name – perhaps the keyboard got excited.")

# Stockfish path from env
# stockfish_path = os.getenv("STOCKFISH_PATH")
stockfish_path= '/usr/games/stockfish'

print("STOCKFISH_PATH:", stockfish_path)

Selected config_name: qwen4b
STOCKFISH_PATH: /usr/games/stockfish


In [6]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = config["model"],
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = False, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.7, # Reduce if out of memory
)
#add paddings
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token or "<|pad|>"

model.resize_token_embeddings(len(tokenizer))

target_modules = [
    "q_proj","k_proj","v_proj","o_proj",
    "gate_proj","up_proj","down_proj",

]
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = target_modules,
    lora_alpha = lora_rank, # *2 speeds up training
    use_gradient_checkpointing = "unsloth", # Reduces memory usage
    random_state = 3407,
)



==((====))==  Unsloth 2025.10.12: Fast Qwen3 patching. Transformers: 4.56.2. vLLM: 0.9.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Unsloth 2025.10.12 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


In [7]:
from sklearn.model_selection import train_test_split
from datasets import load_dataset
original_dataset = load_dataset("czovekboti/sft_chess", split="train")
split = original_dataset.train_test_split(test_size=30, seed=42)
train_ds = split["train"]
test_ds = split["test"]
print(train_ds)
print(test_ds)

Generating train split:   0%|          | 0/100 [00:00<?, ? examples/s]

Dataset({
    features: ['fen', 'top_5_moves', 'answer'],
    num_rows: 70
})
Dataset({
    features: ['fen', 'top_5_moves', 'answer'],
    num_rows: 30
})


In [8]:
SYSTEM_PROMPT = """You are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and suggest the best legal move for the player whose turn it is.

Please follow this exact format in your response:

<reasoning>
(Brief explanation of what you see on the board — piece activity, threats, and candidate moves)
</reasoning>
<answer>
(best move written in correct SAN format, such as Nf3 or exd5)
</answer>

Do not invent illegal or impossible moves. The move must be legal in the given FEN position.
Do not use UCI format like e2e4 — only SAN notation like e4, Nf3, or O-O.
In case of taking a piece use the [file]x[target square] format
### Example:
FEN: rnbqkbnr/pppppppp/8/8/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 1

<reasoning>
White has just played e4 and developed the knight to f3. It’s Black’s turn. The e4 pawn is undefended. Capturing it with the pawn from d7 to d5 is a natural central counter.
</reasoning>
<answer>
d5
</answer>

Now solve the following position:
"""

def fen_color(fen: str) -> str:
    try:
        parts = fen.strip().split()
        side = parts[1]
        return "White" if side == "w" else "Black" if side == "b" else "Unknown"
    except Exception:
        return "Unknown"

# ---------- add this block after loading model/tokenizer ----------
import torch, re, pandas as pd
from datasets import load_dataset

# (optional but recommended) ensure a chat template is set for base models
try:
    from unsloth.chat_templates import get_chat_template
    if getattr(tokenizer, "chat_template", None) in (None, ""):
        tokenizer = get_chat_template(tokenizer, chat_template="qwen3")
except Exception:
    # fallback: minimal qwen-like template
    if getattr(tokenizer, "chat_template", None) in (None, ""):
        tokenizer.chat_template = (
            "{% for m in messages %}"
            "{% if m['role'] == 'system' %}<|im_start|>system\n{{ m['content'] }}<|im_end|>\n"
            "{% elif m['role'] == 'user' %}<|im_start|>user\n{{ m['content'] }}<|im_end|>\n"
            "{% elif m['role'] == 'assistant' %}<|im_start|>assistant\n{{ m['content'] }}<|im_end|>\n"
            "{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
        )

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token



def extract_move(text: str):
    m = ANSWER_RE.search(text or "")
    return m.group(1).strip() if m else None

def to_top5_set(v):
    if v is None:
        return set()
    if isinstance(v, (list, tuple, set)):
        return set(str(x).strip() for x in v if str(x).strip())
    parts = [p.strip() for p in str(v).replace(";", ",").split(",")]
    return set(p for p in parts if p)

def build_prompt(fen: str, color: str):
    msgs = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": f"FEN: {fen}\nYou are with the following pieces: {color}"},
    ]
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

def generate_answer(the_model, prompt: str, max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=50):
    the_model.eval()
    with torch.no_grad():
        inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        out = the_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=temperature > 0,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True,
        )
    gen_ids = out[0][inputs["input_ids"].shape[-1]:]
    return tokenizer.decode(gen_ids, skip_special_tokens=False)


print("=" * 80)
print("TESTING MODEL BEFORE TRAINING + SCORING")
print("=" * 80)

ANSWER_RE = re.compile(r"<answer>\s*(.*?)\s*</answer>", re.DOTALL | re.IGNORECASE)
REASONING_RE = re.compile(r"<reasoning>\s*(.*?)\s*</reasoning>", re.DOTALL | re.IGNORECASE)

def test_model(train_ds, model,answer_re,reasoning_re):
  i=0
  rows = []
  for example in train_ds:
        fen = example["fen"]
        color = fen_color(fen)
        prompt = build_prompt(fen, color)
        output = generate_answer(model, prompt, max_new_tokens=128, temperature=0.0, top_p=1.0, top_k=0)  # deterministic

        move = extract_move(output)
        has_reasoning = bool(REASONING_RE.search(output))
        has_answer_tag = bool(ANSWER_RE.search(output))

        top5 = to_top5_set(example.get("top_5_moves"))
        in_top5 = (move in top5) if move else False

        best_move = (example.get("best_move") or "").strip() or None
        equals_best = (move == best_move) if (move and best_move) else False

        score = 0
        score += 1 if has_reasoning else 0
        score += 1 if has_answer_tag else 0
        score += 1 if move else 0
        score += 2 if in_top5 else 0
        score += 2 if equals_best else 0

        print(f"\n{'='*80}")
        print(f"Example {i+1}/{len(train_ds)}")
        print(f"{'='*80}")
        print(f"FEN: {fen}")
        print(f"Color: {color}")
        print(f"\nGenerated Answer:\n{output}")

        rows.append({
            "idx": i,
            "FEN": fen,
            "Color": color,
            "Move": move or "",
            "HasReasoningTag": has_reasoning,
            "HasAnswerTag": has_answer_tag,
            "InTop5": in_top5,
            "EqualsBest": equals_best,
            "Score": score,
            "Top5Moves": ", ".join(sorted(top5)) if top5 else "",
            "BestMove": best_move or "",
            "RawOutput": output,
        })
        i+=1
  return rows

rows = test_model(test_ds, model,ANSWER_RE,REASONING_RE)
df = pd.DataFrame(rows).sort_values("Score", ascending=False).reset_index(drop=True)
print("\n=== SCORE TABLE (top 10 rows) ===")
print(df[["idx","Move","InTop5","EqualsBest","HasReasoningTag","HasAnswerTag","Score"]].head(50))

# save
df.to_csv("sft_answer_scoring_before_training.csv", index=False)
print("\nSaved: sft_answer_scoring_before_training.csv")
print("\nStarting training...\n")


TESTING MODEL BEFORE TRAINING + SCORING

Example 1/30
FEN: 7k/p1r2b2/4pq2/1p1p1nR1/5P2/P2B4/1P2Q2P/1K4R1 w - - 3 31
Color: White

Generated Answer:
<reasoning>
The position is with White to move. The key elements are: Black has a strong king-side presence with a rook on the 8th rank and a king on the 8th file (a1), and White has a queen, several pawns, and a bishop on the queenside. The white king is on d8, and the white queen is on c7. The central pawns are on d5 and e5, and there is a knight on f6. The black queen is on d6, and the black bishop is on c8. White has a strong initiative on the queens

Example 2/30
FEN: r1b1k2r/pp1n2pp/1qn1p3/3pp3/1b1P1P2/3B1N2/PP1BN1PP/R2QK2R w KQkq - 0 12
Color: White

Generated Answer:
<reasoning>
The position is with White to move. The white king is on e1, and the queen is on d4. There are pawns on the queenside and center, with a bishop on c5 and a knight on f6. The black pieces include a rook on a8, a knight on c8, and a bishop on b7. White has a s

In [None]:


# Toggle W&B logging (set to False if you don't want to log)
USE_WANDB = True

if USE_WANDB:
    import wandb
    os.environ["WANDB_LOG_MODEL"] = "end"

    os.environ["WANDB_PROJECT"] = "Chess_RL_Project"
    os.environ["WANDB_ENTITY"] = "czovekboti-budapesti-m-szaki-s-gazdas-gtudom-nyi-egyetem"
    wandb.login()
    wandb.init(
        project="Chess_RL_Project",
        entity="",
        name=config.get("name", "run"),
        config={
            "model": config["model"],
            "max_seq_length": config['max_seq_length'],
            "lora_rank": lora_rank,
            "learning_rate": config["learning_rate"],
            "max_steps": config["max_steps"],
        }
    )
    print("W&B initialized.")
else:
    print("W&B disabled.")

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Currently logged in as: [33mczovekboti[0m ([33mczovekboti-budapesti-m-szaki-s-gazdas-gtudom-nyi-egyetem[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


W&B initialized.


In [10]:
formatted_texts = []
from datasets import Dataset
print(train_ds)
for ex in train_ds:
    fen = (ex.get("fen") or "").strip()
    ans = (ex.get("answer") or "").strip()
    if not fen or not ans:
        continue
    color = fen_color(fen)
    messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": f"\nFEN: {fen}\nYou are with the following pieces: {color}"},
            {"role": "assistant", "content": ans},

    ]
    formatted_texts.append({"conversations": messages})
def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }
from datasets import Dataset
print(formatted_texts)
# formatted_texts is just a list of dicts
dataset = Dataset.from_list(formatted_texts)

dataset = dataset.map(formatting_prompts_func, batched = True,)


Dataset({
    features: ['fen', 'top_5_moves', 'answer'],
    num_rows: 70
})
[{'conversations': [{'role': 'system', 'content': 'You are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and suggest the best legal move for the player whose turn it is.\n\nPlease follow this exact format in your response:\n\n<reasoning>\n(Brief explanation of what you see on the board — piece activity, threats, and candidate moves)\n</reasoning>\n<answer>\n(best move written in correct SAN format, such as Nf3 or exd5)\n</answer>\n\nDo not invent illegal or impossible moves. The move must be legal in the given FEN position.\nDo not use UCI format like e2e4 — only SAN notation like e4, Nf3, or O-O.\nIn case of taking a piece use the [file]x[target square] format\n### Example:\nFEN: rnbqkbnr/pppppppp/8/8/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 1\n\n<reasoning>\nWhite has just played e4 and developed the knight to f3. It’s Black’s turn. The e4 pawn is un

Map:   0%|          | 0/70 [00:00<?, ? examples/s]

In [11]:
from trl import SFTConfig, SFTTrainer
from transformers import DataCollatorForSeq2Seq
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    packing = False, # Can make training 5x faster for short sequences.
    args = SFTConfig(
        completion_only_loss=True,
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 5, # Set this for 1 full training run.
        max_steps = len(dataset)*5,
        learning_rate = 1e-3,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "wandb", # Use TrackIO/WandB etc
    ),
)

Unsloth: Tokenizing ["text"] (num_proc=6):   0%|          | 0/70 [00:00<?, ? examples/s]

In [12]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 70 | Num Epochs = 39 | Total steps = 350
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 66,060,288 of 4,087,844,864 (1.62% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,2.6972
2,2.6721
3,1.9809
4,1.5444
5,1.1829
6,0.8517
7,0.6324
8,0.5585
9,0.4798
10,0.4357




TrainOutput(global_step=350, training_loss=0.09355128367564508, metrics={'train_runtime': 2166.5516, 'train_samples_per_second': 1.292, 'train_steps_per_second': 0.162, 'total_flos': 2.967348449939251e+16, 'train_loss': 0.09355128367564508, 'epoch': 38.91428571428571})

In [13]:
trained_model = trainer.model

In [17]:
# for i in range(num_test_examples):
#     example = original_dataset[i]
#     fen = example["fen"]
#     color = fen_color(fen)

#     prompt = build_prompt(fen, color)
#     output = generate_answer(trained_model, prompt, max_new_tokens=128, temperature=0.0, top_p=1.0, top_k=0)  # deterministic

#     move = extract_move(output)
#     has_reasoning = bool(REASONING_RE.search(output))
#     has_answer_tag = bool(ANSWER_RE.search(output))

#     top5 = to_top5_set(example.get("top_5_moves"))
#     in_top5 = (move in top5) if move else False

#     best_move = (example.get("best_move") or "").strip() or None
#     equals_best = (move == best_move) if (move and best_move) else False
#     score = 0
#     score += 1 if has_reasoning else 0
#     score += 1 if has_answer_tag else 0
#     score += 1 if move else 0
#     score += 2 if in_top5 else 0
#     score += 2 if equals_best else 0

#     print(f"\n{'='*80}")
#     print(f"Example {i+1}/{num_test_examples}")
#     print(f"{'='*80}")
#     print(f"FEN: {fen}")
#     print(f"Color: {color}")
#     print(f"\nGenerated Answer:\n{output}")

#     rows.append({
#         "idx": i,
#         "FEN": fen,
#         "Color": color,
#         "Move": move or "",
#         "HasReasoningTag": has_reasoning,
#         "HasAnswerTag": has_answer_tag,
#         "InTop5": in_top5,
#         "EqualsBest": equals_best,
#         "Score": score,
#         "Top5Moves": ", ".join(sorted(top5)) if top5 else "",
#         "BestMove": best_move or "",
#         "RawOutput": output,
#     })
rows  = test_model(test_ds,trained_model,ANSWER_RE,REASONING_RE)
df = pd.DataFrame(rows).sort_values("Score", ascending=False).reset_index(drop=True)
print("\n=== SCORE TABLE (top 10 rows) ===")
print(df[["idx","Move","InTop5","EqualsBest","HasReasoningTag","HasAnswerTag","Score"]].head(50))
df.to_csv("sft_answer_scoring_after_training.csv", index=False)



Example 1/30
FEN: 7k/p1r2b2/4pq2/1p1p1nR1/5P2/P2B4/1P2Q2P/1K4R1 w - - 3 31
Color: White

Generated Answer:
<think>

</think>

<reasoning>
g2 is a weak move as it fails to develop a piece effectively and does not provide immediate threats or positional advantages. It also allows the opponent to recapture on g2 without losing material.

Rg2 is a strong move as it develops the rook to a powerful square and threatens to open up the center. This move puts pressure on the black king and prepares for future attacks.

Therefore, Rg2 is the best move because it maximizes the potential of the rook and creates immediate threats.
</reasoning>
<answer>
Rg2
</answer><|im_end|>

Example 2/30
FEN: r1b1k2r/pp1n2pp/1qn1p3/3pp3/1b1P1P2/3B1N2/PP1BN1PP/R2QK2R w KQkq - 0 12
Color: White

Generated Answer:
<think>

</think>

<reasoning>
e2 is a weak move as it fails to develop a piece meaningfully and does not advance the pawn structure effectively. 

Ne2 is the best move as it develops a knight to a strong

In [19]:
import chess
import pandas as pd

# simple legality check
def is_legal_move(fen, san):
    try:
        board = chess.Board(fen)
        move = board.parse_san(san)  # will raise error if move invalid
        return move in board.legal_moves
    except Exception:
        return False
before_df = pd.read_csv("sft_answer_scoring_before_training.csv")
before_df["is_legal"] = before_df.apply(lambda x: is_legal_move(x["FEN"], x["Move"]), axis=1)
after_df = pd.read_csv("sft_answer_scoring_after_training.csv")
after_df["is_legal"] = after_df.apply(lambda x: is_legal_move(x["FEN"], x["Move"]), axis=1)
legal_count = before_df["is_legal"].sum()
print("Number of legal moves:", legal_count)
legal_count = after_df["is_legal"].sum()
print("Number of legal moves:", legal_count)






Number of legal moves: 0
Number of legal moves: 8
