In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
import csv, chess, chess.syzygy
from typing import Optional, Iterable, Tuple, List, Set
import csv, itertools, os

## k vs k positions

In [None]:
def kings_ok(wk, bk):
    ax, ay = wk % 8, wk // 8
    bx, by = bk % 8, bk // 8
    return max(abs(ax - bx), abs(ay - by)) > 1  # not adjacent

def board_from(wk, bk, white_to_move=True):
    b = chess.Board(None)
    b.set_piece_at(wk, chess.Piece(chess.KING, chess.WHITE))
    b.set_piece_at(bk, chess.Piece(chess.KING, chess.BLACK))
    b.turn = chess.WHITE if white_to_move else chess.BLACK
    b.clear_stack()
    return b


# choose one:
WHITE_ONLY = False   # set False to include both sides to move

rows = []
for wk in range(64):
    for bk in range(64):
        if bk == wk or not kings_ok(wk, bk):
            continue
        sides = (True,) if WHITE_ONLY else (True, False)
        for stm in sides:
            b = board_from(wk, bk, stm)
            if not b.is_valid():  # should already be valid with the filter
                continue
            # WDL=0 (draw), DTZ not meaningful; write None or 0
            rows.append([b.fen(), 0, None])

with open("kvk.csv", "w", newline="") as f:
    w = csv.writer(f)
    w.writerow(["fen", "wdl", "dtz"])
    w.writerows(rows)

print("Wrote kvk.csv")


# generate endgames `.csv` from tb

