In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import torchvision.transforms as T
from tqdm import tqdm
from PIL import Image
import pandas as pd
import json
from pathlib import Path

In [2]:
import os
N_WORKERS = os.cpu_count()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
from torch.utils.data import Dataset

class PreloadedDataset(Dataset):
    def __init__(self, tensor_file):
        self.data = torch.load(tensor_file)  # list of (img_tensor, presence_tensor, count_tensor)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_tensor, presence_tensor, count_tensor = self.data[idx]
        return img_tensor, presence_tensor, count_tensor

In [4]:
train_dataset = PreloadedDataset("./train_data_noWarp.pt")
test_dataset = PreloadedDataset("./test_data_noWarp.pt")
val_dataset = PreloadedDataset("./val_data_noWarp.pt")

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)
validation_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0)

  self.data = torch.load(tensor_file)  # list of (img_tensor, presence_tensor, count_tensor)


In [5]:
class ResNet50MultiTask(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNet50MultiTask, self).__init__()
        
        # Load pretrained ResNet-50
        resnet = models.resnet50(pretrained=pretrained)

        # Remove the classification head (fc layer)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])  # Output: [B, 2048, 1, 1]

        # Flatten layer (ResNet output is [B, 2048, 1, 1])
        self.flatten = nn.Flatten()

        # Classification head for presence map (64 outputs for 8x8 grid)
        self.presence_head = nn.Sequential(
            nn.Linear(2048, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
        )

        # Regression head for piece count
        self.count_head = nn.Sequential(
            nn.Linear(2048, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Hardtanh(min_val=0, max_val=32)
        )

    def forward(self, x):
        features = self.feature_extractor(x)  # [B, 2048, 1, 1]
        features = self.flatten(features)     # [B, 2048]
        
        presence_out = self.presence_head(features)  # [B, 64]
        count_out = self.count_head(features)        # [B, 1]

        return presence_out, count_out

In [6]:
checkpoint = torch.load("best_checkpoint_combination.pt", map_location=device)
model = ResNet50MultiTask().to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

  checkpoint = torch.load("best_checkpoint_combination.pt", map_location=device)


ResNet50MultiTask(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
      

In [7]:
from sklearn.metrics import mean_absolute_error, accuracy_score

piece_counts = []
true_counts = []
with torch.no_grad():
    for images, _, count_tensor in test_loader:
        images = images.to(device)
        _, count_out = model(images)
        piece_counts.extend(torch.round(count_out.cpu().squeeze()).tolist())
        true_counts.extend(torch.round(count_tensor.cpu().squeeze()).tolist())

mae = mean_absolute_error(true_counts, piece_counts)
accuracy = accuracy_score(true_counts, piece_counts)

print("True piece counts for test set:", true_counts)
print("Predicted piece counts for test set:", piece_counts)
print(f"MAE: {mae:.4f}")
print(f"Piece count accuracy: {accuracy:.4f}")

True piece counts for test set: [32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 31.0, 30.0, 29.0, 28.0, 28.0, 28.0, 28.0, 27.0, 26.0, 26.0, 26.0, 26.0, 26.0, 26.0, 25.0, 24.0, 24.0, 24.0, 24.0, 23.0, 22.0, 21.0, 20.0, 19.0, 19.0, 19.0, 19.0, 19.0, 19.0, 19.0, 19.0, 19.0, 19.0, 19.0, 18.0, 17.0, 17.0, 17.0, 16.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 14.0, 13.0, 13.0, 13.0, 12.0, 12.0, 12.0, 11.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 9.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 31.0, 31.0, 31.0, 30.0, 29.0, 28.0, 28.0, 28.0, 28.0, 28.0, 28.0, 28.0, 28.0, 28.0, 28.0, 28.0, 27.0, 27.0, 27.0, 27.0, 27.0, 27.0, 27.0, 26.0, 25.0, 25.0, 25.0, 25.0, 25.0, 24.0, 24.0, 24.0, 23.0, 22.0, 21.0, 20.0, 20.0, 20.0, 20.0, 19.0, 18.0, 18.0, 17.0, 17.0, 17.0, 17