[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](
  https://colab.research.google.com/github/czovekboti/chess_rl/blob/sft%2Bgrpo/SFT%20Trainer%20notebook.ipynb
)


In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1" # [NEW] Extra 30% context lengths!
if "COLAB_" not in "".join(os.environ.keys()):
    # If you're not in Colab, just use pip install or uv pip install
    !pip install unsloth vllm
else:
    pass # For Colab / Kaggle, we need extra instructions hidden below \/

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
!pip install python-chess
!apt-get install stockfish

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
!pip install --upgrade -qqq uv
if "COLAB_" not in "".join(os.environ.keys()):
    # If you're not in Colab, just use pip install!
    !pip install unsloth vllm
else:
    try: import numpy, PIL; get_numpy = f"numpy=={numpy.__version__}"; get_pil = f"pillow=={PIL.__version__}"
    except: get_numpy = "numpy"; get_pil = "pillow"
    try: import subprocess; is_t4 = "Tesla T4" in str(subprocess.check_output(["nvidia-smi"]))
    except: is_t4 = False
    get_vllm, get_triton = ("vllm==0.9.2", "triton==3.2.0") if is_t4 else ("vllm==0.10.2", "triton")
    !uv pip install -qqq --upgrade \
        unsloth {get_vllm} {get_numpy} {get_pil} torchvision bitsandbytes xformers
    !uv pip install -qqq {get_triton}
!uv pip install transformers==4.56.2
!uv pip install --no-deps trl==0.22.2

In [6]:
#@title Load model config{ display-mode: "form" }
import os
from dotenv import load_dotenv
load_dotenv()

import yaml

# Path to your YAML config file
# path = '/content/config.yaml'
path = 'config.yaml'

def load_config(path: str):
    with open(path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config = load_config(path)
config_name = "qwen4b"
print("Selected config_name:", config_name)

match config_name:
    case "llama":
        config = config["llama_config"]
    case "phi":
        config = config["PHI_config"]
    case "mistral":
        config = config["mistral_config"]
    case "qwen7b":
        config = config["qwen7b_config"]
    case "qwen4b":
        config = config["qwen4b_config"]
    case _:
        raise ValueError("Check model name â€“ perhaps the keyboard got excited.")

# Stockfish path from env
# stockfish_path = os.getenv("STOCKFISH_PATH")
stockfish_path= '/usr/games/stockfish'

print("STOCKFISH_PATH:", stockfish_path)

Selected config_name: qwen4b
STOCKFISH_PATH: /usr/games/stockfish


# Load dataset

In [7]:
from sklearn.model_selection import train_test_split
from datasets import load_dataset,concatenate_datasets
dataset = load_dataset("czovekboti/sft_training2")
print(len(dataset["train"]))

json_dataset = load_dataset("json", data_files="training.json")

# Concatenate the 'train' splits from both datasets
dataset["train"] = concatenate_datasets([
    dataset["train"],
    json_dataset["train"]
])
print(len(dataset["train"]))


14547
16108


# Load model


In [4]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = config["model"],
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = False, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.7, # Reduce if out of memory
)
#add paddings
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token or "<|pad|>"

model.resize_token_embeddings(len(tokenizer))

target_modules = [
    "q_proj","k_proj","v_proj","o_proj",
    "gate_proj","up_proj","down_proj",

]
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = target_modules,
    lora_alpha = lora_rank, # *2 speeds up training
    use_gradient_checkpointing = "unsloth", # Reduces memory usage
    random_state = 3407,
)



ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.




INFO 11-24 09:51:07 [__init__.py:244] Automatically detected platform cuda.
ERROR 11-24 09:51:09 [fa_utils.py:57] Cannot use FA version 2 is not supported due to FA2 is only supported on devices with compute capability >= 8
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.10.11: Fast Qwen3 patching. Transformers: 4.57.1. vLLM: 0.9.2.
   \\   /|    Tesla V100-PCIE-32GB. Num GPUs = 2. Max memory: 31.733 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 7.0. CUDA Toolkit: 12.6. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Unsloth 2025.10.11 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


# Prompt, data preparation and model test
model test is commented out run for before-after comparison

In [8]:
from sklearn.model_selection import train_test_split
from datasets import load_dataset
# original_dataset = load_dataset("czovekboti/sft_chess", split="train")

split = dataset['train'].train_test_split(test_size=0.1, seed=42)
train_ds = split["train"]
test_ds = split["test"]
print(train_ds)
print(test_ds)

