### Detect Grid Line help detect squares

In [1]:
def detect_grid_lines(image_rgb, show_lines=True, edge_threshold=20):
    height, width = image_rgb.shape[:2]

    gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY)
    blurred = cv2.GaussianBlur(gray, (17, 13), 0)

    edges = cv2.Canny(blurred, 50, 150)
    edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1)

    lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=80, minLineLength=100, maxLineGap=10)
    
    if lines is None:
        print("No lines found")
        return [], []

    horizontal_lines = []
    vertical_lines = []

    for x1, y1, x2, y2 in lines[:, 0]:
        angle = np.degrees(np.arctan2(y2 - y1, x2 - x1))
        if abs(angle) < 10:
            horizontal_lines.append((x1, y1, x2, y2))
        elif abs(angle - 90) < 10 or abs(angle + 90) < 10:
            vertical_lines.append((x1, y1, x2, y2))

    def merge_line_coords(lines, axis='y', threshold=30):
        if not lines:
            return []
        coords = [int((y1 + y2) / 2) if axis == 'y' else int((x1 + x2) / 2) for x1, y1, x2, y2 in lines]
        coords = sorted(coords)
        merged = []
        current = coords[0]
        for val in coords[1:]:
            if abs(val - current) < threshold:
                current = int((current + val) / 2)
            else:
                merged.append(current)
                current = val
        merged.append(current)
        return merged

    merged_horizontal = merge_line_coords(horizontal_lines, axis='y')
    merged_vertical = merge_line_coords(vertical_lines, axis='x')

    # Remove lines too close to image edges
    merged_horizontal = [y for y in merged_horizontal if edge_threshold < y < (height - edge_threshold)]
    merged_vertical = [x for x in merged_vertical if edge_threshold < x < (width - edge_threshold)]

    filled_horizontal = []
    filled_vertical = []

    def fill_missing_lines(existing, filled, total_lines, size, tolerance=40):
        print(f"Filling missing lines: {len(existing)} found, {total_lines} expected")
        ideal_positions = list(np.linspace(0, size, total_lines + 2, dtype=int)[1:-1])
        for pos in ideal_positions:
            if not any(abs(pos - e) < tolerance for e in existing):
                existing.append(pos)
                filled.append(pos)
        existing.sort()
        return existing

    if len(merged_horizontal) < 7:
        merged_horizontal = fill_missing_lines(merged_horizontal, filled_horizontal, 7, height)
    if len(merged_vertical) < 7:
        merged_vertical = fill_missing_lines(merged_vertical, filled_vertical, 7, width)
        
    # If we don't have exactly 7 lines in each direction, use equally spaced lines as fallback
    if len(merged_horizontal) != 7 or len(merged_vertical) != 7:
        print(f"Insufficient grid lines detected: {len(merged_horizontal)} horizontal, {len(merged_vertical)} vertical")
        print("Falling back to equally spaced grid lines")

        # Generate 9 lines (including borders)
        horizontal_all = list(np.linspace(0, height, 9, dtype=int))
        vertical_all = list(np.linspace(0, width, 9, dtype=int))

        # Take internal 7 lines (excluding the first and last, which are borders)
        merged_horizontal = horizontal_all[1:-1]
        merged_vertical = vertical_all[1:-1]

        # Mark all as filled (fallback)
        filled_horizontal = merged_horizontal.copy()
        filled_vertical = merged_vertical.copy()

    if show_lines:
        merged_image = image_rgb.copy()
        for y in merged_horizontal:
            color = (255, 0, 255) if y in filled_horizontal else (0, 0, 255)  # Magenta or Blue
            cv2.line(merged_image, (0, y), (merged_image.shape[1], y), color, 3)

        for x in merged_vertical:
            color = (0, 255, 255) if x in filled_vertical else (0, 255, 0)  # Cyan or Green
            cv2.line(merged_image, (x, 0), (x, merged_image.shape[0]), color, 3)

        plt.figure(figsize=(8, 8))
        plt.imshow(merged_image)
        plt.title("Filtered + Filled Grid Lines (Distinct Colors)")
        plt.axis("off")
        plt.show()

    return merged_horizontal, merged_vertical

