In [1]:
from torch.utils.data import Dataset
import os
import cv2
from torchvision import transforms

from board_to_fen import *

class FENChessSquareDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.filenames = [f for f in os.listdir(img_dir) if f.endswith('.jpeg')]
        self.transform = transform or transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((64, 64)),
            transforms.ToTensor()
        ])
        self.piece_map = {
            '.': 0, 'P': 1, 'N': 2, 'B': 3, 'R': 4, 'Q': 5, 'K': 6,
            'p': 7, 'n': 8, 'b': 9, 'r': 10, 'q': 11, 'k': 12
        }
    
    def __len__(self):
        return len(self.filenames) * 64
    
    def __getitem__(self, idx):
        board_idx = idx // 64
        square_idx = idx % 64

        filename = self.filenames[board_idx]
        img_path = os.path.join(self.img_dir, filename)
        img = cv2.imread(img_path)
        img = cv2.resize(img, (800, 800))
        squares = split_into_squares(img)

        fen = filename_to_fen(filename)
        labels = fen_to_labels(fen)

        square_img = squares[square_idx]
        label = self.piece_map[labels[square_idx]]

        return self.transform(square_img), label

In [2]:
from torch.utils.data import DataLoader

dataset = FENChessSquareDataset("/Users/kisel/projects/shallow-blue/cv/dataset/train")

loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [3]:
import torch.nn as nn

class ChessPieceCNN(nn.Module):
    def __init__(self, num_classes=13):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), 
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*16*16, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.model(x)

In [3]:
!python train.py

Using device: mps
Using 500 boards → 32000 total square samples
DataLoader will use 8 worker processes
Epoch 1/10: 100%|█| 500/500 [00:14<00:00, 34.24batch/s, loss=0.1312, acc=0.9683]
→ Epoch 1: Loss 0.1312 | Acc 0.9683
Epoch 2/10: 100%|█| 500/500 [00:13<00:00, 38.11batch/s, loss=0.0048, acc=0.9989]
→ Epoch 2: Loss 0.0048 | Acc 0.9989
Epoch 3/10: 100%|█| 500/500 [00:13<00:00, 37.04batch/s, loss=0.0029, acc=0.9991]
→ Epoch 3: Loss 0.0029 | Acc 0.9991
Epoch 4/10: 100%|█| 500/500 [00:13<00:00, 36.29batch/s, loss=0.0018, acc=0.9996]
→ Epoch 4: Loss 0.0018 | Acc 0.9996
Epoch 5/10: 100%|█| 500/500 [00:14<00:00, 35.44batch/s, loss=0.0044, acc=0.9987]
→ Epoch 5: Loss 0.0044 | Acc 0.9987
Epoch 6/10: 100%|█| 500/500 [00:14<00:00, 35.30batch/s, loss=0.0043, acc=0.9987]
→ Epoch 6: Loss 0.0043 | Acc 0.9987
Epoch 7/10: 100%|█| 500/500 [00:14<00:00, 33.52batch/s, loss=0.0000, acc=1.0000]
→ Epoch 7: Loss 0.0000 | Acc 1.0000
Epoch 8/10: 100%|█| 500/500 [00:16<00:00, 30.88batch/s, loss=0.0000, acc=1.000