# Notebook Overview

This notebook is designed for preparing, processing, and training a classification model using a detection dataset. 


In [1]:
import os
import torch
import torchvision
from torch.utils.data import random_split
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
import json
from pathlib import Path
from torch.utils.data import DataLoader
from ultralytics import YOLO
from torchvision import models
import torch.optim as optim


In [2]:
# Path to annotations and images
train_ann_path = r"C:\Users\bumin\Downloads\DLCV project\TACO\dataset_split\train_iou_results.json" 
val_ann_path = r"C:\Users\bumin\Downloads\DLCV project\TACO\dataset_split\val_iou_results.json"
train_image_dir = r"C:\Users\bumin\Downloads\DLCV project\TACO\dataset_split\train"
val_image_dir = r"C:\Users\bumin\Downloads\DLCV project\TACO\dataset_split\val"
test_dir = r"C:\Users\bumin\Downloads\DLCV project\TACO\dataset_split\test"
test_ann_path = r"C:\Users\bumin\Downloads\DLCV project\TACO\dataset_split\test_iou_results.json"

In [3]:
# Load annotations
with open(train_ann_path, 'r') as f:
    train_annotations = json.load(f)

# Load annotations
with open(val_ann_path, 'r') as f:
    val_annotations = json.load(f)

# Load annotations
with open(test_ann_path, 'r') as f:
    test_annotations = json.load(f)

# Example format assumption: {'image.jpg': 'supercategory'}
label_mapping_train = train_annotations
label_mapping_val = val_annotations

In [4]:
# Example: Group annotations by image_id
image_annotations_train = {}
for item in train_annotations:
    image_id = item['image_id']
    if image_id not in image_annotations_train:
        image_annotations_train[image_id] = []
    image_annotations_train[image_id].append(item)

# Example: Group annotations by image_id
image_annotations_val = {}
for item in val_annotations:
    image_id = item['image_id']
    if image_id not in image_annotations_val:
        image_annotations_val[image_id] = []
    image_annotations_val[image_id].append(item)

# Example: Group annotations by image_id
image_annotations_test = {}
for item in test_annotations:
    image_id = item['image_id']
    if image_id not in image_annotations_test:
        image_annotations_test[image_id] = []
    image_annotations_test[image_id].append(item)

## Define Helper Functions
- Define utility functions like `bbox_center_distance` for computing the distance between bounding box centers.
- These functions aid in selecting the most relevant annotations and preprocessing data effectively.


In [5]:
def bbox_center_distance(box1, box2):
    """
    Calculate the Euclidean distance between the centers of two bounding boxes.
    Args:
        box1: [x_min, y_min, x_max, y_max] for box 1
        box2: [x_min, y_min, x_max, y_max] for box 2
    Returns:
        distance: Euclidean distance between centers
    """
    center1 = [(box1[0] + box1[2]) / 2, (box1[1] + box1[3]) / 2]
    center2 = [(box2[0] + box2[2]) / 2, (box2[1] + box2[3]) / 2]
    return ((center1[0] - center2[0]) ** 2 + (center1[1] - center2[1]) ** 2) ** 0.5


## Create Dataset Classes
- **DetectionToClassificationDataset**: Converts detection data to classification-ready data.
  - Crops images based on bounding boxes.
  - Extracts supercategories and maps them to integer labels.
- **FilteredDataset**: Filters invalid samples from the dataset to ensure only high-quality data is used for training and evaluation.