### Divide warped into Squares

In [2]:
files = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
def process_image(key, warped_results):
    h_lines, v_lines = detect_grid_lines(warped_results[key]['warped_image'], show_lines=False)

    if len(h_lines) != 7 or len(v_lines) != 7:
        print(f"Not enough grid lines detected in {key}.")
        return None

    height, width = warped_results[key]['warped_image'].shape[:2]
    h_lines_full = [0] + sorted(h_lines) + [height]
    v_lines_full = [0] + sorted(v_lines) + [width]

    squares = []
    square_centers = []
    square_bounds = []

    for i in range(8):
        for j in range(8):
            y1, y2 = h_lines_full[i], h_lines_full[i + 1]
            x1, x2 = v_lines_full[j], v_lines_full[j + 1]
            square = warped_results[key]['warped_image'][y1:y2, x1:x2]
            rank = 8 - i
            file = files[j]
            label = f"{file}{rank}"
            # Resize the square image to 224x224
            square_resized = cv2.resize(square, (224, 224), interpolation=cv2.INTER_AREA)
            squares.append({"position": (i, j), "label": label, "image": square_resized})
            center_x = (x1 + x2) // 2
            center_y = (y1 + y2) // 2
            square_centers.append((center_x, center_y))
            square_bounds.append((x1, y1, x2, y2))

    warped_results[key]['squares'] = squares
    return warped_results

### Load Truth Matrices

In [3]:
import os
import json
import numpy as np

training_inputs_dir = "../Shared/training_inputs/matrices"
training_files = [f for f in os.listdir(training_inputs_dir) if f.endswith('.json')]

training_data = []

for fname in training_files:
    path = os.path.join(training_inputs_dir, fname)
    with open(path, 'r') as f:
        data = json.load(f)
        if data.get("corners"):
            training_data.append({
                "file_name": data.get("file_name"),
                "image_id" : data.get("image_id"),
                "piece_count": data.get("piece_count"),
                "presence_matrix": np.flipud(data.get("presence_matrix")).tolist(),
                "piece_type_matrix": data.get("piece_type_matrix"),
                "corners": data.get("corners")
            })

print(f"Loaded {len(training_data)} training input files.")

Loaded 2078 training input files.


### Warp the image using corners

In [4]:
import cv2
import numpy as np

def warp_with_corners(image, corners_dict, board_size=800):
    """
    Warps the input image using the provided corners so that the bottom-left corner
    in corners_dict becomes the bottom-left of the output image.
    corners_dict: dict with keys 'top_left', 'top_right', 'bottom_left', 'bottom_right'
    image: input image (BGR or RGB)
    """
    # Order: [top_left, top_right, bottom_left, bottom_right]
    src_corners = np.array([
        corners_dict['top_left'],
        corners_dict['top_right'],
        corners_dict['bottom_left'],
        corners_dict['bottom_right']
    ], dtype="float32")

    # The destination corners: [top_left, top_right, bottom_left, bottom_right]
    dst_corners = np.array([
        [0, 0],
        [board_size - 1, 0],
        [0, board_size - 1],
        [board_size - 1, board_size - 1]
    ], dtype="float32")

    # Find the perspective transform matrix
    matrix = cv2.getPerspectiveTransform(src_corners, dst_corners)
    warped = cv2.warpPerspective(image, matrix, (board_size, board_size))

    return warped, matrix, src_corners, dst_corners

