<a href="https://colab.research.google.com/github/czovekboti/chess_rl/blob/sft%2Bgrpo/Final_models_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](
  https://colab.research.google.com/github/czovekboti/chess_rl/blob/sft%2Bgrpo/Model%20Evaluation.ipynb
)


In [None]:
#@title Colab Extra Install { display-mode: "form" }
#%%capture
!pip install unsloth python-chess datasets matplotlib pandas tqdm python-dotenv colorama transformers
!pip install python-chess
!apt-get install stockfish

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
stockfish is already the newest version (14.1-1).
0 upgraded, 0 newly installed, 0 to remove and 41 not upgraded.


# Add how many examples the script should eval

In [None]:
number_of_examples = 10
model_path = "unsloth/Qwen2.5-Coder-7B-Instruct"

In [None]:
!pip install vllm



In [None]:
# ================== SETUP IMPORTS ==================
# pip-eket kÃ¼lÃ¶n cellÃ¡ban Ã©rdemes futtatni:
#!pip install unsloth python-chess datasets matplotlib pandas tqdm python-dotenv colorama transformers

from dotenv import load_dotenv
load_dotenv()
import json, os, re, random, warnings, textwrap
from datetime import datetime
from pathlib import Path
from collections import defaultdict
from types import SimpleNamespace

import torch
import chess
import matplotlib
matplotlib.use("Agg")  # headless-safe backend for saving plots
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset
from unsloth import FastLanguageModel
from transformers import AutoTokenizer

# progress + color
from colorama import Fore, Style, init as colorama_init
colorama_init(autoreset=True)
try:
    from tqdm import tqdm
except Exception:
    tqdm = None

# -------------------- utils --------------------
UCI_RE = re.compile(r"\b([a-h][1-8])([a-h][1-8])([qrbn])?\b", re.IGNORECASE)
def sanitize(s): return re.sub(r"[^a-zA-Z0-9._-]+", "_", s)
def now_tag():   return datetime.now().strftime("%Y%m%d_%H%M%S")
def side_to_move_text(fen): return "White" if fen.split()[1] == "w" else "Black"

def uci_is_legal(fen, u):
    try:
        b = chess.Board(fen); mv = chess.Move.from_uci(u.lower()); return mv in b.legal_moves
    except: return False

def san_is_legal(fen, san):
    try:
        b = chess.Board(fen); mv = b.parse_san(san); return mv in b.legal_moves
    except: return False

def san_to_uci(fen, san):
    b = chess.Board(fen); mv = b.parse_san(san); return mv.uci()

def extract_first_uci(text):
    m = UCI_RE.search(text); return m.group(0).lower() if m else None

def extract_answer_block(text):
    m = re.search(r"<answer>\s*(.*?)\s*</answer>", text, flags=re.DOTALL|re.IGNORECASE)
    if not m: return None
    return m.group(1).strip().split()[0]

def variant_label(v):
    mode = "normal" if v["include_reasoning"] else "legal"
    return f"{v['move_format']}/R={mode}/T={v.get('temperature')}"


# -------------------- ILLEGAL MOVE ANALYSIS --------------------

def diagnose_illegal_move(fen, pred_move, move_format):
    if not pred_move:
        return "EMPTY_OUTPUT"

    b = chess.Board(fen)

    try:
        if move_format.upper() == "UCI":
            if not re.match(r'^[a-h][1-8][a-h][1-8][qrbn]?$', pred_move.lower()):
                return "MALFORMED_UCI"
            try:
                mv = chess.Move.from_uci(pred_move.lower())
            except:
                return "INVALID_UCI_SYNTAX"

            if mv not in b.legal_moves:
                from_sq = mv.from_square
                to_sq = mv.to_square
                piece = b.piece_at(from_sq)

                if piece is None:
                    return "EMPTY_SQUARE"
                if piece.color != b.turn:
                    return "WRONG_COLOR"

                b_test = b.copy()
                try:
                    b_test.push(mv)
                    return "LEAVES_KING_IN_CHECK"
                except:
                    pass

                if b.is_pinned(piece.color, from_sq):
                    return "PINNED_PIECE"

                if b.is_check() and not b.is_legal(mv):
                    return "DOESNT_ESCAPE_CHECK"

                if mv in b.pseudo_legal_moves:
                    return "PSEUDO_LEGAL_BUT_ILLEGAL"

                return "ILLEGAL_PIECE_MOVE"

        else:  # SAN
            try:
                mv = b.parse_san(pred_move)
            except ValueError as e:
                error_str = str(e).lower()
                if "ambiguous" in error_str:
                    return "AMBIGUOUS_SAN"
                elif "illegal" in error_str:
                    return "ILLEGAL_SAN"
                else:
                    return "INVALID_SAN_SYNTAX"
            except:
                return "SAN_PARSE_ERROR"

            if mv not in b.legal_moves:
                piece = b.piece_at(mv.from_square)
                if piece and piece.color != b.turn:
                    return "WRONG_COLOR"
                if b.is_check():
                    return "DOESNT_ESCAPE_CHECK"
                return "ILLEGAL_MOVE"

        return "LEGAL"

    except Exception as e:
        return f"EXCEPTION_{type(e).__name__}"