In [6]:
class DetectionToClassificationDataset(Dataset):
    def __init__(self, image_dir, annotations, transform=None):
        """
        Args:
            image_dir (str): Base directory containing images.
            annotations (dict): Annotations with 'image_id' as keys and a list of bounding boxes.
            transform (callable, optional): Transformations to apply to the cropped images.
        """
        self.image_dir = Path(image_dir)
        self.annotations = annotations
        self.transform = transform
        self.image_paths = self._collect_image_paths()
        self.label_mapping = self._create_label_mapping()
        self.missing_files = []  # Track missing files
    def _collect_image_paths(self):
        """
        Collect all image paths from subdirectories and map them to their IDs.
        """
        image_paths = {}
        for image_path in self.image_dir.rglob("*.jpg"):  # Adjust extension if needed
            image_id = int(image_path.stem)  # Assuming filenames are integers matching 'image_id'
            image_paths[image_id] = image_path
        return image_paths

    def _create_label_mapping(self):
        """
        Create a mapping from supercategory to integer label.
        Filters out None values.
        """
        # Collect all supercategories from the annotations, ignoring None
        supercategories = {
            ann['supercategory']
            for anns in self.annotations.values()
            for ann in anns
            if ann['supercategory'] is not None
        }
        
        # Create and return the mapping
        return {label: idx for idx, label in enumerate(sorted(supercategories))}

    def __len__(self):
        return len(self.annotations)


    def __getitem__(self, idx):
        image_id = list(self.annotations.keys())[idx]
        
        if image_id not in self.image_paths:
            # Log missing file
            self.missing_files.append(image_id)
            return None  # Skip this file
        
        # Load image
        img_path = self.image_paths[image_id]
        image = Image.open(img_path).convert("RGB")

        # Get annotations for this image
        anns = self.annotations[image_id]
        
        # Select the detected_bbox closest to the actual_bbox
        closest_ann = None
        min_distance = float('inf')
        for ann in anns:
            detected_bbox = ann['detected_bbox']
            actual_bbox = ann['actual_bbox']
            distance = bbox_center_distance(detected_bbox, actual_bbox)
            if distance < min_distance:
                min_distance = distance
                closest_ann = ann

        # Process the selected annotation
        detected_bbox = closest_ann['detected_bbox']
        supercategory = closest_ann['supercategory']

        # Crop the detected bounding box
        x_min, y_min, x_max, y_max = map(int, detected_bbox)
        cropped_image = image.crop((x_min, y_min, x_max, y_max))

        # Apply transformations
        if self.transform:
            cropped_image = self.transform(cropped_image)

        # Get label
        label = self.label_mapping[supercategory]

        return cropped_image, label


In [7]:
def collate_fn(batch):
    """
    Custom collate function to filter out None values from the dataset.
    Args:
        batch (list): List of samples returned by the dataset.
    Returns:
        Filtered batch with None values removed.
    """
    # Filter out None samples
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:  # If all samples in batch are None
        return None  # Skip this batch entirely
    images, labels = zip(*batch)  # Separate images and labels
    return list(images), list(labels)


In [8]:

# Define transformations for input normalization and augmentation
transform = transforms.Compose([
   transforms.Resize((224, 224)),  # ResNet input size
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization

])


## Load and Preprocess Data
- Organize and preprocess datasets:
  - Apply transformations such as resizing, normalization, and cropping.
  - Structure data loaders to efficiently batch and shuffle the dataset during training.


In [9]:
# Initialize datasets
train_dataset = DetectionToClassificationDataset(train_image_dir, image_annotations_train, transform=transform)
val_dataset = DetectionToClassificationDataset(val_image_dir, image_annotations_val, transform=transform)
test_dataset = DetectionToClassificationDataset(test_dir, image_annotations_test, transform=transform)


In [10]:
def count_valid_samples(dataset):
    valid_samples = 0
    for image_id, anns in dataset.annotations.items():
        if not anns:
            continue
        if all(
            ann['detected_bbox'] is not None and
            ann['actual_bbox'] is not None and
            ann['supercategory'] is not None
            for ann in anns
        ):
            valid_samples += 1
    return valid_samples

# Usage
valid_train_samples = count_valid_samples(train_dataset)
valid_val_samples = count_valid_samples(val_dataset)

print(f"Number of valid samples in the training dataset: {valid_train_samples}")
print(f"Number of valid samples in the validation dataset: {valid_val_samples}")


Number of valid samples in the training dataset: 344
Number of valid samples in the validation dataset: 76


