In [10]:
import os, random
import numpy as np
import pandas as pd
from collections import Counter
from tqdm import tqdm
import chess
import chess.pgn
import tensorflow as tf
from tensorflow.keras import layers, models


In [None]:
CSV_PATH = "games.csv"   
STOCKFISH_PATH = None                 

TOP_N = 256      
TEST_SPLIT = 0.20
RANDOM_SEED = 42
EPOCHS = 12      
BATCH_SIZE = 64
INPUT_PLANES = 12

random.seed(RANDOM_SEED); np.random.seed(RANDOM_SEED); tf.random.set_seed(RANDOM_SEED)

In [None]:
# piece -> plane index
piece_to_plane = {
    'P':0,'N':1,'B':2,'R':3,'Q':4,'K':5,
    'p':6,'n':7,'b':8,'r':9,'q':10,'k':11
}

def fen_to_planes(fen):
    """Convert a FEN to 8x8x12 binary planes (uint8)."""
    board = chess.Board(fen)
    planes = np.zeros((8,8,INPUT_PLANES), dtype=np.uint8)
    for sq in chess.SQUARES:
        piece = board.piece_at(sq)
        if piece:
            plane = piece_to_plane[piece.symbol()]
            r = 7 - chess.square_rank(sq)  
            c = chess.square_file(sq)
            planes[r, c, plane] = 1
    return planes

def san_moves_to_uci_list(san_moves_str):
    """
    Convert a SAN or UCI-like space-separated move string (or PGN without headers) into a UCI list.
    Returns list of UCI strings or None if parsing fails.
    """
    s = str(san_moves_str).strip()
    if s == "" or pd.isna(s):
        return None
    toks = s.split()
    cleaned = []
    for tok in toks:
        if '.' in tok and tok.split('.')[0].isdigit():
            tok = tok.split('.',1)[1]
        if tok in ('1-0','0-1','1/2-1/2','*'):
            continue
        cleaned.append(tok)
    toks = cleaned
    board = chess.Board()
    uci_list = []
    for tok in toks:
        # try UCI parse first if looks like UCI (e2e4 or e7e8q)
        tried_uci = False
        if len(tok) in (4,5) and tok[0].isalpha() and tok[1].isdigit():
            try:
                m = chess.Move.from_uci(tok)
                if m in board.legal_moves:
                    board.push(m); uci_list.append(m.uci()); tried_uci = True
            except Exception:
                tried_uci = False
        if tried_uci:
            continue
        # try SAN parse
        try:
            move = board.parse_san(tok)
            uci_list.append(move.uci())
            board.push(move)
        except Exception:
            return None
    return uci_list

# Quick test
print("san->uci example:", san_moves_to_uci_list("e4 e5 Nf3 Nc6"))


san->uci example: ['e2e4', 'e7e5', 'g1f3', 'b8c6']


In [None]:
# load CSV
df = pd.read_csv(CSV_PATH, dtype=str, low_memory=False)
print("Loaded rows:", len(df))
print("Columns in CSV:", df.columns.tolist())

moves_col = None
for candidate in ["moves","Moves","pgn","PGN","moves_pgn","uci_moves","moves_list"]:
    if candidate in df.columns:
        moves_col = candidate
        break
if moves_col is None:
    for c in df.columns:
        sample = df[c].astype(str).head(20).str.cat(sep=' ')
        if '1.' in sample or 'e4' in sample or 'Nf3' in sample:
            moves_col = c
            break
if moves_col is None:
    raise SystemExit("Could not find moves column. Please set CSV_PATH to your file and ensure it has a moves column.")

print("Using moves column:", moves_col)

# parse and create examples
examples = []
skipped_rows = 0
rows = df[moves_col].astype(str).tolist()
for raw in tqdm(rows, desc="Parsing moves -> examples"):
    uci_list = san_moves_to_uci_list(raw)
    if not uci_list:
        skipped_rows += 1
        continue
    board = chess.Board()
    for u in uci_list:
        fen_before = board.fen()
        examples.append((fen_before, u))
        board.push(chess.Move.from_uci(u))

print("Total examples extracted (plys):", len(examples))
print("Rows skipped due to parse errors:", skipped_rows)


Loaded rows: 20058
Columns in CSV: ['id', 'rated', 'created_at', 'last_move_at', 'turns', 'victory_status', 'winner', 'increment_code', 'white_id', 'white_rating', 'black_id', 'black_rating', 'moves', 'opening_eco', 'opening_name', 'opening_ply']
Using moves column: moves


Parsing moves -> examples: 100%|██████████| 20058/20058 [01:07<00:00, 296.35it/s]

Total examples extracted (plys): 1212827
Rows skipped due to parse errors: 0





In [None]:
# build move freq and pick top N
move_counts = Counter([m for (_,m) in examples])
print("Unique moves seen:", len(move_counts))
top_moves = [m for m,_ in move_counts.most_common(TOP_N)]
move2idx = {m:i for i,m in enumerate(top_moves)}
print("Top moves size:", len(top_moves))