def analyze_illegal_patterns(logs_path):
    illegal_by_type = defaultdict(int)
    illegal_by_variant = defaultdict(lambda: defaultdict(int))

    try:
        with open(logs_path, 'r', encoding='utf-8') as f:
            for line in f:
                # --- FIX: Handle empty lines ---
                if not line.strip():
                    continue
                try:
                    entry = json.loads(line)
                except json.JSONDecodeError:
                    continue
                # -------------------------------

                if not entry['eval']['valid']:
                    variant_key = variant_label(entry['variant'])
                    diagnosis = diagnose_illegal_move(
                        entry['fen'],
                        entry['eval'].get('pred_move'),
                        entry['variant']['move_format']
                    )
                    illegal_by_type[diagnosis] += 1
                    illegal_by_variant[variant_key][diagnosis] += 1
    except FileNotFoundError:
        print("Warning: Log file not found for illegal pattern analysis.")

    return dict(illegal_by_type), dict(illegal_by_variant)

# -------------------- STRATIFIED ANALYSIS --------------------

def classify_position_complexity(fen):
    b = chess.Board(fen)
    legal_moves = list(b.legal_moves)
    num_legal = len(legal_moves)
    piece_map = b.piece_map()
    num_pieces = len(piece_map)

    if num_pieces <= 6:
        phase = "endgame"
    elif num_pieces <= 16:
        phase = "middlegame"
    else:
        phase = "opening"

    num_pawns = sum(1 for p in piece_map.values() if p.piece_type == chess.PAWN)
    num_queens = sum(1 for p in piece_map.values() if p.piece_type == chess.QUEEN)

    in_check = b.is_check()
    has_checks = any(b.gives_check(m) for m in legal_moves)
    has_captures = any(b.is_capture(m) for m in legal_moves)

    has_ep = b.ep_square is not None
    can_castle = bool(b.castling_rights)

    if num_legal <= 5:
        move_bucket = "very_few"
    elif num_legal <= 15:
        move_bucket = "few"
    elif num_legal <= 30:
        move_bucket = "medium"
    else:
        move_bucket = "many"

    return {
        "num_legal_moves": num_legal,
        "move_bucket": move_bucket,
        "phase": phase,
        "num_pieces": num_pieces,
        "in_check": in_check,
        "has_checks_available": has_checks,
        "has_captures_available": has_captures,
        "has_en_passant": has_ep,
        "can_castle": can_castle,
        "num_pawns": num_pawns,
        "num_queens": num_queens,
    }

def aggregate_by_strata(logs_path):
    strata_results = defaultdict(lambda: {
        "total": 0, "valid": 0, "correct": 0,
        "topk_hit": 0, "engine_best": 0
    })

    try:
        with open(logs_path, 'r', encoding='utf-8') as f:
            for line in f:
                # --- FIX: Handle empty lines ---
                if not line.strip():
                    continue
                try:
                    entry = json.loads(line)
                except json.JSONDecodeError:
                    continue
                # -------------------------------

                complexity = entry.get('complexity', {})
                if not complexity:
                    continue
                fmt = entry['variant']['move_format'].upper()
                interest_keys = ['phase', 'move_bucket', 'in_check', 'has_queens']

                for key in interest_keys:
                    if key in complexity:
                        val = complexity[key]
                        stratum = f"{key}={val}"
                        stats = strata_results[stratum]
                        stats['total'] += 1
                        stats['valid'] += int(entry['eval']['valid'])
                        stats['correct'] += int(entry['eval']['correct'])
                        stats['topk_hit'] += int(entry['eval'].get('topk_hit', 0))

                        fmt_key = f"{fmt}|{stratum}"
                        fmt_stats = strata_results[fmt_key]
                        fmt_stats['total'] += 1
                        fmt_stats['valid'] += int(entry['eval']['valid'])
                        fmt_stats['correct'] += int(entry['eval']['correct'])

                        variant_key = variant_label(entry['variant'])
                        combined_key = f"{variant_key}|{stratum}"
                        combined_stats = strata_results[combined_key]
                        combined_stats['total'] += 1
                        combined_stats['valid'] += int(entry['eval']['valid'])
                        combined_stats['correct'] += int(entry['eval']['correct'])
    except FileNotFoundError:
        print("Warning: Log file not found for strata analysis.")

    for key, stats in strata_results.items():
        total = max(1, stats['total'])
        stats['valid_rate'] = stats['valid'] / total
        stats['correct_rate'] = stats['correct'] / total
        stats['topk_hit_rate'] = stats['topk_hit'] / total

    return dict(strata_results)

# -------------------- dataset mapping --------------------
ALIASES = {
    "fen": ["fen","FEN","position_fen"],
    "uci": ["best_move_uci","uci","BestMoveUCI","bestUCI","best_uci"],
    "san": ["best_move_san","san","BestMoveSAN","bestSAN","best_san"],
    "topk": ["top_moves_uci","TopMovesUCI","engine_top_uci","topk_uci","multipv_uci"],
}

def pick_key(row, keys):
    for k in keys:
        if k in row and row[k] not in (None, "", []):
            return k
    lower = {k.lower(): k for k in row.keys()}
    for k in keys:
        kl = k.lower()
        if kl in lower and row[lower[kl]] not in (None, "", []):
            return lower[kl]
    return None