images_dir = "../Shared/all_images"
warped_results = {}
for entry in training_data:
    img_path = os.path.join(images_dir, entry['file_name'])
    if not os.path.exists(img_path):
        print(f"Image not found: {img_path}")
        continue
    img = cv2.imread(img_path)
    corners = entry['corners']
    warped_img, matrix, src, dst = warp_with_corners(img, corners)
    warped_results[entry['file_name']] = {
        'warped_image': warped_img,
        'matrix': matrix,
        'src_corners': src,
        'dst_corners': dst
    }

### Process the image

In [5]:
# for fname in warped_results.keys():
#     warped_results = process_image(fname, warped_results)

### Load splits

In [6]:
import json
from pathlib import Path

json_path = Path("../Shared/annotations.json")  # Update if needed

with open(json_path, 'r') as f:
    data = json.load(f)

# Load chessred2k split image IDs
chessred2k_splits = {}
for split in ['train', 'val', 'test']:
    chessred2k_splits[split] = data['splits']['chessred2k'][split]['image_ids']

print({k: len(v) for k, v in chessred2k_splits.items()})

{'train': 1442, 'val': 330, 'test': 306}


In [23]:
import torch
from torchvision import models
import torchvision.transforms as T
import torch.nn as nn
import cv2

def load_model(model, path):
    checkpoint = torch.load(path, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint)
    model.eval()
    return model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# # Load model
# MODEL_NAME = 'resnet50'
# model = models.resnet50(weights=False)
# model.fc = nn.Linear(model.fc.in_features, 1)  # Binary classification
# model = load_model(model, './checkpoints_resnet50_1_squares/best_piece_classifier.pt')

# # Load MobileNetV3 Large model
# MODEL_NAME = 'MobileNetV3'
# model = models.mobilenet_v3_large(weights=False)
# num_features = model.classifier[3].in_features
# model.classifier[3] = nn.Linear(num_features, 1)
# model = load_model(model, './checkpoints_mobileNetV3_squares/best_piece_classifier.pt')

# Load SqueezeNet model
MODEL_NAME = 'SqueezeNet'
model = models.squeezenet1_0(pretrained=False)
model.classifier[1] = nn.Conv2d(512, 1, kernel_size=1)
model.num_classes = 1