Dataset({
    features: ['fen', 'top_5_moves', 'answer'],
    num_rows: 14497
})
Dataset({
    features: ['fen', 'top_5_moves', 'answer'],
    num_rows: 1611
})


In [11]:
SYSTEM_PROMPT = """You are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and suggest the best legal move for the player whose turn it is.

Please follow this exact format in your response:

<reasoning>
(Brief explanation of what you see on the board â€” piece activity, threats, and candidate moves)
</reasoning>
<answer>
(best move written in correct SAN format, such as Nf3 or exd5)
</answer>

Do not invent illegal or impossible moves. The move must be legal in the given FEN position.
Do not use UCI format like e2e4 â€” only SAN notation like e4, Nf3, or O-O.
In case of taking a piece use the [file]x[target square] format
### Example:
FEN: rnbqkbnr/pppppppp/8/8/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 1

<reasoning>
White has just played e4 and developed the knight to f3. Itâ€™s Blackâ€™s turn. The e4 pawn is undefended. Capturing it with the pawn from d7 to d5 is a natural central counter.
</reasoning>
<answer>
d5
</answer>

Now solve the following position:
"""

def fen_color(fen: str) -> str:
    try:
        parts = fen.strip().split()
        side = parts[1]
        return "White" if side == "w" else "Black" if side == "b" else "Unknown"
    except Exception:
        return "Unknown"

# ---------- add this block after loading model/tokenizer ----------
import torch, re, pandas as pd
from datasets import load_dataset

# (optional but recommended) ensure a chat template is set for base models
try:
    from unsloth.chat_templates import get_chat_template
    if getattr(tokenizer, "chat_template", None) in (None, ""):
        tokenizer = get_chat_template(tokenizer, chat_template="qwen3")
except Exception:
    # fallback: minimal qwen-like template
    if getattr(tokenizer, "chat_template", None) in (None, ""):
        tokenizer.chat_template = (
            "{% for m in messages %}"
            "{% if m['role'] == 'system' %}<|im_start|>system\n{{ m['content'] }}<|im_end|>\n"
            "{% elif m['role'] == 'user' %}<|im_start|>user\n{{ m['content'] }}<|im_end|>\n"
            "{% elif m['role'] == 'assistant' %}<|im_start|>assistant\n{{ m['content'] }}<|im_end|>\n"
            "{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
        )



if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token



def extract_move(text: str):
    m = ANSWER_RE.search(text or "")
    return m.group(1).strip() if m else None

def to_top5_set(v):
    if v is None:
        return set()
    if isinstance(v, (list, tuple, set)):
        return set(str(x).strip() for x in v if str(x).strip())
    parts = [p.strip() for p in str(v).replace(";", ",").split(",")]
    return set(p for p in parts if p)

def build_prompt(fen: str, color: str):
    msgs = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": f"FEN: {fen}\nYou are with the following pieces: {color}"},
    ]
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

def generate_answer(the_model, prompt: str, max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=50):
    the_model.eval()
    with torch.no_grad():
        inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        out = the_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=temperature > 0,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True,
        )
    gen_ids = out[0][inputs["input_ids"].shape[-1]:]
    return tokenizer.decode(gen_ids, skip_special_tokens=False)


print("=" * 80)
print("TESTING MODEL BEFORE TRAINING + SCORING")
print("=" * 80)