def normalize_row(row):
    k_fen = pick_key(row, ALIASES["fen"])
    if not k_fen: return None
    fen = row[k_fen]

    k_uci = pick_key(row, ALIASES["uci"])
    k_san = pick_key(row, ALIASES["san"])
    k_top = pick_key(row, ALIASES["topk"])

    best_uci = str(row[k_uci]).strip().lower() if k_uci else None
    best_san = str(row[k_san]).strip() if k_san else None
    top_moves = None
    if k_top:
        vals = row[k_top]
        if isinstance(vals, str):
            try:
                arr = json.loads(vals)
                if isinstance(arr, list): top_moves = [str(m).lower() for m in arr]
            except: pass
        elif isinstance(vals, list):
            top_moves = [str(m).lower() for m in vals]

    out = {"fen": fen, "best_move_uci": best_uci, "best_move_san": best_san, "top_moves_uci": top_moves}
    return out

def load_hf(dataset_name, split, n=None, seed=42, shuffle=True, show_progress=True, log_stats=True, sample_failures=3):
    print(f"[HF] Loading dataset '{dataset_name}' split '{split}'")
    ds = load_dataset(dataset_name, split=split)
    raw_len = len(ds)
    if shuffle:
        ds = ds.shuffle(seed=seed)
        print(f"[HF] Shuffled with seed={seed}")
    if n is not None:
        ds = ds.select(range(min(n, len(ds))))
        print(f"[HF] Selected first {len(ds)} rows (requested n={n})")

    rows = []
    first_keys = None
    failures = []
    iterator = ds
    if show_progress and tqdm is not None:
        iterator = tqdm(ds, desc="HF normalize", leave=False)
    for i, x in enumerate(iterator):
        if first_keys is None:
            first_keys = list(x.keys())
        y = normalize_row(x)
        if y: rows.append(y)
        elif len(failures) < sample_failures:
            failures.append(list(x.keys()))

    stats = {
        "dataset": dataset_name, "split": split,
        "raw_len": raw_len, "post_select_len": len(ds),
        "kept": len(rows), "kept_ratio": (len(rows) / max(1, len(ds))),
        "first_row_keys": first_keys or [],
        "example_failed_keys": failures,
    }
    if log_stats:
        print(f"[HF] Summary: raw={raw_len}, after_select={len(ds)}, kept={len(rows)} ({stats['kept_ratio']:.1%})")
        if len(rows) < 10:
            print(f"[HF] Few rows kept; sample keys: {stats['first_row_keys']}")
            if failures:
                print(f"[HF] Example keys from non-normalized rows (up to {sample_failures}): {failures}")
            print("[HF] Hint: this pipeline assumes the dataset has at least 'FEN'.")
    return rows, stats

# -------------------- prompts --------------------
SYSTEM_BASE = "You are a chess coach assistant. You will be given a board position in FEN format. Your job is to analyze the board and choose the best move from the candidate list."

XML_EXPECTED = """Return output in the following XML structure:

<reasoning>
(Brief explanation of what you see on the board â€” piece activity, threats, and candidate moves)
</reasoning>
<answer>
(best move written in correct {fmt} format, such as Nf3 or exd5)
</answer>
"""
XML_LEGAL= """Return output in the following XML structure:
<reasoning>
(Brief explanation of why the choosen move is legal)
</reasoning>
<answer>
(legal move written in correct {fmt} format, such as Nf3 or exd5)
</answer>
"""
XML_LEGAL_EXAMPLE_SAN = """Example (SAN format):
FEN: 8/8/8/8/4K3/8/8/4k3 w - - 0 1
<reasoning>
White can move the king one square in any direction as long as it does not move into check.
The move Kd4 is legal because the king moves exactly one square horizontally from e4 to d4,
the destination square is empty, and the king is not moving into check or through check.
</reasoning>
<answer>
Kd4
</answer>
"""
XML_LEGAL_EXAMPLE_UCI = """Example (UCI format):
FEN: 8/8/8/8/4K3/8/8/4k3 w - - 0 1
<reasoning>
In UCI notation, a king move is specified by its from-square and to-square.
The move e4d4 is legal because the king moves exactly one square horizontally from e4 to d4,
the destination square is empty, and the king does not move into check.
</reasoning>
<answer>
e4d4
</answer>
"""
XML_EXAMPLE_SAN = """Example:
FEN: 8/8/8/2k5/8/2K5/8/8 w - - 0 1
<reasoning>
White to move in a trivial king opposition scenario; moving the king towards the center maintains opposition.
</reasoning>
<answer>
Kd3
</answer>
"""
XML_EXAMPLE_UCI = """Example:
FEN: 8/8/8/2k5/8/2K5/8/8 w - - 0 1
<reasoning>
White to move in a trivial king opposition scenario; moving the king towards the center maintains opposition.
</reasoning>
<answer>
c3d3
</answer>
"""

