In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import torchvision.transforms as T
from tqdm import tqdm
from PIL import Image
import pandas as pd
import json
from pathlib import Path

In [2]:
import os
N_WORKERS = os.cpu_count()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
N_EPOCHS = 100

In [3]:
from comet_ml import Experiment


# Create an instance of the Experiment class
experiment = Experiment(
    project_name="ResNet50 Piece Counter Combination",  # Replace with your project name
    workspace="cristy17001"  # Replace with your workspace name
)

experiment.set_name("ResNet50 Piece Counter Combination 1")
experiment.log_parameters({
    "model": "resnet50",
    "optimizer": "AdamW",
    "lr": 1e-4,
    "weight_decay": 1e-4,
    "loss_function": "BCE + MSE",
    "scheduler": "ReduceLROnPlateau",
    "pretrained": True,
    "patience": 2,
    "batch_size": 64,
    "epochs": N_EPOCHS,
})

[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/cristy17001/resnet50-piece-counter-combination/d3a137f472674ad9b9f3abeee560bf4c



In [4]:
from torch.utils.data import Dataset

class PreloadedDataset(Dataset):
    def __init__(self, tensor_file):
        self.data = torch.load(tensor_file)  # list of (img_tensor, presence_tensor, count_tensor)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_tensor, presence_tensor, count_tensor = self.data[idx]
        return img_tensor, presence_tensor, count_tensor

In [5]:
train_dataset = PreloadedDataset("./train_data_noWarp.pt")
test_dataset = PreloadedDataset("./test_data_noWarp.pt")
val_dataset = PreloadedDataset("./val_data_noWarp.pt")

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)
validation_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0)

  self.data = torch.load(tensor_file)  # list of (img_tensor, presence_tensor, count_tensor)


In [6]:
class ResNet50MultiTask(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNet50MultiTask, self).__init__()
        
        # Load pretrained ResNet-50
        resnet = models.resnet50(pretrained=pretrained)

        # Remove the classification head (fc layer)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])  # Output: [B, 2048, 1, 1]

        # Flatten layer (ResNet output is [B, 2048, 1, 1])
        self.flatten = nn.Flatten()

        # Classification head for presence map (64 outputs for 8x8 grid)
        self.presence_head = nn.Sequential(
            nn.Linear(2048, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
        )

        # Regression head for piece count
        self.count_head = nn.Sequential(
            nn.Linear(2048, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Hardtanh(min_val=0, max_val=32)
        )

    def forward(self, x):
        features = self.feature_extractor(x)  # [B, 2048, 1, 1]
        features = self.flatten(features)     # [B, 2048]
        
        presence_out = self.presence_head(features)  # [B, 64]
        count_out = self.count_head(features)        # [B, 1]

        return presence_out, count_out

In [7]:
model = ResNet50MultiTask().to(device)



In [8]:
# Loss functions
count_loss = nn.L1Loss()
presence_loss = nn.BCEWithLogitsLoss()

PRESENCE_WEIGHT = 0.65
COUNT_WEIGHT = 0.35

In [9]:
# Optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.1, verbose=True)



In [10]:
import numpy as np
from sklearn.metrics import mean_absolute_error, accuracy_score

best_count_mae = np.inf  # Initialize with infinity

for epoch in range(N_EPOCHS):
    model.train()
    total_loss = 0.0

    for images, presence_maps, counts in train_loader:
        images = images.to(device)
        presence_maps = presence_maps.to(device)
        counts = counts.to(device)

        optimizer.zero_grad()
        outputs_presence, outputs_count = model(images)
        loss_presence = presence_loss(outputs_presence, presence_maps)
        loss_count = count_loss(outputs_count.squeeze(1), counts)
        loss = PRESENCE_WEIGHT * loss_presence + COUNT_WEIGHT * loss_count
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Train Loss: {avg_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0.0
    all_presence_preds = []
    all_presence_labels = []
    all_count_preds = []
    all_count_labels = []

    with torch.no_grad():
        for images, presence_maps, counts in validation_loader:
            images = images.to(device)
            presence_maps = presence_maps.to(device)
            counts = counts.to(device)

            presence_pred, count_pred = model(images)

            loss_presence = presence_loss(presence_pred, presence_maps)
            loss_count = count_loss(count_pred.squeeze(1), counts)
            val_loss += (PRESENCE_WEIGHT * loss_presence + COUNT_WEIGHT * loss_count).item()

            all_presence_preds.append(presence_pred.detach().cpu())
            all_presence_labels.append(presence_maps.detach().cpu())
            all_count_preds.append(count_pred.detach().cpu())
            all_count_labels.append(counts.detach().cpu())

    val_loss /= len(validation_loader)

    all_presence_preds = torch.cat(all_presence_preds).numpy()
    all_presence_labels = torch.cat(all_presence_labels).numpy()
    all_count_preds = torch.cat(all_count_preds).numpy()
    all_count_labels = torch.cat(all_count_labels).numpy()

    all_presence_preds_binary = (all_presence_preds > 0.5).astype(int)
    all_presence_labels_int = all_presence_labels.astype(int)

    presence_accuracy = accuracy_score(all_presence_labels_int.flatten(), all_presence_preds_binary.flatten())
    count_mae = mean_absolute_error(all_count_labels, all_count_preds)

    print(f"Validation Loss: {val_loss:.4f} | Presence Accuracy: {presence_accuracy:.4f} | Count MAE: {count_mae:.4f}")
    scheduler.step(val_loss)

    # Save checkpoint if this is the best count_mae so far
    if count_mae < best_count_mae:
        best_count_mae = count_mae
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'count_mae': count_mae,
        }, 'best_checkpoint_combination.pt')
        print(f"Checkpoint saved at epoch {epoch+1} with Count MAE: {count_mae:.4f}")

Epoch 1, Train Loss: 6.2699
Validation Loss: 3.2042 | Presence Accuracy: 0.7093 | Count MAE: 8.6637
Checkpoint saved at epoch 1 with Count MAE: 8.6637
Epoch 2, Train Loss: 2.5568
Validation Loss: 3.2243 | Presence Accuracy: 0.7194 | Count MAE: 7.6764
Checkpoint saved at epoch 2 with Count MAE: 7.6764
Epoch 3, Train Loss: 1.1257
Validation Loss: 1.4874 | Presence Accuracy: 0.7412 | Count MAE: 3.1232
Checkpoint saved at epoch 3 with Count MAE: 3.1232
Epoch 4, Train Loss: 1.0084
Validation Loss: 0.9652 | Presence Accuracy: 0.7356 | Count MAE: 1.6730
Checkpoint saved at epoch 4 with Count MAE: 1.6730
Epoch 5, Train Loss: 0.8778
Validation Loss: 1.1134 | Presence Accuracy: 0.7379 | Count MAE: 2.3388
Epoch 6, Train Loss: 0.7882
Validation Loss: 0.7668 | Presence Accuracy: 0.7510 | Count MAE: 1.1682
Checkpoint saved at epoch 6 with Count MAE: 1.1682
Epoch 7, Train Loss: 0.7263
Validation Loss: 0.8090 | Presence Accuracy: 0.7511 | Count MAE: 1.3181
Epoch 8, Train Loss: 0.7502
Validation Loss: 