In [None]:
from positions import *
from ultralytics import YOLO
import os
import time

pieces_model_weights = 'best_detection.pt'
board_model_weights = 'best_segmentation.pt'

TEST_FOLDER_IMAGES = 'dataset/test/images'
TEST_FOLDER_LABELS = 'dataset/test/correct_output'

pieces_model = YOLO(pieces_model_weights)
board_model = YOLO(board_model_weights)

images = [os.path.join(TEST_FOLDER_IMAGES, f) for f in os.listdir(TEST_FOLDER_IMAGES) if f.endswith('.png')]

classes = {
    'B': 0, 'K': 1, 'N': 2, 'P': 3, 'Q': 4, 'R': 5,
    'b': 6, 'k': 7, 'n': 8, 'p': 9, 'q': 10, 'r': 11
}
inv_classes = {v: k for k, v in classes.items()}

In [None]:
# Check if a piece is in a list of pieces
def piece_in_list(piece, pieces):
    for p in pieces:
        if p['piece'] == piece['piece'] and p['square'] == piece['square']:
            return True
    return False
    

In [None]:
import cv2
import numpy as np
import json

# [Correct, Wrong]
stats = {
    'fen': [0, 0],
    'pieces_pred': [0, 0],
    'pieces_real': [0, 0]
}

# Timing statistics
timing_stats = {
    'pieces_inference': [],
    'board_inference': [],
    'total_processing': []
}

for class_name in classes:
    stats[class_name + '_pred'] = [0, 0]
    stats[class_name + '_real'] = [0, 0]

many_errors_pred_board = []
many_errors_real_board = []
wrong_board = []
for i, image in enumerate(images):
    print(f'Processing {i+1}/{len(images)}: {image}')
    print(f'{stats["fen"]}, {stats["pieces_pred"]}, {stats["pieces_real"]}')
    
    start_total = time.time()

    # Read the label file
    label_file = os.path.join(TEST_FOLDER_LABELS, os.path.basename(image).replace('.png', '.json'))
    if not os.path.exists(label_file):
        print(f'Label file not found: {label_file}')
        continue

    with open(label_file, 'r') as f:
        label_data = json.load(f)
        
    # Detect pieces with timing
    start_pieces = time.time()
    pieces_results = pieces_model.predict(image, conf=0.5, verbose=False, save=False)
    pieces_time = time.time() - start_pieces
    timing_stats['pieces_inference'].append(pieces_time)
    #print(f'Pieces inference time: {pieces_time:.4f}s')
    
    # Detect board with timing
    start_board = time.time()
    board_results = board_model.predict(image, conf=0.5, verbose=False, save=False)
    board_time = time.time() - start_board
    timing_stats['board_inference'].append(board_time)
    #print(f'Board inference time: {board_time:.4f}s')

    detected_pieces = []
    correct_pieces = label_data['pieces']

    # Process results
    if board_results[0].masks.xy is not None and len(board_results[0].masks.xy) > 0:
        # Get the original mask contours from xy coordinates
        mask_contours = board_results[0].masks.xy[0]
        
        # Convert to numpy array for OpenCV operations
        contour_points = np.array(mask_contours, dtype=np.float32)
        
        # Approximate the contour to a quadrilateral using masks.xy
        epsilon = 0.05 * cv2.arcLength(contour_points, True)
        board_vert = cv2.approxPolyDP(contour_points, epsilon, True)
        rep = 0
        while len(board_vert) != 4:
            if len(board_vert) > 4:
                epsilon *= 1.05
            else:
                epsilon *= 0.95
            board_vert = cv2.approxPolyDP(contour_points, epsilon, True)
            rep += 1

            if rep > 100:
                print('Too many iterations, breaking out')
                break
        if len(board_vert) == 4:
            board_vert = board_vert.reshape(-1, 2)  # Flatten to 2D array
            board_vert = [[float(x), float(y)] for x, y in board_vert]


            transform = calc_transform(board_vert)
        
            # Process pieces results
            if len(pieces_results) > 0:
                for j in range(len(pieces_results[0].boxes.cls)):
                    box = [float(pieces_results[0].boxes.xyxy[j][0]), float(pieces_results[0].boxes.xyxy[j][1]),
                        float(pieces_results[0].boxes.xywh[j][2]), float(pieces_results[0].boxes.xywh[j][3])]
                    
                    piece = {
                        'piece' : inv_classes[int(pieces_results[0].boxes.cls[j])],
                        'box' : box,
                        'square' : calc_position(box, transform, 0 if label_data["white_turn"] else 2)
                    }
                    
                    if not 'ERROR' in piece['square']:
                        detected_pieces.append(piece)
        
        wrong_pieces_pred = 0
        for piece in detected_pieces:
            if piece_in_list(piece, correct_pieces):
                stats['pieces_pred'][0] += 1
                stats[piece['piece'] + '_pred'][0] += 1
            else:
                wrong_pieces_pred += 1
                stats['pieces_pred'][1] += 1
                stats[piece['piece'] + '_pred'][1] += 1
        
        wrong_pieces_real = 0
        for piece in correct_pieces:
            if piece_in_list(piece, detected_pieces):
                stats['pieces_real'][0] += 1
                stats[piece['piece'] + '_real'][0] += 1
            else:
                wrong_pieces_real += 1
                stats['pieces_real'][1] += 1
                stats[piece['piece'] + '_real'][1] += 1
        
        # Convert detected pieces to FEN
        fen_pred = pieces_to_fen(detected_pieces)
        fen_real = label_data['fen']

        if fen_pred == fen_real:
            stats['fen'][0] += 1
        else:
            wrong_board.append(image)
            stats['fen'][1] += 1

        if wrong_pieces_pred > 5:
            many_errors_pred_board.append(image)
        if wrong_pieces_real > 5:
            many_errors_real_board.append(image)

    total_time = time.time() - start_total
    timing_stats['total_processing'].append(total_time)
    