class SqueezeNetBinary(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model

    def forward(self, x):
        x = self.base(x)
        return x.view(x.size(0), -1)  # Flatten to shape [batch_size, 1]

model = SqueezeNetBinary(model)
model = load_model(model, './checkpoints_SqueezeNet_1.0_squares/best_piece_classifier.pt')

# Preprocessing for ResNet50
preprocess = T.Compose([
    T.ToPILImage(),
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

model = model.to(device)

# Split and image filtering
train_image_ids = set(chessred2k_splits['test'])
train_entries = [entry for entry in training_data if entry['image_id'] in train_image_ids]

results = {}

for entry in train_entries:
    fname = entry['file_name']
    if fname not in warped_results:
        print(f"Warped image not found for {fname}")
        continue

    warped_results = process_image(fname, warped_results)
    squares = warped_results[fname].get('squares', [])

    # Skip if no squares found
    if not squares:
        continue

    # Preprocess all 64 squares into a batch
    batch_tensors = []
    for square in squares:
        img = square['image']
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        tensor = preprocess(img_rgb)
        batch_tensors.append(tensor)

    # Create single batch tensor
    input_batch = torch.stack(batch_tensors).to(device)  # Shape: [64, 3, 224, 224]

    # Inference in one go
    with torch.no_grad():
        logits = model(input_batch)                       # [64, 1]
        probs = torch.sigmoid(logits).squeeze(1)          # [64]

    # Store predictions
    predictions = []
    for square, logit, prob in zip(squares, logits.squeeze(1), probs):
        predictions.append({
            'label': square['label'],
            'logit': logit.item(),
            'prob': prob.item()
        })

    results[fname] = predictions

  checkpoint = torch.load(path, map_location='cpu')


Filling missing lines: 4 found, 7 expected
Filling missing lines: 5 found, 7 expected
Filling missing lines: 6 found, 7 expected
Filling missing lines: 6 found, 7 expected
Filling missing lines: 6 found, 7 expected
Filling missing lines: 6 found, 7 expected
Filling missing lines: 4 found, 7 expected
Insufficient grid lines detected: 9 horizontal, 7 vertical
Falling back to equally spaced grid lines
Filling missing lines: 5 found, 7 expected
Filling missing lines: 5 found, 7 expected
Filling missing lines: 6 found, 7 expected
Filling missing lines: 6 found, 7 expected
Filling missing lines: 6 found, 7 expected
Filling missing lines: 6 found, 7 expected
Filling missing lines: 6 found, 7 expected
Filling missing lines: 6 found, 7 expected
Insufficient grid lines detected: 7 horizontal, 8 vertical
Falling back to equally spaced grid lines
Filling missing lines: 6 found, 7 expected
Filling missing lines: 6 found, 7 expected
Filling missing lines: 5 found, 7 expected
Filling missing lines: 6

In [25]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, mean_squared_error, mean_absolute_error
)
import numpy as np

# --- Per-Square Evaluation (classification-style) ---
y_true = []
y_pred = []  # Binary class: 0 or 1
y_prob = []  # Probabilistic output: sigmoid

file_to_presence = {entry['file_name']: entry['presence_matrix'] for entry in training_data}

for fname, predictions in results.items():
    presence_matrix = file_to_presence.get(fname)
    if presence_matrix is None:
        continue
    gt_flat = np.array(presence_matrix).flatten()
    for idx, pred in enumerate(predictions):
        y_true.append(gt_flat[idx])
        y_prob.append(pred['prob'])
        y_pred.append(int(pred['prob'] > 0.5))

y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_prob = np.array(y_prob)

# Classification metrics
print(f"\n----- Model {MODEL_NAME} -----")
print()
print("----- Per-Square Classification Metrics -----")
print("Accuracy:", accuracy_score(y_true, y_pred)*100, "%") 
print("Precision:", precision_score(y_true, y_pred, zero_division=0)*100, "%")
print("Recall:", recall_score(y_true, y_pred, zero_division=0)*100, "%")
print("F1 Score:", f1_score(y_true, y_pred, zero_division=0)*100, "%")
print("Confusion Matrix:\n", confusion_matrix(y_true, y_pred))

# --- Per-Image Evaluation (counting pieces per board) ---
true_counts = []
pred_counts = []
soft_pred_counts = []

for fname, predictions in results.items():
    presence_matrix = file_to_presence.get(fname)
    if presence_matrix is None:
        continue
    true_count = np.array(presence_matrix).sum()
    thresholded_count = sum(pred['prob'] > 0.5 for pred in predictions)
    soft_count = sum(pred['prob'] for pred in predictions)

    true_counts.append(true_count)
    pred_counts.append(thresholded_count)
    soft_pred_counts.append(soft_count)

# Thresholded count metrics
mse_count = mean_squared_error(true_counts, pred_counts)
rmse_count = np.sqrt(mse_count)
mae_count = mean_absolute_error(true_counts, pred_counts)

print("\n----- Per-Image Piece Count Metrics -----")
print("MSE (count-level, thresholded):", mse_count)
print("RMSE (count-level, thresholded):", rmse_count)
print("MAE (count-level, thresholded):", mae_count)


----- Model SqueezeNet -----

----- Per-Square Classification Metrics -----
Accuracy: 99.68341503267973 %
Precision: 99.13276568905708 %
Recall: 99.88878296790594 %
F1 Score: 99.50933839822729 %
Confusion Matrix:
 [[13235    55]
 [    7  6287]]

----- Per-Image Piece Count Metrics -----
MSE (count-level, thresholded): 0.3202614379084967
RMSE (count-level, thresholded): 0.5659164584181102
MAE (count-level, thresholded): 0.20261437908496732
