In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import torchvision.transforms as T
from tqdm import tqdm
from PIL import Image
import pandas as pd
import json
from pathlib import Path

In [2]:
import os
N_WORKERS = os.cpu_count()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
N_EPOCHS = 100


In [3]:
# from comet_ml import Experiment


# # Create an instance of the Experiment class
# experiment = Experiment(
#     project_name="ResNet18 Piece Counter Combination",  # Replace with your project name
#     workspace="cristy17001"  # Replace with your workspace name
# )

# experiment.set_name("ResNet18 Piece Counter Combination 1")
# experiment.log_parameters({
#     "model": "resnet18",
#     "optimizer": "AdamW",
#     "lr": 1e-4,
#     "weight_decay": 1e-4,
#     "loss_function": "BCE + MSE",
#     "scheduler": "ReduceLROnPlateau",
#     "pretrained": True,
#     "patience": 2,
#     "batch_size": 64,
#     "epochs": N_EPOCHS,
# })

In [4]:
from torch.utils.data import Dataset

class PreloadedDataset(Dataset):
    def __init__(self, tensor_file):
        self.data = torch.load(tensor_file)  # list of (img_tensor, presence_tensor, count_tensor)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_tensor, presence_tensor, count_tensor = self.data[idx]
        return img_tensor, presence_tensor, count_tensor

In [5]:
train_dataset = PreloadedDataset("./split_count_presence/train_data.pt")
test_dataset = PreloadedDataset("./split_count_presence/test_data.pt")
val_dataset = PreloadedDataset("./split_count_presence/val_data.pt")

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)
validation_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0)

  self.data = torch.load(tensor_file)  # list of (img_tensor, presence_tensor, count_tensor)


In [6]:
class ResNet18MultiTask(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNet18MultiTask, self).__init__()
        
        # Load pretrained ResNet-50
        resnet = models.resnet50(pretrained=pretrained)

        # Remove the classification head (fc layer)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])  # Output: [B, 512, 1, 1]

        # Flatten layer (ResNet output is [B, 512, 1, 1])
        self.flatten = nn.Flatten()

        # Classification head for presence map (64 outputs for 8x8 grid)
        self.presence_head = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
        )

        # Regression head for piece count
        self.count_head = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # Output is a scalar
        )

    def forward(self, x):
        features = self.feature_extractor(x)  # [B, 512, 1, 1]
        features = self.flatten(features)     # [B, 512]
        
        presence_out = self.presence_head(features)  # [B, 64]
        count_out = self.count_head(features)        # [B, 1]

        return presence_out, count_out

In [7]:
model = ResNet18MultiTask().to(device)



In [8]:
# Loss functions
count_loss = nn.L1Loss()
presence_loss = nn.BCEWithLogitsLoss()

PRESENCE_WEIGHT = 0.7
COUNT_WEIGHT = 0.3

In [9]:
# Optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.1, verbose=True)



In [10]:
from sklearn.metrics import mean_absolute_error, accuracy_score

# Training loop (simplified)
for epoch in range(N_EPOCHS):
    model.train()
    total_loss = 0.0

    for images, presence_maps, counts in train_loader:
        images = images.to(device)
        presence_maps = presence_maps.to(device)
        counts = counts.to(device)

        optimizer.zero_grad()
        outputs_presence, outputs_count = model(images)
        loss_presence = presence_loss(outputs_presence, presence_maps)
        loss_count = count_loss(outputs_count.squeeze(1), counts)
        loss = PRESENCE_WEIGHT * loss_presence + COUNT_WEIGHT * loss_count
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Train Loss: {avg_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0.0
    all_presence_preds = []
    all_presence_labels = []
    all_count_preds = []
    all_count_labels = []

    with torch.no_grad():
        for images, presence_maps, counts in validation_loader:
            images = images.to(device)
            presence_maps = presence_maps.to(device)
            counts = counts.to(device)

            presence_pred, count_pred = model(images)

            loss_presence = presence_loss(presence_pred, presence_maps)
            loss_count = count_loss(count_pred.squeeze(1), counts)  # Fix shape here
            val_loss += (PRESENCE_WEIGHT * loss_presence + COUNT_WEIGHT * loss_count).item()  # Use same weights as training

            # Collect for metrics
            all_presence_preds.append(presence_pred.detach().cpu())
            all_presence_labels.append(presence_maps.detach().cpu())
            all_count_preds.append(count_pred.detach().cpu())
            all_count_labels.append(counts.detach().cpu())

    val_loss /= len(validation_loader)

    # Concatenate all batches
    all_presence_preds = torch.cat(all_presence_preds).numpy()
    all_presence_labels = torch.cat(all_presence_labels).numpy()
    all_count_preds = torch.cat(all_count_preds).numpy()
    all_count_labels = torch.cat(all_count_labels).numpy()

    # Convert presence predictions to binary by thresholding at 0.5
    all_presence_preds_binary = (all_presence_preds > 0.5).astype(int)
    all_presence_labels_int = all_presence_labels.astype(int)

    # Accuracy for presence map
    presence_accuracy = accuracy_score(all_presence_labels_int.flatten(), all_presence_preds_binary.flatten())

    # MAE for count regression
    count_mae = mean_absolute_error(all_count_labels, all_count_preds)

    print(f"Validation Loss: {val_loss:.4f} | Presence Accuracy: {presence_accuracy:.4f} | Count MAE: {count_mae:.4f}")
    scheduler.step(val_loss)

Epoch 1, Train Loss: 6.2425
Validation Loss: 4.0902 | Presence Accuracy: 0.7009 | Count MAE: 13.7282
Epoch 2, Train Loss: 4.8904
Validation Loss: 2.9434 | Presence Accuracy: 0.7182 | Count MAE: 9.7101
Epoch 3, Train Loss: 3.4143
Validation Loss: 1.9037 | Presence Accuracy: 0.7374 | Count MAE: 5.7177
Epoch 4, Train Loss: 1.7841
Validation Loss: 1.0231 | Presence Accuracy: 0.7455 | Count MAE: 2.1537
Epoch 5, Train Loss: 0.9478
Validation Loss: 0.8706 | Presence Accuracy: 0.7590 | Count MAE: 1.6201
Epoch 6, Train Loss: 0.7745
Validation Loss: 0.9190 | Presence Accuracy: 0.7636 | Count MAE: 1.8725
Epoch 7, Train Loss: 0.7267
Validation Loss: 0.7680 | Presence Accuracy: 0.7647 | Count MAE: 1.4631
Epoch 8, Train Loss: 0.6899
Validation Loss: 0.7267 | Presence Accuracy: 0.7695 | Count MAE: 1.3412
Epoch 9, Train Loss: 0.6404
Validation Loss: 0.7591 | Presence Accuracy: 0.7706 | Count MAE: 1.4090
Epoch 10, Train Loss: 0.6438
Validation Loss: 0.7003 | Presence Accuracy: 0.7711 | Count MAE: 1.237

KeyboardInterrupt: 