def build_prompt(
    fen,
    move_format: str = "SAN",
    include_reasoning: bool = True,
    include_example: bool = False,
    candidate_list=None,
    use_xml: bool = True,
) -> str:
    side = side_to_move_text(fen)
    base = f"It is {side} to move.\nFEN: {fen}\n"
    cand_txt = (
        f"Candidate moves to consider ({move_format.upper()}): {', '.join(candidate_list)}\n"
        if candidate_list else ""
    )
    if use_xml and not include_reasoning:
        fmt_block = XML_LEGAL.format(fmt=move_format.upper())
        example_block = ""
        if include_example:
            if move_format.upper() == "SAN":
                example_block = XML_LEGAL_EXAMPLE_SAN
            else:
                example_block = XML_LEGAL_EXAMPLE_UCI
        instr = (
            "Respond ONLY with the XML block; do not add extra text, labels, or markdown.\n"
            "If you are unsure, still output a single legal move.\n"
        )
        return base + fmt_block + (example_block + "\n" if example_block else "") + instr
    if use_xml and include_reasoning:
        fmt_block = XML_EXPECTED.format(fmt=move_format.upper())
        example_block = ""
        if include_example:
            if move_format.upper() == "SAN":
                example_block = XML_EXAMPLE_SAN
            else:
                example_block = XML_EXAMPLE_UCI
        instr = (
            "Respond ONLY with the XML block; do not add extra text, labels, or markdown.\n"
            "If you are unsure, still output a single legal move.\n"
        )
        return base + fmt_block + (example_block + "\n" if example_block else "") + instr

    return f"{base}Respond ONLY with a single {move_format.upper()} move on one line.\nOutput:\n"

# -------------------- candidates --------------------
def get_candidate_list(fen, target_uci, move_format, K, rng, cand_mode="target_plus_random", dataset_topk=None):
    if K is None: return None
    b = chess.Board(fen)
    legal_ucis = [m.uci() for m in b.legal_moves]
    if cand_mode == "dataset_topk" and dataset_topk:
        norm = [m.lower() for m in dataset_topk]
        if target_uci and target_uci not in norm: norm = [target_uci] + norm
        ucis = [u for u in (norm if K=="ALL" else norm[:int(K)]) if u in legal_ucis]
        if K != "ALL" and len(ucis) < int(K):
            pool = [u for u in legal_ucis if u not in ucis]
            rng.shuffle(pool); ucis += pool[: int(K) - len(ucis)]
    else:
        if K == "ALL":
            ucis = sorted(legal_ucis)
        else:
            pool = [u for u in legal_ucis if u != target_uci]
            rng.shuffle(pool)
            ucis = ([target_uci] if target_uci else []) + pool[: max(0, int(K) - (1 if target_uci else 0))]
            rng.shuffle(ucis)
    if move_format.upper() == "UCI": return ucis
    out = []
    for u in ucis:
        mv = chess.Move.from_uci(u); out.append(b.san(mv))
    return out

# -------------------- model I/O --------------------
def load_model(model_name, dtype="float16", load_in_4bit=False, device_map="auto"):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = "czovekboti/7B_qwen",
        max_seq_length = 1024,
        dtype=getattr(torch, dtype),
        load_in_4bit=load_in_4bit,
        device_map=device_map,
    )
    FastLanguageModel.for_inference(model)
    return model, tokenizer

@torch.inference_mode()
def generate(model, tokenizer, prompt, max_new_tokens=96, temperature=0.2, top_p=0.95,
             repetition_penalty=1.05, use_chat_template=True):
    msgs = [{"role":"system","content":SYSTEM_BASE},{"role":"user","content":prompt}]

    if use_chat_template and getattr(tokenizer, "chat_template", None):
        text_in = tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
    else:
        text_in = "\n\n".join(f"{m['role'].upper()}:\n{m['content']}" for m in msgs) + "\nASSISTANT:\n"

    enc = tokenizer(text_in, return_tensors="pt", truncation=True)
    enc = {k: v.to(model.device) for k, v in enc.items()}

    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    eos_id = tokenizer.eos_token_id or pad_id

    do_sample = temperature is not None and temperature > 0.0
    gen_kwargs = dict(
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature if do_sample else None,
        top_p=top_p if do_sample else None,
        repetition_penalty=repetition_penalty,
        pad_token_id=pad_id,
        eos_token_id=eos_id,
        return_dict_in_generate=True,
        use_cache=True,
    )
    gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}

    out = model.generate(**enc, **gen_kwargs)

    prompt_len = enc["input_ids"].shape[1]
    seq = out.sequences[0]
    new_tokens = seq[prompt_len:]
    raw_resp = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

    return raw_resp

# -------------------- evaluation --------------------
def eval_one(example, variant, answer_text):
    fen = example["fen"]; tgt_uci = example.get("best_move_uci")
    if variant["use_xml"]:
        ans = extract_answer_block(answer_text) or (answer_text.strip().split()[0] if answer_text.strip() else None)
    else:
        ans = answer_text.strip().split()[0] if answer_text.strip() else None

    valid = correct = topk_hit = False
    pred_for_log = ans
    if ans:
        if variant["move_format"].upper() == "UCI":
            pred_uci = extract_first_uci(ans)
            valid = bool(pred_uci and uci_is_legal(fen, pred_uci))
            correct = bool(valid and tgt_uci and pred_uci == tgt_uci)
            topk = example.get("top_moves_uci"); topk_hit = bool(valid and topk and pred_uci in topk)
            pred_for_log = pred_uci or ans
        else:
            if san_is_legal(fen, ans):
                valid = True
                try:
                    pred_uci = san_to_uci(fen, ans)
                    correct = (tgt_uci is not None and pred_uci == tgt_uci)
                    topk = example.get("top_moves_uci"); topk_hit = bool(topk and pred_uci in topk)
                except: pass
    return {"pred_move": pred_for_log, "valid": valid, "correct": correct, "topk_hit": topk_hit}

