In [5]:
import sys
sys.path.append("../src")
import torch
import os
import cv2
from torchvision import transforms
from board_to_fen import split_into_squares
from dataset import ChessPieceCNN  # replace with actual file/module name

# 1. Define label-to-piece mapping
piece_map = {
    '.': 0, 'P': 1, 'N': 2, 'B': 3, 'R': 4, 'Q': 5, 'K': 6,
    'p': 7, 'n': 8, 'b': 9, 'r': 10, 'q': 11, 'k': 12
}
idx_to_symbol = {v: k for k, v in piece_map.items()}

# 2. Load the trained model
model = ChessPieceCNN()
model.load_state_dict(torch.load('../models/chess_piece_cnn.pth', map_location='cpu'))
model.eval()

# 3. Load and preprocess the board image
img_path = '../single_test/test_image.png'
img = cv2.imread(img_path)
if img is None:
    raise FileNotFoundError(f"Cannot read image: {img_path}")

img = cv2.resize(img, (800, 800))
squares = split_into_squares(img)  # list of 64 cropped square images

# 4. Apply transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])
batch = torch.stack([transform(square) for square in squares])  # (64, 3, 64, 64)

# 5. Predict all 64 squares
with torch.no_grad():
    logits = model(batch)
    preds = logits.argmax(dim=1).tolist()

# 6. Convert to symbols
fen_symbols = [idx_to_symbol[i] for i in preds]

# 7. Group into ranks
ranks = [fen_symbols[i*8:(i+1)*8] for i in range(8)]

def compress_rank(rank):
    out = ""
    empty = 0
    for piece in rank:
        if piece == '.':
            empty += 1
        else:
            if empty > 0:
                out += str(empty)
                empty = 0
            out += piece
    if empty > 0:
        out += str(empty)
    return out

fen_rows = [compress_rank(r) for r in ranks]
fen_str = "/".join(fen_rows)
print("Predicted FEN:", fen_str)


Predicted FEN: 1P1K3R/2P2P2/4BP1P/1N2P3/NPnPp3/3p1n2/npp2ppp/1knq1p1r