In [2]:
def generate_endgame_csv(
    endgame: str,                      # "KRK", "KQK", or "KBBK"
    tb_dir: str,                       # path to Syzygy tables
    out_csv: str,
    both_sides_to_move: bool = False,  # if True, write w/b to move; else only White to move
    include_kbk_from_kbbk: bool = True,# only applies to KBBK
    opposite_colors_only: bool = False,# only applies to KBBK
    dedupe_kbk: bool = True,           # dedupe KBK FENs when include_kbk_from_kbbk=True
    limit: Optional[int] = None,       # cap number of *main* endgame positions for testing
    preview_rows: int = 0              # return up to this many rows as a DataFrame preview
) -> Optional[pd.DataFrame]:
    """
    Generate KRK / KQK / KBBK positions, probe WDL & DTZ (plies) from Syzygy,
    and stream to CSV with columns: fen, side, wdl, dtz, endgame, kind.
      - endgame: "KRK"|"KQK"|"KBBK"
      - kind: "main" for the requested endgame, "KBK" for derived single-bishop positions
    """
    endgame = endgame.upper()
    if endgame not in {"KRK", "KQK", "KBBK"}:
        raise ValueError("endgame must be one of {'KRK','KQK','KBBK'}")

    def kings_ok(wk: int, bk: int) -> bool:
        ax, ay = wk % 8, wk // 8
        bx, by = bk % 8, bk // 8
        return max(abs(ax - bx), abs(ay - by)) > 1  # not adjacent

    def make_board(side_white_to_move: bool) -> chess.Board:
        b = chess.Board(None)
        b.turn = chess.WHITE if side_white_to_move else chess.BLACK
        b.clear_stack()  # halfmove=0, fullmove=1
        return b

    # bishop-square colors (for optional opposite-color filter)
    dark: Set[int] = {s for s in range(64) if (s % 8 + s // 8) % 2}
    light: Set[int] = set(range(64)) - dark

    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
    preview: List[Tuple[str, str, int, int, str, str]] = []
    seen_kbk: Set[str] = set() if (endgame == "KBBK" and include_kbk_from_kbbk and dedupe_kbk) else set()

    total_main = 0
    sides = (True, False) if both_sides_to_move else (True,)

    with chess.syzygy.open_tablebase(tb_dir) as tb, open(out_csv, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["fen", "side", "wdl", "dtz"])

        for wk in range(64):
            for bk in range(64):
                if bk == wk or not kings_ok(wk, bk):
                    continue

                if endgame == "KRK":
                    it_pieces = ((wr,) for wr in range(64) if wr not in (wk, bk))
                    place_piece = lambda B, xs: B.set_piece_at(xs[0], chess.Piece(chess.ROOK, chess.WHITE))
                elif endgame == "KQK":
                    it_pieces = ((wq,) for wq in range(64) if wq not in (wk, bk))
                    place_piece = lambda B, xs: B.set_piece_at(xs[0], chess.Piece(chess.QUEEN, chess.WHITE))
                else:  # KBBK
                    it_pieces = (pair for pair in itertools.combinations(range(64), 2)
                                 if pair[0] not in (wk, bk) and pair[1] not in (wk, bk) and
                                    (not opposite_colors_only or
                                     ((pair[0] in dark) != (pair[1] in dark))))
                    def place_piece(B, xs):
                        b1, b2 = xs
                        B.set_piece_at(b1, chess.Piece(chess.BISHOP, chess.WHITE))
                        B.set_piece_at(b2, chess.Piece(chess.BISHOP, chess.WHITE))

                for xs in it_pieces:
                    for white_to_move in sides:
                        B = make_board(white_to_move)
                        B.set_piece_at(wk, chess.Piece(chess.KING, chess.WHITE))
                        B.set_piece_at(bk, chess.Piece(chess.KING, chess.BLACK))
                        place_piece(B, xs)
                        if not B.is_valid():
                            continue
                        try:
                            wdl = tb.probe_wdl(B)
                            dtz = tb.probe_dtz(B)
                            row = (B.fen(), "w" if white_to_move else "b", wdl, dtz)
                            w.writerow(row)
                            if preview_rows and len(preview) < preview_rows:
                                preview.append(row)
                            total_main += 1
                        except KeyError:
                            pass

                        if limit is not None and total_main >= limit:
                            return pd.DataFrame(preview, columns=["fen","side","wdl","dtz"]) if preview_rows else None

                    # Only for KBBK: also emit KBK rows (remove one bishop each)
                    if endgame == "KBBK" and include_kbk_from_kbbk:
                        b1, b2 = xs
                        for keep in (b1, b2):
                            for white_to_move in sides:
                                Bk = make_board(white_to_move)
                                Bk.set_piece_at(wk, chess.Piece(chess.KING, chess.WHITE))
                                Bk.set_piece_at(bk, chess.Piece(chess.KING, chess.BLACK))
                                Bk.set_piece_at(keep, chess.Piece(chess.BISHOP, chess.WHITE))
                                if not Bk.is_valid():
                                    continue
                                fen_k = Bk.fen()
                                if seen_kbk and fen_k in seen_kbk:
                                    continue
                                try:
                                    wdl_k = tb.probe_wdl(Bk)
                                    dtz_k = tb.probe_dtz(Bk)
                                    rowk = (fen_k, "w" if white_to_move else "b", wdl_k, dtz_k)
                                    w.writerow(rowk)
                                    if preview_rows and len(preview) < preview_rows:
                                        preview.append(rowk)
                                    if seen_kbk is not None:
                                        seen_kbk.add(fen_k)
                                except KeyError:
                                    pass

    return pd.DataFrame(preview, columns=["fen","side","wdl","dtz"]) if preview_rows else None


In [4]:
krk_preview = generate_endgame_csv(
    endgame="KRK",
    tb_dir="./krk/",
    out_csv="./krk/krk_full.csv",
    both_sides_to_move=True,
)

In [None]:
# KQK, both sides to move
kqk_preview = generate_endgame_csv(
    endgame="KQK",
    tb_dir="./kqk/",
    out_csv="kqk_full.csv",
    both_sides_to_move=True,
)

In [None]:
# KBBK + KBK, opposite-colored bishops only, both sides to move
kbbk_preview = generate_endgame_csv(
    endgame="KBBK",
    tb_dir="./kbbk/",
    out_csv="./kbbk/kbbk_plus_kbk.csv",
    both_sides_to_move=True,
    include_kbk_from_kbbk=True,
    opposite_colors_only=True,
    dedupe_kbk=True,
)

## train-validate-test split

In [10]:
def split_data(endgame):
	full_endgame_df = pd.read_csv(f"{endgame}_full.csv")
	white_to_move_df = full_endgame_df[full_endgame_df['side'].str.contains('w')]
	
	try:
		# Try with stratification first
		train_df, temp_df = train_test_split(
			white_to_move_df, 
			test_size=0.2, 
			random_state=42, 
			stratify=white_to_move_df['dtz']
		)
		
		val_df, test_df = train_test_split(
			temp_df, 
			test_size=0.5, 
			random_state=42, 
			stratify=temp_df['dtz']
		)
	except ValueError as e:
		print(f"Warning: Stratification failed: {e}")
		print("Falling back to non-stratified split.")
		
		# Fall back to non-stratified split
		train_df, temp_df = train_test_split(
			white_to_move_df, 
			test_size=0.2, 
			random_state=42
		)
		
		val_df, test_df = train_test_split(
			temp_df, 
			test_size=0.5, 
			random_state=42
		)

	train_df.to_csv(f"{endgame}_train.csv", index=False)
	val_df.to_csv(f"{endgame}_val.csv", index=False)
	test_df.to_csv(f"{endgame}_test.csv", index=False)

	# Print dataset sizes
	print(f"Train set size: {len(train_df)} ({len(train_df)/len(white_to_move_df)*100:.2f}%)")
	print(f"Validation set size: {len(val_df)} ({len(val_df)/len(white_to_move_df)*100:.2f}%)")
	print(f"Test set size: {len(test_df)} ({len(test_df)/len(white_to_move_df)*100:.2f}%)")

	# Verify DTZ distribution is similar across splits
	print("\nDTZ distribution sample:")
	for i, dtz_value in enumerate(sorted(white_to_move_df['dtz'].unique())):
		orig_pct = (white_to_move_df['dtz'] == dtz_value).mean() * 100
		train_pct = (train_df['dtz'] == dtz_value).mean() * 100
		val_pct = (val_df['dtz'] == dtz_value).mean() * 100
		test_pct = (test_df['dtz'] == dtz_value).mean() * 100
		print(f"DTZ {dtz_value}: Original={orig_pct:.2f}%, Train={train_pct:.2f}%, Val={val_pct:.2f}%, Test={test_pct:.2f}%")

In [11]:
split_data("./krk/krk")

Train set size: 140134 (80.00%)
Validation set size: 17517 (10.00%)
Test set size: 17517 (10.00%)

DTZ distribution sample:
DTZ 1: Original=0.86%, Train=0.86%, Val=0.86%, Test=0.86%
DTZ 3: Original=2.67%, Train=2.67%, Val=2.67%, Test=2.67%
DTZ 5: Original=2.20%, Train=2.20%, Val=2.20%, Test=2.20%
DTZ 7: Original=1.08%, Train=1.08%, Val=1.08%, Test=1.08%
DTZ 9: Original=2.77%, Train=2.77%, Val=2.77%, Test=2.77%
DTZ 11: Original=4.97%, Train=4.97%, Val=4.97%, Test=4.97%
DTZ 13: Original=6.46%, Train=6.46%, Val=6.46%, Test=6.46%
DTZ 15: Original=9.80%, Train=9.80%, Val=9.81%, Test=9.80%
DTZ 17: Original=11.47%, Train=11.47%, Val=11.47%, Test=11.47%
DTZ 19: Original=10.86%, Train=10.86%, Val=10.85%, Test=10.86%
DTZ 21: Original=11.69%, Train=11.69%, Val=11.69%, Test=11.69%
DTZ 23: Original=12.26%, Train=12.26%, Val=12.26%, Test=12.26%
DTZ 25: Original=10.18%, Train=10.18%, Val=10.17%, Test=10.18%
DTZ 27: Original=9.21%, Train=9.21%, Val=9.21%, Test=9.21%
DTZ 29: Original=2.99%, Train=2.99%

In [12]:
split_data("./kqk/kqk")

Falling back to non-stratified split.
Train set size: 115606 (80.00%)
Validation set size: 14451 (10.00%)
Test set size: 14451 (10.00%)

DTZ distribution sample:
DTZ 1: Original=1.69%, Train=1.69%, Val=1.67%, Test=1.76%
DTZ 3: Original=3.47%, Train=3.49%, Val=3.22%, Test=3.55%
DTZ 5: Original=6.27%, Train=6.27%, Val=6.25%, Test=6.28%
DTZ 7: Original=13.82%, Train=13.77%, Val=14.27%, Test=13.71%
DTZ 9: Original=18.11%, Train=18.11%, Val=17.85%, Test=18.30%
DTZ 11: Original=22.19%, Train=22.27%, Val=21.98%, Test=21.76%
DTZ 13: Original=22.22%, Train=22.25%, Val=21.90%, Test=22.29%
DTZ 15: Original=10.38%, Train=10.30%, Val=10.95%, Test=10.48%
DTZ 17: Original=1.85%, Train=1.85%, Val=1.90%, Test=1.85%
DTZ 19: Original=0.01%, Train=0.01%, Val=0.00%, Test=0.01%


In [15]:
split_data("./kbbk/kbbk_plus_kbk")

Train set size: 96195 (80.00%)
Validation set size: 12024 (10.00%)
Test set size: 12025 (10.00%)

DTZ distribution sample:
DTZ 0: Original=7.27%, Train=7.27%, Val=7.27%, Test=7.27%
DTZ 1: Original=0.31%, Train=0.31%, Val=0.32%, Test=0.31%
DTZ 2: Original=0.10%, Train=0.10%, Val=0.11%, Test=0.10%
DTZ 4: Original=0.41%, Train=0.41%, Val=0.41%, Test=0.41%
DTZ 6: Original=0.14%, Train=0.14%, Val=0.14%, Test=0.13%
DTZ 8: Original=0.19%, Train=0.19%, Val=0.19%, Test=0.18%
DTZ 10: Original=0.47%, Train=0.47%, Val=0.47%, Test=0.47%
DTZ 12: Original=0.66%, Train=0.66%, Val=0.66%, Test=0.66%
DTZ 14: Original=1.17%, Train=1.17%, Val=1.16%, Test=1.17%
DTZ 16: Original=1.75%, Train=1.75%, Val=1.75%, Test=1.75%
DTZ 18: Original=3.69%, Train=3.69%, Val=3.69%, Test=3.69%
DTZ 20: Original=6.09%, Train=6.09%, Val=6.09%, Test=6.10%
DTZ 22: Original=7.25%, Train=7.25%, Val=7.24%, Test=7.25%
DTZ 24: Original=10.45%, Train=10.45%, Val=10.45%, Test=10.45%
DTZ 26: Original=13.99%, Train=13.99%, Val=14.00%, Te