# -------------------- Stockfish wrapper --------------------
def with_stockfish(stockfish_path, multipv=5, movetime_ms=30):
    if not stockfish_path: return None
    import chess.engine
    engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
    limits = chess.engine.Limit(time=movetime_ms/1000.0)
    def analyze_fen(fen):
        b = chess.Board(fen)
        try:
            info = engine.analyse(b, limits, multipv=multipv)
            if isinstance(info, dict): info = [info]
            top = []
            for entry in info:
                mv = entry.get("pv",[None])[0]
                if mv is None: continue
                top.append(mv.uci())
            best = top[0] if top else None
            return {"engine_top_moves_uci": top, "engine_best_uci": best}
        except Exception:
            return {"engine_top_moves_uci": [], "engine_best_uci": None}
    def close():
        try: engine.quit()
        except: pass
    return analyze_fen, close

# -------------------- hardcoded easy boards ----------
def seed_scholars_mate_frame():
    b = chess.Board()
    for mv in ["e2e4","e7e5","f1c4","b8c6","d1h5","g8f6"]: b.push(chess.Move.from_uci(mv))
    qxf7 = chess.Move.from_uci("h5f7"); assert qxf7 in b.legal_moves
    b.push(qxf7); assert b.is_checkmate(); b.pop()
    return {"group":"MATE_IN_ONE","title":"Scholar's-mate frame (Qxf7#)","fen":b.fen(),"best_move_uci":"h5f7","best_move_san":"Qxf7#","top_moves_uci":None}

def seed_fools_mate_frame():
    b = chess.Board()
    for mv in ["f2f3","e7e5","g2g4"]: b.push(chess.Move.from_uci(mv))
    qh4 = chess.Move.from_uci("d8h4"); assert qh4 in b.legal_moves
    b.push(qh4); assert b.is_checkmate(); b.pop()
    return {"group":"MATE_IN_ONE","title":"Fool's-mate frame (Qh4# for Black)","fen":b.fen(),"best_move_uci":"d8h4","best_move_san":"Qh4#","top_moves_uci":None}

def seed_back_rank_mate_white():
    b = chess.Board.empty(); b.clear_board()
    b.set_piece_at(chess.G8, chess.Piece(chess.KING, chess.BLACK))
    b.set_piece_at(chess.G7, chess.Piece(chess.PAWN, chess.BLACK))
    b.set_piece_at(chess.H7, chess.Piece(chess.PAWN, chess.BLACK))
    b.set_piece_at(chess.F7, chess.Piece(chess.PAWN, chess.BLACK))
    b.set_piece_at(chess.E8, chess.Piece(chess.ROOK, chess.BLACK))
    b.set_piece_at(chess.E1, chess.Piece(chess.ROOK, chess.WHITE))
    b.set_piece_at(chess.G1, chess.Piece(chess.KING, chess.WHITE))
    b.turn = chess.WHITE
    mv = chess.Move.from_uci("e1e8"); assert mv in b.legal_moves
    b.push(mv); assert b.is_checkmate(); b.pop()
    return {"group":"MATE_IN_ONE","title":"Back-rank motif (Re8#)","fen":b.fen(),"best_move_uci":"e1e8","best_move_san":"Re8#","top_moves_uci":None}

def try_single_legal_block():
    candidates = [
        [("Q","e6"),("N","d6"),("N","f6")],
        [("Q","e6"),("B","c4"),("B","g4")],
    ]
    PN = {'Q': chess.QUEEN, 'R': chess.ROOK, 'B': chess.BISHOP, 'N': chess.KNIGHT}
    for helpers in candidates:
        b = chess.Board.empty(); b.clear_board()
        b.set_piece_at(chess.E8, chess.Piece(chess.KING, chess.BLACK))
        b.set_piece_at(chess.E1, chess.Piece(chess.ROOK, chess.WHITE))
        b.set_piece_at(chess.C5, chess.Piece(chess.BISHOP, chess.BLACK))
        b.set_piece_at(chess.G1, chess.Piece(chess.KING, chess.WHITE))
        for sym, sq in helpers:
            piece = chess.Piece(PN[sym], chess.WHITE)
            b.set_piece_at(chess.SQUARE_NAMES.index(sq), piece)
        b.turn = chess.BLACK
        legals = list(b.legal_moves)
        if len(legals) == 1:
            only = legals[0]
            return {"group":"SINGLE_LEGAL_MOVE","title":"Single legal block","fen":b.fen(),
                    "best_move_uci":only.uci(),"best_move_san":chess.Board(b.fen()).san(only),"top_moves_uci":None}
    return None

def try_single_legal_corner():
    layouts = [
        [("R","a8"),("Q","g6"),("B","f7")],
        [("R","b8"),("Q","g6"),("B","e7")],
    ]
    PN = {'Q': chess.QUEEN, 'R': chess.ROOK, 'B': chess.BISHOP, 'N': chess.KNIGHT}
    for helpers in layouts:
        b = chess.Board.empty(); b.clear_board()
        b.set_piece_at(chess.H8, chess.Piece(chess.KING, chess.BLACK))
        b.set_piece_at(chess.G1, chess.Piece(chess.KING, chess.WHITE))
        for sym, sq in helpers:
            piece = chess.Piece(PN[sym], chess.WHITE)
            b.set_piece_at(chess.SQUARE_NAMES.index(sq), piece)
        b.turn = chess.BLACK
        legals = list(b.legal_moves)
        if len(legals) == 1:
            only = legals[0]
            return {"group":"SINGLE_LEGAL_MOVE","title":"Corner squeeze (one king move)","fen":b.fen(),
                    "best_move_uci":only.uci(),"best_move_san":chess.Board(b.fen()).san(only),"top_moves_uci":None}
    return None