ANSWER_RE = re.compile(r"<answer>\s*(.*?)\s*</answer>", re.DOTALL | re.IGNORECASE)
REASONING_RE = re.compile(r"<reasoning>\s*(.*?)\s*</reasoning>", re.DOTALL | re.IGNORECASE)
DEVICE = next(model.parameters()).device if 'model' in globals() else (
    torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
def test_model(train_ds, model,answer_re,reasoning_re, max = len(train_ds)):
  i=0
  rows = []
  for example in train_ds:
        if i== max:
          return rows
        fen = example["fen"]
        color = fen_color(fen)
        prompt = build_prompt(fen, color)
        output = generate_answer(model, prompt, max_new_tokens=1024, temperature=0.0, top_p=1.0, top_k=0)  # deterministic

        move = extract_move(output)
        has_reasoning = bool(REASONING_RE.search(output))
        has_answer_tag = bool(ANSWER_RE.search(output))

        top5 = to_top5_set(example.get("top_5_moves"))
        in_top5 = (move in top5) if move else False

        best_move = (example.get("best_move") or "").strip() or None
        equals_best = (move == best_move) if (move and best_move) else False

        score = 0
        score += 1 if has_reasoning else 0
        score += 1 if has_answer_tag else 0
        score += 1 if move else 0
        score += 2 if in_top5 else 0
        score += 2 if equals_best else 0

        print(f"\n{'='*80}")
        print(f"Example {i+1}/{len(train_ds)}")
        print(f"{'='*80}")
        print(f"FEN: {fen}")
        print(f"Color: {color}")
        print(f"\nGenerated Answer:\n{output}")

        rows.append({
            "idx": i,
            "FEN": fen,
            "Color": color,
            "Move": move or "",
            "HasReasoningTag": has_reasoning,
            "HasAnswerTag": has_answer_tag,
            "InTop5": in_top5,
            "EqualsBest": equals_best,
            "Score": score,
            "Top5Moves": ", ".join(sorted(top5)) if top5 else "",
            "BestMove": best_move or "",
            "RawOutput": output,
        })
        i+=1
  return rows



TESTING MODEL BEFORE TRAINING + SCORING


In [8]:
# MODEL TEST =====
rows = test_model(test_ds, model,ANSWER_RE,REASONING_RE,10)
df = pd.DataFrame(rows).sort_values("Score", ascending=False).reset_index(drop=True)
print("\n=== SCORE TABLE (top 10 rows) ===")
print(df[["idx","Move","InTop5","EqualsBest","HasReasoningTag","HasAnswerTag","Score"]].head(50))

# # save
df.to_csv("sft_answer_scoring_before_training.csv", index=False)
print("\nSaved: sft_answer_scoring_before_training.csv")
print("\nStarting training...\n")


Example 1/1611
FEN: 1rbq1rk1/p3nppp/1p2pb2/3n4/3PB2P/P1N2N2/1P3PP1/R1BQR1K1 w - - 0 14
Color: White

Generated Answer:
<reasoning>
The position is with White to move. The white king is on e1, and the queen is on d1. White has a strong central presence with pawns on c3, d4, e4, and f3, and pieces including a bishop on c5 and a knight on f3. Black has a knight on c8 and a bishop on c7, and the black king is on e8. The white pawns on d4 and e4 are central and supported by the knight on f3. The black knight on c8 is active and threatens to attack the d4 pawn. White

Example 2/1611
FEN: 8/p2q2kp/5np1/2Qp1n2/8/1P6/P4PPP/4R1K1 w - - 0 34
Color: White

Generated Answer:
<reasoning>
The position is with White to move. The key elements are: White has a rook on the back rank (on h8), a queen on d4, and pawns on the queenside and center. Black has a queen on c5, a knight on e6, and a pawn on d7. Whiteâ€™s king is on e8, and the pawns on the queenside are advancing. The most immediate threat is th

# WANDB

In [7]:


# Toggle W&B logging (set to False if you don't want to log)
USE_WANDB = True

if USE_WANDB:
    import wandb
    os.environ["WANDB_LOG_MODEL"] = "end"
    os.environ["WANDB_PROJECT"] = "Chess_RL_Project"
    os.environ["WANDB_ENTITY"] = "czovekboti-budapesti-m-szaki-s-gazdas-gtudom-nyi-egyetem"
    wandb.login()
    wandb.init(
        project="Chess_RL_Project",
        entity="czovekboti-budapesti-m-szaki-s-gazdas-gtudom-nyi-egyetem",
        name= "QWEN4B last sft test run (TARS)",
        config={
            "model": config["model"],
            "max_seq_length": config['max_seq_length'],
            "lora_rank": lora_rank,
            "learning_rate": config["learning_rate"],
            "max_steps": config["max_steps"],
        }
    )
    print("W&B initialized.")
else:
    print("W&B disabled.")

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlepkepukivadasz[0m ([33mczovekboti-budapesti-m-szaki-s-gazdas-gtudom-nyi-egyetem[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


W&B initialized.


# Formatting dataset to match gpro training input

In [9]:
from datasets import Dataset
print(train_ds)
def format_dataset(ds):
    formatted_texts = []
    for ex in ds:
        fen = (ex.get("fen") or "").strip()
        ans = (ex.get("answer") or "").strip()
        if not fen or not ans:
            continue
        color = fen_color(fen)
        messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": f"\nFEN: {fen}\nYou are with the following pieces: {color}"},
                {"role": "assistant", "content": ans},

        ]
        formatted_texts.append({"conversations": messages})
    dataset = Dataset.from_list(formatted_texts)
    dataset = dataset.map(formatting_prompts_func, batched = True,)
    return dataset
def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }
final_train_ds = format_dataset(train_ds)
print(final_train_ds[0])
final_test_ds = format_dataset(test_ds)
print(final_test_ds[0])
print(len(final_train_ds))
print(len(final_test_ds))


Dataset({
    features: ['fen', 'top_5_moves', 'answer'],
    num_rows: 14497
})


Map:   0%|          | 0/14497 [00:00<?, ? examples/s]

{'conversations': [{'content': 'You are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and suggest the best legal move for the player whose turn it is.\n\nPlease follow this exact format in your response:\n\n<reasoning>\n(Brief explanation of what you see on the board â€” piece activity, threats, and candidate moves)\n</reasoning>\n<answer>\n(best move written in correct SAN format, such as Nf3 or exd5)\n</answer>\n\nDo not invent illegal or impossible moves. The move must be legal in the given FEN position.\nDo not use UCI format like e2e4 â€” only SAN notation like e4, Nf3, or O-O.\nIn case of taking a piece use the [file]x[target square] format\n### Example:\nFEN: rnbqkbnr/pppppppp/8/8/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 1\n\n<reasoning>\nWhite has just played e4 and developed the knight to f3. Itâ€™s Blackâ€™s turn. The e4 pawn is undefended. Capturing it with the pawn from d7 to d5 is a natural central counter.\n</reaso

Map:   0%|          | 0/1611 [00:00<?, ? examples/s]

{'conversations': [{'content': 'You are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and suggest the best legal move for the player whose turn it is.\n\nPlease follow this exact format in your response:\n\n<reasoning>\n(Brief explanation of what you see on the board â€” piece activity, threats, and candidate moves)\n</reasoning>\n<answer>\n(best move written in correct SAN format, such as Nf3 or exd5)\n</answer>\n\nDo not invent illegal or impossible moves. The move must be legal in the given FEN position.\nDo not use UCI format like e2e4 â€” only SAN notation like e4, Nf3, or O-O.\nIn case of taking a piece use the [file]x[target square] format\n### Example:\nFEN: rnbqkbnr/pppppppp/8/8/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 1\n\n<reasoning>\nWhite has just played e4 and developed the knight to f3. Itâ€™s Blackâ€™s turn. The e4 pawn is undefended. Capturing it with the pawn from d7 to d5 is a natural central counter.\n</reaso

# TRAIN

Trains on whole dataset, saves every 1000th step.
Uses validation

In [10]:
from trl import SFTConfig, SFTTrainer
from transformers import DataCollatorForSeq2Seq
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = final_train_ds,
    eval_dataset=final_test_ds,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    packing = False, # Can make training 5x faster for short sequences.
    args = SFTConfig(
        completion_only_loss=True,
        per_device_train_batch_size = 2,
        per_device_eval_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,   # TODO
        num_train_epochs = 1, # Set this for 1 full training run.
        # max_steps = len(final_train_ds)+ len(final_test_ds),
        max_steps=500,
        metric_for_best_model = "eval_loss",
        eval_strategy="steps",   # evaluate every N steps
        eval_steps=50,                 # <-- evaluate every 10 steps
        save_strategy="steps",
        save_steps= 250,
        save_total_limit = 5,
        learning_rate = 1e-3,  # TODO
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear", # ->
        seed = 3407,
        output_dir = "./sft_outputs",
        report_to = "wandb", # Use TrackIO/WandB
    ),
)

Unsloth: Tokenizing ["text"] (num_proc=64):   0%|          | 0/14497 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"] (num_proc=64):   0%|          | 0/1611 [00:00<?, ? examples/s]

In [11]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.
The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 14,497 | Num Epochs = 9 | Total steps = 16,108
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 66,060,288 of 4,087,844,864 (1.62% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,Validation Loss
10,0.4708,0.455616
20,0.4208,0.413933
30,0.4261,0.460411
40,0.3922,0.379437
50,0.395,0.363148
60,0.368,0.355242
70,0.3622,0.348036
80,0.3629,0.343962
90,0.3249,0.348798
100,0.3859,0.342536


Unsloth: Not an error, but Qwen3ForCausalLM does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient


KeyboardInterrupt: 

# Test model answers after training

In [12]:
trained_model = trainer.model

In [2]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "./sft_outputs/checkpoint-1000/",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = False, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.7, # Reduce if out of memory
    resize_model_vocab = 151669,
)

# if NO_LORA:
#     model_args["resize_model_vocab"] = 151669

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.




INFO 11-25 10:06:19 [__init__.py:244] Automatically detected platform cuda.
ERROR 11-25 10:06:21 [fa_utils.py:57] Cannot use FA version 2 is not supported due to FA2 is only supported on devices with compute capability >= 8
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.10.11: Fast Qwen3 patching. Transformers: 4.57.1. vLLM: 0.9.2.
   \\   /|    Tesla V100-PCIE-32GB. Num GPUs = 2. Max memory: 31.733 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 7.0. CUDA Toolkit: 12.6. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Unsloth 2025.10.11 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


In [16]:
# 1. Load the BASE model first (use the same config string you used for training)
# Do NOT put the checkpoint path here yet.
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = config["model"], # e.g., "unsloth/Qwen2.5-..."
    max_seq_length = max_seq_length,
    load_in_4bit = False,
    fast_inference = False,
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.4,
    device_map='cuda:1'
)

# 2. Re-apply the EXACT tokenizer padding and resizing logic used in training
# This ensures the base model shape matches what the adapter expects
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token or "<|pad|>"

model.resize_token_embeddings(len(tokenizer))

# 3. Now load your local checkpoint adapter on top
model.load_adapter("./sft_outputs/checkpoint-1000/")

# 4. (Optional) Set inference mode
FastLanguageModel.for_inference(model)

print("Local model with resized embeddings loaded successfully!")

==((====))==  Unsloth 2025.10.11: Fast Qwen3 patching. Transformers: 4.57.1. vLLM: 0.9.2.
   \\   /|    Tesla V100-PCIE-32GB. Num GPUs = 2. Max memory: 31.733 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 7.0. CUDA Toolkit: 12.6. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Local model with resized embeddings loaded successfully!


In [3]:
trained_model = model

In [13]:
rows  = test_model(test_ds,trained_model,ANSWER_RE,REASONING_RE, 20)
df = pd.DataFrame(rows).sort_values("Score", ascending=False).reset_index(drop=True)
print("\n=== SCORE TABLE (top 10 rows) ===")
print(df[["idx","Move","InTop5","EqualsBest","HasReasoningTag","HasAnswerTag","Score"]].head(50))
df.to_csv("sft_answer_scoring_after_training.csv", index=False)



Example 1/1611
FEN: 1rbq1rk1/p3nppp/1p2pb2/3n4/3PB2P/P1N2N2/1P3PP1/R1BQR1K1 w - - 0 14
Color: White

Generated Answer:
<think>

</think>

<reasoning>
c3 develops the central pawn and prepares to push it to c4, creating pressure on the black king and controlling the center.
Nc3 develops a piece to a strong square and prepares to support the center and control the center.
Nf3 is solid but less ambitious, focusing on developing a piece without immediately threatening black's position.
h3 and c3 are slow moves that fail to fight for space and control the center effectively.
Therefore, c3 is best because it develops a pawn to a strong square and prepares to push it to c4, creating pressure on the black king and controlling the center.
</reasoning>
<answer>
c3
</answer><|im_end|>

Example 2/1611
FEN: 8/p2q2kp/5np1/2Qp1n2/8/1P6/P4PPP/4R1K1 w - - 0 34
Color: White

Generated Answer:
<think>

</think>

<reasoning>
f4 is a strong move as it develops the knight to a powerful square and prepares 

If ran test before training comparison of legal moves is possible.


In [14]:
import chess
import pandas as pd

# simple legality check
def is_legal_move(fen, san):
    try:
        board = chess.Board(fen)
        move = board.parse_san(san)  # will raise error if move invalid
        return move in board.legal_moves
    except Exception:
        return False
before_df = pd.read_csv("sft_answer_scoring_before_training.csv")
before_df["is_legal"] = before_df.apply(lambda x: is_legal_move(x["FEN"], x["Move"]), axis=1)
after_df = pd.read_csv("sft_answer_scoring_after_training.csv")
after_df["is_legal"] = after_df.apply(lambda x: is_legal_move(x["FEN"], x["Move"]), axis=1)
legal_count = before_df["is_legal"].sum()
print("Number of legal moves:", legal_count)
legal_count = after_df["is_legal"].sum()
print("Number of legal moves:", legal_count)






Number of legal moves: 0
Number of legal moves: 5
