In [1]:
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Dataset
import albumentations as A
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
IMG_SIZE = 256

In [2]:
class DartsDataset(Dataset):
    def __init__(self, label_dir: Path, img_dir: Path, transform=None):
        self.label_dir = label_dir
        self.img_dir = img_dir
        self.transform = transform

        self.img_paths = sorted(Path(img_dir).glob("*.jpg"))
        self.label_paths = sorted(Path(label_dir).glob("*.txt"))

        assert len(self.img_paths) == len(self.label_paths), (
            "Image and Label count mismatch."
        )

        self.img_labels = pd.DataFrame([np.loadtxt(f) for f in self.label_paths])
        self.img_labels.columns = ["x", "y"]
        self.img_labels = self.img_labels[
            (self.img_labels["x"] < 1)
            & (self.img_labels["y"] < 1)
            & (self.img_labels["x"] > 0)
            & (self.img_labels["y"] > 0)
        ]

        self.safe_transform = A.Compose([
            A.Resize(height=IMG_SIZE, width=IMG_SIZE, p=1.0),
            A.ToFloat(max_value=255.0),
            A.pytorch.ToTensorV2(),
        ], keypoint_params=A.KeypointParams(format="xy", remove_invisible=False))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        real_idx: int = int(self.img_labels.index[idx])  # type: ignore
        img_path = self.img_paths[real_idx]

        image = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")

        image = np.expand_dims(image, axis=2)
        h, w = image.shape[:2]

        label_path = self.label_paths[real_idx]
        coords = np.loadtxt(label_path).flatten()
        x_norm, y_norm = coords[0], coords[1]

        x_abs = x_norm * w
        y_abs = y_norm * h

        image_trans = None
        target_tensor = None

        if self.transform:
            for _ in range(10):
                augmented = self.transform(image=image, keypoints=[[x_abs, y_abs]])
                if len(augmented["keypoints"]) > 0:
                    image_trans = augmented["image"]
                    x_aug, y_aug = augmented["keypoints"][0]
                    h_new, w_new = image_trans.shape[1], image_trans.shape[2]
                    x_final = np.clip(x_aug / w_new, 0.0, 1.0).item()
                    y_final = np.clip(y_aug / h_new, 0.0, 1.0).item()
                    target_tensor = torch.tensor([x_final, y_final], dtype=torch.float32)
                    break
        if target_tensor is None or self.transform is None:
            safe_aug = self.safe_transform(image=image, keypoints=[[x_abs, y_abs]])
            image_trans = safe_aug["image"]
            h_new, w_new = image_trans.shape[1], image_trans.shape[2]
            if len(safe_aug['keypoints']) > 0:
                x_s, y_s = safe_aug['keypoints'][0]
                x_final = x_s / w_new
                y_final = y_s / h_new
            else:
              x_final, y_final = x_norm, y_norm
            target_tensor = torch.tensor([x_final, y_final], dtype=torch.float32)

        return image_trans, target_tensor

In [3]:
class CNNModel(torch.nn.Module):
    def __init__(self, input_channels=1):
        super(CNNModel, self).__init__()
        in_dim = input_channels + 2
        self.features = nn.Sequential(
            # 1. Block with dilation (Increased receptive field)
            nn.Conv2d(in_dim, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            # 2. Block with dilation (Increased receptive field)
            nn.Conv2d(32, 64, kernel_size=3, padding=2, dilation=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            # 3. Block without dilation
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
        )

        self.regressor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 2),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Inputs should be tensors in (B, C, H, W) format"""
        # https://arxiv.org/pdf/1807.03247
        batch_size, c, h, w = x.size()


        # Create coordinate grid [-1, 1]
        y_coords = torch.linspace(-1, 1, h, device=x.device)
        x_coords = torch.linspace(-1, 1, w, device=x.device)
        yy, xx = torch.meshgrid(y_coords, x_coords, indexing="ij")

        # Expand and Cat
        xx = xx.expand(batch_size, 1, h, w)
        yy = yy.expand(batch_size, 1, h, w)
        x = torch.cat([x, xx, yy], dim=1)

        # Forward pass through the network
        x = self.features(x)
        x = self.regressor(x)

        return x

In [4]:
def test_loop(dataloader, model, loss_fn, device):
    num_batches = len(dataloader)
    test_loss = 0.0
    mean_distance = 0.0

    model.eval()

    with torch.no_grad():
        for img, label in dataloader:
            x, y = img.to(device), label.to(device)

            pred = model(x)
            test_loss += loss_fn(pred, y).item()
            dist = torch.sqrt(torch.sum((pred - y)**2, dim=1))
            mean_distance += dist.mean().item()

    test_loss /= num_batches
    mean_distance /= num_batches
    return test_loss, mean_distance

In [5]:
dataset_dir = Path(r"..\data\darts_positions")
validation_images = dataset_dir / "images/val"
validation_labels = dataset_dir / "labels/val"

validation_dataset = DartsDataset(
    label_dir=validation_labels,
    img_dir=validation_images,
    transform=A.Compose(
        [
            A.Resize(height=IMG_SIZE, width=IMG_SIZE, p=1.0),
            A.ToFloat(max_value=255.0),
            A.pytorch.transforms.ToTensorV2(),
        ],
        keypoint_params=A.KeypointParams(format="xy"),
    ),
)

In [12]:
val_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=2)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

model = CNNModel()
model_locations = Path(r"G:\My Drive\CNNTraining")
model_paths = sorted(model_locations.rglob("*.pt"), key=lambda x: x.stat().st_mtime)
model_paths = filter(lambda p: "best" in p.name and "training_" in str(p), model_paths)

Using cpu device


In [None]:
losses = []
mean_distances = []
for model_path in model_paths:
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)

    loss_fn = nn.SmoothL1Loss()
    test_loss, mean_distance = test_loop(val_loader, model, loss_fn, device)
    print(f"Model: {model_path.name} | Test Loss: {test_loss:.6f} | Mean Distance: {mean_distance:.6f}")
    losses.append(test_loss)
    mean_distances.append(mean_distance)

plt.figure(figsize=(10, 5))
plt.plot(losses, label="Test Loss")
plt.plot(mean_distances, label="Mean Distance")