def seed_single_king_only():
    b = chess.Board()
    b.clear_board()
    b.set_piece_at(chess.H1, chess.Piece(chess.KING, chess.WHITE))
    b.set_piece_at(chess.G3, chess.Piece(chess.ROOK, chess.BLACK))
    b.set_piece_at(chess.F3, chess.Piece(chess.ROOK, chess.BLACK))
    b.set_piece_at(chess.F2, chess.Piece(chess.ROOK, chess.BLACK))
    b.turn = chess.WHITE
    assert len(list(b.legal_moves)) == 1
    mv = next(iter(b.legal_moves))
    return {
        "group": "SINGLE_LEGAL_MOVE",
        "title": "Trapped king only one move",
        "fen": b.fen(),
        "best_move_uci": mv.uci(),
        "best_move_san": b.san(mv),
        "top_moves_uci": None,
    }

def build_hardcoded_seeds():
    seeds = []
    for maker in [seed_scholars_mate_frame, seed_fools_mate_frame, seed_back_rank_mate_white, seed_single_king_only]:
        try:
            s = maker()
            b = chess.Board(s["fen"]); mv = chess.Move.from_uci(s["best_move_uci"])
            if "#" in (s.get("best_move_san") or ""):
                assert mv in b.legal_moves; b.push(mv); assert b.is_checkmate(); b.pop()
            else:
                assert mv in b.legal_moves
            seeds.append(s)
        except Exception as e:
            warnings.warn(f"Skipping seed '{maker.__name__}': {e}")
    for builder in [try_single_legal_block, try_single_legal_corner]:
        try:
            s = builder()
            if s:
                b = chess.Board(s["fen"]); assert sum(1 for _ in b.legal_moves) == 1
                seeds.append(s)
        except Exception as e:
            warnings.warn(f"Skipping seed '{builder.__name__}': {e}")
    return seeds