In [11]:
class FilteredDataset(DetectionToClassificationDataset):
    def __init__(self, original_dataset):
        """
        Filter out invalid samples from the original dataset.
        Args:
            original_dataset (Dataset): The dataset to filter.
        """
        self.original_dataset = original_dataset
        self.valid_samples = []

        # Iterate over the original dataset to collect valid samples
        for i in range(len(original_dataset)):
            try:
                sample = original_dataset[i]  # Fetch the sample
                if sample is not None:  # Check if the sample is valid
                    self.valid_samples.append(sample)  # Append valid sample
            except Exception as e:
                print(f"Error filtering index {i}: {e}")

    def __len__(self):
        return len(self.valid_samples)

    def __getitem__(self, idx):
        # Return the valid sample at the given index
        return self.valid_samples[idx]

    
# Wrap train and validation datasets

filtered_train_dataset = FilteredDataset(train_dataset)
filtered_val_dataset = FilteredDataset(val_dataset)
filtered_test_dataset = FilteredDataset(test_dataset)
print(f"Number of valid samples in the training dataset: {len(filtered_train_dataset)}")
print(f"Number of valid samples in the validation dataset: {len(filtered_val_dataset)}")
print(f"Number of valid samples in the test dataset: {len(filtered_test_dataset)}")


Error filtering index 0: 'NoneType' object is not subscriptable
Error filtering index 1: 'NoneType' object is not subscriptable
Error filtering index 2: 'NoneType' object is not subscriptable
Error filtering index 3: 'NoneType' object is not subscriptable
Error filtering index 4: 'NoneType' object is not subscriptable
Error filtering index 5: 'NoneType' object is not subscriptable
Error filtering index 6: 'NoneType' object is not subscriptable
Error filtering index 7: 'NoneType' object is not subscriptable
Error filtering index 8: 'NoneType' object is not subscriptable
Error filtering index 9: 'NoneType' object is not subscriptable
Error filtering index 10: 'NoneType' object is not subscriptable
Error filtering index 13: 'NoneType' object is not subscriptable
Error filtering index 15: 'NoneType' object is not subscriptable
Error filtering index 16: 'NoneType' object is not subscriptable
Error filtering index 19: 'NoneType' object is not subscriptable
Error filtering index 20: 'NoneType

In [12]:
print(f"Filtered train dataset length: {len(filtered_train_dataset)}")
print(f"Filtered val dataset length: {len(filtered_val_dataset)}")
print(f"Filtered test dataset length: {len(filtered_test_dataset)}")


Filtered train dataset length: 344
Filtered val dataset length: 76
Filtered test dataset length: 79


In [13]:

# DataLoaders
train_loader = DataLoader(
    filtered_train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=collate_fn
)
val_loader = DataLoader(
    filtered_val_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn
)

test_loader = DataLoader(filtered_test_dataset, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn)


In [14]:
# Inspect the first 5 samples from the dataset
for i in range(5):
    sample = filtered_train_dataset[i]
    image, label = sample
    print(f"Sample {i}: Image shape = {image.shape}, Label = {label}")


Sample 0: Image shape = torch.Size([3, 224, 224]), Label = 0
Sample 1: Image shape = torch.Size([3, 224, 224]), Label = 0
Sample 2: Image shape = torch.Size([3, 224, 224]), Label = 14
Sample 3: Image shape = torch.Size([3, 224, 224]), Label = 1
Sample 4: Image shape = torch.Size([3, 224, 224]), Label = 2


In [15]:
# Test DataLoader output
for batch in train_loader:
    images, labels = batch
    print(f"Batch size: {len(images)}")
    print(f"Labels: {labels[:5]}")
    break


Batch size: 32
Labels: [14, 8, 18, 11, 5]


## Define and Train the Model
- Use a pretrained ResNet architecture for classification.
  - Replace the final fully connected layer to match the number of classes in the dataset.
- Train the model on the processed data and validate its performance after each epoch.

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet = models.resnet18(pretrained=True)

# Modify the final layer for your number of classes
num_classes = len(train_dataset.label_mapping)  # Number of unique labels
resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)