# print wrong boards for debugging
print(f"Boards with many errors in predicted pieces ({len(many_errors_pred_board)}):", many_errors_pred_board)
print(f"Boards with many errors in real pieces ({len(many_errors_real_board)}):", many_errors_real_board)
print(f"Boards with wrong FEN ({len(wrong_board)}):", wrong_board)

print("Statistics:")
for key, value in stats.items():
    print(f"{key}: Correct: {value[0]}, Wrong: {value[1]} ({(value[0] / (value[0] + value[1])) * 100:.2f}%)")
print(f"Total pieces detected: {stats['pieces_pred'][0] + stats['pieces_pred'][1]}")
print(f"Total pieces real: {stats['pieces_real'][0] + stats['pieces_real'][1]}")
print(f"Total FEN correct: {stats['fen'][0] + stats['fen'][1]}")
print(f"Total FEN correct: {stats['fen'][0]} / {stats['fen'][0] + stats['fen'][1]} ({(stats['fen'][0] / (stats['fen'][0] + stats['fen'][1])) * 100:.2f}%)")
print(f"Total pieces predicted: {stats['pieces_pred'][0] + stats['pieces_pred'][1]}")
print(f"Total pieces real: {stats['pieces_real'][0] + stats['pieces_real'][1]}")

print("\nTiming Statistics:")
if timing_stats['pieces_inference']:
    pieces_avg = sum(timing_stats['pieces_inference']) / len(timing_stats['pieces_inference'])
    print(f"Pieces model - Avg: {pieces_avg:.4f}s, Min: {min(timing_stats['pieces_inference']):.4f}s, Max: {max(timing_stats['pieces_inference']):.4f}s")

if timing_stats['board_inference']:
    board_avg = sum(timing_stats['board_inference']) / len(timing_stats['board_inference'])
    print(f"Board model - Avg: {board_avg:.4f}s, Min: {min(timing_stats['board_inference']):.4f}s, Max: {max(timing_stats['board_inference']):.4f}s")

if timing_stats['total_processing']:
    total_avg = sum(timing_stats['total_processing']) / len(timing_stats['total_processing'])
    print(f"Total processing - Avg: {total_avg:.4f}s, Min: {min(timing_stats['total_processing']):.4f}s, Max: {max(timing_stats['total_processing']):.4f}s")

In [None]:
import json
stats_file = 'test_stats.json'
with open(stats_file, 'w') as f:
    json.dump(stats, f, indent=2)