# -------------------- main(args) --------------------
def main(args):
    use_chat_template = not args.no_chat_template
    show_progress = (not args.quiet) and (tqdm is not None)
    rng = random.Random(args.seed)

    out_root = Path("eval_results") / f"{now_tag()}_{sanitize(args.model)}"
    (out_root / "plots").mkdir(parents=True, exist_ok=True)
    (out_root / "config.json").write_text(json.dumps({
        **vars(args),
        "use_chat_template": use_chat_template,
        "progress": show_progress,
    }, indent=2), encoding="utf-8")

    seeds = build_hardcoded_seeds()
    seeds_count = len(seeds)
    data_rows, hf_stats = load_hf(
        args.hf_dataset,
        args.split,
        n=2*args.num_examples,
        seed=args.seed,
        shuffle=True,
        show_progress=show_progress,
        log_stats=not args.quiet,
    )

    if not args.stockfish_path:
        raise RuntimeError("Stockfish is required because the dataset has FEN only. "
                           "Set args.stockfish_path to a valid engine binary.")
    res = with_stockfish(args.stockfish_path, multipv=max(1, args.engine_multipv), movetime_ms=args.engine_movetime_ms)
    if not res:
        raise RuntimeError("Failed to start Stockfish. Check stockfish_path.")
    engine, close_engine = res

    it = data_rows
    if show_progress and tqdm is not None:
        it = tqdm(data_rows, desc="Engine-label dataset", leave=False)
    labeled = 0
    for ex in it:
        if not ex.get("best_move_uci"):
            info = engine(ex["fen"]) or {}
            top = [m.lower() for m in info.get("engine_top_moves_uci", [])] if info else []
            best = (info.get("engine_best_uci") or "").lower() if info else ""
            if best:
                ex["best_move_uci"] = best
                try:
                    ex["best_move_san"] = chess.Board(ex["fen"]).san(chess.Move.from_uci(best))
                except Exception:
                    ex["best_move_san"] = None
                ex["top_moves_uci"] = top or None
                labeled += 1
    if not args.quiet:
        print(f"[ENGINE-LABEL] Added labels to {labeled} dataset rows (MultiPV={args.engine_multipv}, movetime_ms={args.engine_movetime_ms}).")

    seen = {s["fen"] for s in seeds}
    remainder = [r for r in data_rows if r["fen"] not in seen and r.get("best_move_uci")]
    take_more = max(0, args.num_examples - len(seeds))
    rows = seeds + remainder[:take_more]
    (out_root / "config.json").write_text(json.dumps({
        **json.loads((out_root / "config.json").read_text(encoding="utf-8")),
        "hf_stats": hf_stats,
        "engine_labelled": labeled,
        "seeds_used": len(seeds),
        "remainder_available": len(remainder),
        "rows_final": len(rows),
    }, indent=2), encoding="utf-8")
    if not args.quiet:
        print(f"[DATA] seeds={len(seeds)}, remainder_available={len(remainder)}, taking={take_more}, total_rows={len(rows)}")

    model, tokenizer = load_model(args.model, dtype=args.dtype, load_in_4bit=args.load_in_4bit, device_map=args.device_map)

    k_list = []
    for t in str(args.k_values).split(","):
        t = t.strip()
        if not t:
            continue
        if t.lower() == "none":
            k_list.append(None)
        elif t.upper() == "ALL":
            k_list.append("ALL")
        else:
            k_list.append(int(t))

    # ---- Temperature sweep parse ----
    temp_list = []
    for t in str(args.temperatures).split(","):
        t = t.strip()
        if not t:
            continue
        temp_list.append(float(t))
    if not temp_list:
        temp_list = [float(args.temperature)]  # fallback

    # ---- VariÃ¡nsok: csak example=False + temperature sweep ----
    variants = []
    for move_format in ("SAN", "UCI"):
        for include_reasoning in (False, True):
            include_example = True   # FORCE: nincs example
            # for K in k_list:
            for temp in temp_list:
                variants.append({
                    "move_format": move_format,
                    "include_reasoning": include_reasoning,
                    "include_example": include_example,
                    "temperature": temp,
                    "use_xml": True,
                })
    logs_path = out_root / "prompts_and_answers.jsonl"
    logs = logs_path.open("w", encoding="utf-8")
    agg = []
    total_variants = len(variants)

    for vidx, v in enumerate(variants, start=1):
        if not args.quiet:
            color = Fore.CYAN if v["move_format"] == "SAN" else Fore.MAGENTA
            print(f"{color}[{vidx}/{total_variants}] {variant_label(v)}{Style.RESET_ALL}")

        m = {"total":0, "valid":0, "correct":0, "topk_hit":0,
             "engine_topk_hit":0, "engine_best_match":0}

        iterator = rows
        if show_progress:
            iterator = tqdm(rows, desc="examples", leave=False)

        for ex in iterator:
            fen = ex["fen"]; tgt = ex["best_move_uci"]
            cand = get_candidate_list(fen, tgt, v["move_format"], v["K"], rng,
                                       cand_mode=args.cand_mode, dataset_topk=ex.get("top_moves_uci")) if v["K"] is not None else None
            prompt = build_prompt(fen, v["move_format"], v["include_reasoning"], v["include_example"], v["use_xml"])
            out = generate(model, tokenizer, prompt, max_new_tokens=args.max_new_tokens, temperature=v.get("temperature", args.temperature),
                           top_p=args.top_p, repetition_penalty=args.repetition_penalty, use_chat_template=use_chat_template)
            ev = eval_one(ex, v, out)

            info = engine(fen)
            top = [muv.lower() for muv in info.get("engine_top_moves_uci", [])] if info else []
            best = (info.get("engine_best_uci") or "").lower() if info else ""
            if ev["pred_move"] and ev["valid"]:
                pred_uci = ev["pred_move"] if v["move_format"].upper()=="UCI" else san_to_uci(fen, ev["pred_move"])
                if pred_uci:
                    if top and pred_uci in top: m["engine_topk_hit"] += 1
                    if best and pred_uci == best: m["engine_best_match"] += 1

            m["total"] += 1
            m["valid"] += int(ev["valid"])
            m["correct"] += int(ev["correct"])
            m["topk_hit"] += int(ev["topk_hit"])

            logs.write(json.dumps({
                "variant": v, "group": ex.get("group"), "title": ex.get("title"),
                "fen": fen, "target_uci": tgt, "target_san": ex.get("best_move_san"),
                "prompt": prompt, "answer": out, "eval": ev,
                "complexity": classify_position_complexity(fen) ,
                "engine": {"engine_top_moves_uci": top, "engine_best_uci": best}
            }) + "\n")

            # Ensure data is written to disk immediately to minimize corrupt lines on crash
            logs.flush()

        total = max(1, m["total"])
        agg.append({
            "variant_id": vidx-1, **v, "total": m["total"],
            "valid_rate": m["valid"]/total,
            "correct_rate": m["correct"]/total,
            "topk_hit_rate": m["topk_hit"]/total,
            "engine_topk_hit_rate": (m["engine_topk_hit"]/total),
            "engine_best_match_rate": (m["engine_best_match"]/total),
        })

    logs.close()
    if not args.quiet:
        print("[ANALYSIS] Analyzing illegal move patterns...")
    illegal_by_type, illegal_by_variant = analyze_illegal_patterns(logs_path)
    (out_root / "illegal_analysis.json").write_text(
        json.dumps({"by_type": illegal_by_type, "by_variant": illegal_by_variant}, indent=2),
        encoding="utf-8"
    )
    # # ADD THIS BLOCK
    if not args.quiet:
        print("[ANALYSIS] Computing stratified statistics...")
    strata_results = aggregate_by_strata(logs_path)
    (out_root / "strata_analysis.json").write_text(
        json.dumps(strata_results, indent=2),
        encoding="utf-8"
    )
    if 'close_engine' in locals() and close_engine: close_engine()

    out_root.joinpath("results.json").write_text(json.dumps(agg, indent=2), encoding="utf-8")
    pd.DataFrame(agg).to_csv(out_root / "summary.csv", index=False)

    lines = [
        "# Chess LM Eval Report",
        f"- Model: `{args.model}`",
        f"- HF dataset: `{args.hf_dataset}` split `{args.split}`",
        f"- Seeds included: {seeds_count}",
        f"- N evaluated: {len(rows)}",
        f"- Variants: {len(variants)}",
        "",
        "## Summary",
        "| id | format | reasoning | T | N | valid | correct | topk_hit | eng_topk | eng_best |",
        "|---|---|---|---|---|---|---|---|---|---|---|",
    ]

    # -------------------- Difficulty Analysis Report Block --------------------
    lines.extend([
        "",
        "## Difficulty Analysis: SAN vs UCI",
        "Breakdown of Legal Move Rate (Valid%) by Board Type.",
        "",
        "| Board Type | Category | SAN Valid% | UCI Valid% | SAN Count | UCI Count |",
        "|---|---|---|---|---|---|"
    ])

    # Define the specific board types we want to compare
    comparison_targets = [
        ("Game Phase", "phase", ["opening", "middlegame", "endgame"]),
        ("Legal Options", "move_bucket", ["very_few", "few", "medium", "many"]),
        ("Check State", "in_check", ["True", "False"]),
    ]

    for label, key, values in comparison_targets:
        for val in values:
            stratum = f"{key}={val}"

            # Fetch SAN stats
            san_key = f"SAN|{stratum}"
            san_stats = strata_results.get(san_key, {"valid_rate": 0, "total": 0})

            # Fetch UCI stats
            uci_key = f"UCI|{stratum}"
            uci_stats = strata_results.get(uci_key, {"valid_rate": 0, "total": 0})

            # Only add row if we actually have data
            if san_stats['total'] > 0 or uci_stats['total'] > 0:
                lines.append(
                    f"| {label} | {val} | "
                    f"{san_stats['valid_rate']:.1%} | {uci_stats['valid_rate']:.1%} | "
                    f"{san_stats['total']} | {uci_stats['total']} |"
                )
    # --------------------------------------------------------------------------

    lines.extend([
        "",
        "## Illegal Move Patterns",
        "| Error Type | Count |",
        "|---|---|"
    ])
    for err_type, count in sorted(illegal_by_type.items(), key=lambda x: -x[1]):
        lines.append(f"| {err_type} | {count} |")

    for r in agg:
        mode = "normal" if r["include_reasoning"] else "legal"
        lines.append(
            f"| {r['variant_id']} | {r['move_format']} | {mode} | {r.get('temperature')} | {r['total']} | "
            f"{r['valid_rate']:.3f} | {r['correct_rate']:.3f} | {r['topk_hit_rate']:.3f} | "
            f"{r['engine_topk_hit_rate']:.3f} | {r['engine_best_match_rate']:.3f} |"
        )

    out_root.joinpath("report.md").write_text("\n".join(lines), encoding="utf-8")

    def bar(metric, title, fname):
        vals = [r[metric] if r[metric] is not None else 0.0 for r in agg]
        labels = []
        for r in agg:
            mode = "normal" if r["include_reasoning"] else "legal"
            labels.append(f"{r['move_format']}/R={mode}/T={r.get('temperature')}")

        plt.figure(figsize=(max(8, len(vals)*0.6), 5))
        plt.bar(range(len(vals)), vals)
        plt.xticks(range(len(vals)), labels, rotation=45, ha="right")
        plt.ylabel(metric); plt.title(title); plt.tight_layout()
        plt.savefig(out_root / "plots" / fname); plt.close()

    bar("valid_rate", "Legal move rate by variant", "valid_rate.png")
    bar("correct_rate", "Correctness by Variant","correct_rate.png")
    bar("engine_topk_hit_rate","Engine Top-K Hit by Variant","engine_topk_hit_rate.png")
    bar("engine_best_match_rate","Engine Best-Move Match by Variant","engine_best_match_rate.png")

    if not args.quiet:
        print(f"Done -> {out_root.resolve()}")

