In [None]:
import torch
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from dataset import FENChessSquareDataset, ChessPieceCNN

# 1. Load model & data
device = torch.device("mps" if torch.backends.mps.is_available() else
                      "cuda" if torch.cuda.is_available() else "cpu")
model = ChessPieceCNN(13).to(device)
model.load_state_dict(torch.load("chess_piece_cnn.pth", map_location=device))
model.eval()

val_ds     = FENChessSquareDataset("dataset/test")
indices = list(range(len(val_ds)))
val_subset = Subset(val_ds, indices[:100_000])
val_loader = DataLoader(val_subset, batch_size=64, shuffle=False, num_workers=0)

# 2. Collect preds & labels
all_preds = []
all_labels = []
with torch.no_grad():
    for imgs, labels in tqdm(val_loader, desc="Evaluating"):
        imgs   = imgs.to(device)
        logits = model(imgs)
        preds  = logits.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

# 3. Compute accuracy
all_preds  = np.array(all_preds)
all_labels = np.array(all_labels)
accuracy = (all_preds == all_labels).mean()
print(f"Overall accuracy: {accuracy:.4f}")

# 4. Build confusion matrix
num_classes = len(val_ds.piece_map)
cm = np.zeros((num_classes, num_classes), dtype=int)
for t, p in zip(all_labels, all_preds):
    cm[t, p] += 1

# 5. Plot
plt.figure(figsize=(8, 8))
plt.imshow(cm, interpolation="nearest", cmap="Blues")
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.colorbar()
plt.show()


Evaluating:   7%|▋         | 102/1563 [00:05<01:19, 18.46it/s]