# filter examples whose move is in top set
filtered = [(f,m) for f,m in examples if m in move2idx]
print("Examples after filtering to top-N moves:", len(filtered))
if len(filtered) < 50:
    print("WARNING: too few examples after filtering. Consider lowering TOP_N or using more CSV rows.")

# build arrays 
MAX_EXAMPLES = 20000  
if len(filtered) > MAX_EXAMPLES:
    filtered = random.sample(filtered, MAX_EXAMPLES)

X = np.stack([fen_to_planes(fen) for fen,_ in filtered])
y = np.array([move2idx[m] for _,m in filtered], dtype=np.int32)

# shuffle + train/test split
idxs = np.arange(len(X)); np.random.shuffle(idxs)
split_idx = int(len(X)*(1-TEST_SPLIT))
train_idx = idxs[:split_idx]; test_idx = idxs[split_idx:]

X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y[train_idx], y[test_idx]

print("Train size:", len(X_train), "Test size:", len(X_test), "Vocab:", len(top_moves))


Unique moves seen: 1885
Top moves size: 256
Examples after filtering to top-N moves: 759738
Train size: 16000 Test size: 4000 Vocab: 256


In [15]:
input_shape = (8,8,INPUT_PLANES)
num_classes = len(top_moves)

model = models.Sequential([
    layers.Input(shape=input_shape),
    layers.Conv2D(32, kernel_size=3, padding='same', activation='relu'),
    layers.Conv2D(64, kernel_size=3, padding='same', activation='relu'),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(num_classes, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['sparse_categorical_accuracy'])
model.summary()

history = model.fit(X_train, y_train,
                    validation_data=(X_test, y_test),
                    epochs=EPOCHS,
                    batch_size=BATCH_SIZE,
                    verbose=2)


Epoch 1/12
250/250 - 3s - 13ms/step - loss: 4.9815 - sparse_categorical_accuracy: 0.0601 - val_loss: 4.8187 - val_sparse_categorical_accuracy: 0.0765
Epoch 2/12
250/250 - 2s - 9ms/step - loss: 4.5427 - sparse_categorical_accuracy: 0.1050 - val_loss: 4.5636 - val_sparse_categorical_accuracy: 0.0970
Epoch 3/12
250/250 - 2s - 10ms/step - loss: 4.1349 - sparse_categorical_accuracy: 0.1405 - val_loss: 4.3623 - val_sparse_categorical_accuracy: 0.1115
Epoch 4/12
250/250 - 2s - 9ms/step - loss: 3.7495 - sparse_categorical_accuracy: 0.1822 - val_loss: 4.2169 - val_sparse_categorical_accuracy: 0.1280
Epoch 5/12
250/250 - 2s - 9ms/step - loss: 3.4102 - sparse_categorical_accuracy: 0.2262 - val_loss: 4.1598 - val_sparse_categorical_accuracy: 0.1380
Epoch 6/12
250/250 - 2s - 9ms/step - loss: 3.1091 - sparse_categorical_accuracy: 0.2746 - val_loss: 4.1408 - val_sparse_categorical_accuracy: 0.1423
Epoch 7/12
250/250 - 2s - 9ms/step - loss: 2.8408 - sparse_categorical_accuracy: 0.3183 - val_loss: 4.17

In [16]:
probs = model.predict(X_test)
top1_preds = np.argmax(probs, axis=1)
top3_preds = np.argsort(probs, axis=1)[:,-3:]  # last 3 columns -> top3 indices

top1_acc = np.mean(top1_preds == y_test)
top3_acc = np.mean([y_test[i] in top3_preds[i] for i in range(len(y_test))])

print(f"Top-1 accuracy: {top1_acc:.2%}")
print(f"Top-3 accuracy: {top3_acc:.2%}")


[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
Top-1 accuracy: 14.62%
Top-3 accuracy: 28.30%


In [24]:
inv_vocab = {v:k for k,v in move2idx.items()}

def predict_next_moves_from_sequence(move_sequence, top_k=3):
    """Given a move sequence in SAN or UCI (e.g. 'e2e4' or 'e2e4 e7e5'), return top_k predicted UCI moves with confidences."""
    uci_list = san_moves_to_uci_list(move_sequence)
    if uci_list is None:
        return f"Could not parse move sequence: {move_sequence}"
    board = chess.Board()
    for u in uci_list:
        board.push(chess.Move.from_uci(u))
    fen = board.fen()
    x = fen_to_planes(fen)[None,...]  # batch dim
    probs = model.predict(x)[0]
    top_idx = np.argsort(probs)[-top_k:][::-1]
    return [(inv_vocab.get(i,"<UNK>"), float(probs[i])) for i in top_idx]

# Examples (run these after training)
print(predict_next_moves_from_sequence("e2e4"))           # predicts Black reply to 1.e4
print(predict_next_moves_from_sequence("e4 e5 Nf3 Nc6 Bb5 a6 "))    


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step
[('c7c5', 0.4973739683628082), ('e7e5', 0.20621387660503387), ('d7d5', 0.0726172998547554)]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 47ms/step
[('b5a4', 0.5904750227928162), ('b5c6', 0.15680380165576935), ('e1g1', 0.0728546753525734)]