# --------- ARGS FOR COLAB (EDIT THESE) ----------
args = SimpleNamespace(
    hf_dataset="czovekboti/chessdata",
    split="train",
    num_examples=number_of_examples,
    model=model_path,
    dtype="float16",
    load_in_4bit=False,
    device_map="auto",
    max_new_tokens=2048,
    temperature=0.2,
    top_p=0.95,
    repetition_penalty=1.05,
    k_values="none,2,5,10,ALL",
    cand_mode="target_plus_random",
    no_chat_template=False,
    quiet=False,
    stockfish_path="/usr/games/stockfish",
    engine_multipv=5,
    engine_movetime_ms=30,
    seed=42,
    temperatures="0.0,0.2,0.4",
)

main(args)

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ERROR 12-02 16:59:31 [fa_utils.py:64] Cannot use FA version 2 is not supported due to FA2 is only supported on devices with compute capability >= 8
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!




[HF] Loading dataset 'czovekboti/chessdata' split 'train'


chessData.csv:   0%|          | 0.00/795M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12958035 [00:00<?, ? examples/s]

[HF] Shuffled with seed=42
[HF] Selected first 20 rows (requested n=20)




[HF] Summary: raw=12958035, after_select=20, kept=20 (100.0%)




[ENGINE-LABEL] Added labels to 20 dataset rows (MultiPV=5, movetime_ms=30).
[DATA] seeds=3, remainder_available=20, taking=7, total_rows=10
==((====))==  Unsloth 2025.11.6: Fast Qwen2 patching. Transformers: 4.57.2. vLLM: 0.11.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.88G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.09G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.33G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]