resnet = resnet.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.parameters(), lr=1e-4)

# Training loop
epochs = 10
for epoch in range(epochs):
    resnet.train()
    running_loss = 0.0
    for batch in train_loader:
        if batch is None:  # Skip empty batches
            continue

        images, labels = batch
        images, labels = torch.stack(images).to(device), torch.tensor(labels).to(device)

        optimizer.zero_grad()
        outputs = resnet(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}")





Epoch 1/10, Loss: 2.8718
Epoch 2/10, Loss: 1.2786
Epoch 3/10, Loss: 0.6417
Epoch 4/10, Loss: 0.3163
Epoch 5/10, Loss: 0.1688
Epoch 6/10, Loss: 0.0989
Epoch 7/10, Loss: 0.0610
Epoch 8/10, Loss: 0.0485
Epoch 9/10, Loss: 0.0375
Epoch 10/10, Loss: 0.0318


In [17]:
print(f"Number of samples in filtered_val_dataset: {len(filtered_val_dataset)}")


Number of samples in filtered_val_dataset: 76


## Evaluate the Model
- Compute metrics like accuracy, precision, and recall on the validation and test datasets.
- Log and visualize evaluation results to understand the model's performance.

In [18]:

# Validation loop
resnet.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = torch.stack(images).to(device), torch.tensor(labels).to(device)
        outputs = resnet(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Validation Accuracy: {100 * correct / total:.2f}%")

Validation Accuracy: 3.95%


In [19]:
resnet.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = torch.stack(images).to(device), torch.tensor(labels).to(device)
        outputs = resnet(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Test Accuracy: 15.19%


In [20]:
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, classification_report, confusion_matrix
import numpy as np

# Validation/Test Loop
resnet.eval()
all_true_labels = []
all_predicted_labels = []

with torch.no_grad():
    for i, batch in enumerate(test_loader):  # Replace with val_loader for validation
        try:
            images, labels = batch
            images, labels = torch.stack(images).to(device), torch.tensor(labels).to(device)

            # Forward pass
            outputs = resnet(images)
            _, predicted = torch.max(outputs, 1)

            # Collect true and predicted labels
            all_true_labels.extend(labels.cpu().numpy())  # Convert to numpy for scikit-learn
            all_predicted_labels.extend(predicted.cpu().numpy())

        except Exception as e:
            print(f"Error in batch {i}: {e}")
            continue

# Convert lists to numpy arrays
all_true_labels = np.array(all_true_labels)
all_predicted_labels = np.array(all_predicted_labels)

# 1. Accuracy
accuracy = accuracy_score(all_true_labels, all_predicted_labels)
print(f"Accuracy: {accuracy:.2f}")

# 2. Precision
precision = precision_score(all_true_labels, all_predicted_labels, average='weighted')  # Weighted for multiclass
print(f"Precision: {precision:.2f}")

# 3. Recall
recall = recall_score(all_true_labels, all_predicted_labels, average='weighted')  # Weighted for multiclass
print(f"Recall: {recall:.2f}")

# Full Classification Report (Optional)
print("\nClassification Report:")
print(classification_report(all_true_labels, all_predicted_labels))

# Confusion Matrix (Optional)
print("\nConfusion Matrix:")
cm = confusion_matrix(all_true_labels, all_predicted_labels)
print(cm)


Accuracy: 0.15
Precision: 0.11
Recall: 0.15

Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.00      0.00      0.00         1
           2       0.62      0.80      0.70        10
           3       0.33      1.00      0.50         2
           4       0.00      0.00      0.00        13
           5       0.15      0.20      0.17        10
           6       0.00      0.00      0.00         1
           7       0.00      0.00      0.00         5
           8       0.00      0.00      0.00         2
           9       0.00      0.00      0.00         3
          10       0.00      0.00      0.00         3
          11       0.00      0.00      0.00         2
          12       0.00      0.00      0.00        14
          13       0.00      0.00      0.00         6
          14       0.00      0.00      0.00         1
          15       0.00      0.00      0.00         3
          16 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
