# Task 2b: Attribute Classification (Mounting Type)

**Student:** Stefania Livori  
**Attribute:** Mounting Type (Pole-mounted vs Wall-mounted)  
**Model:** Faster R-CNN (ResNet50 FPN)

This notebook implements a detector specifically for classifying the mounting type of traffic signs. It covers:
1.  **Dataset Preparation**: Custom `MountingDataset` parsing "mounting" attributes.
2.  **Model Training**: Training to detect and classify 'Pole-mounted' vs 'Wall-mounted'.
3.  **Evaluation**: Calculating Mean Average Precision (mAP).
4.  **Inference & Analytics**: Counting the distribution of mounting types.

In [None]:
# Import necessary libraries
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from PIL import Image
import json
import os
import matplotlib.pyplot as plt
from stefania_livori_utils import *

# Ensure reproducible results
torch.manual_seed(42)

## 1. Dataset Preparation

We define the `MountingDataset` class. This class parses the Label Studio JSON export to extract the "mounting" attribute for each sign.
**Robustness Update:** This version correctly handles images with multiple signs having different mounting types by checking for linked IDs (`from_id`).

In [None]:
# Mounting Classes mapping
MOUNTING_CLASSES = {
    1: "Pole-mounted",
    2: "Wall-mounted"
}

NUM_CLASSES = len(MOUNTING_CLASSES) + 1 # +1 for background

class MountingDataset(Dataset):
    # root - folder containing images
    # ann_file - Label Studio JSON export
    # transforms - image preprocessing
    # preload - load images into memory for faster training
    def __init__(self, root, ann_file, transforms=None, preload=True):
        self.root = root
        self.transforms = transforms

        # Load the annotations
        with open(ann_file) as f:
            self.tasks = json.load(f)

        self.preload = preload
        # Preload images into memory for faster training
        if preload:
            self.loaded_images = []
            for i, task in enumerate(self.tasks):
                if "data" in task and "image" in task["data"]:
                    img_rel_path = task["data"]["image"].replace("/data/upload/", "")
                    img_path = os.path.join(self.root, img_rel_path)
                elif "file_name" in task:
                     img_path = os.path.join(self.root, task["file_name"])
                else:
                     self.loaded_images.append(None)
                     continue

                try:
                    with Image.open(img_path) as img:
                        self.loaded_images.append(img.convert("RGB").copy())
                except FileNotFoundError:
                    raise FileNotFoundError(f"Warning: Image not found {img_path}")

        self.map = {
            "Pole-mounted": 1,
            "Wall-mounted": 2
        }

    def __len__(self):
        return len(self.tasks)

    def __getitem__(self, idx):
        task = self.tasks[idx]

        if self.preload:
            img = self.loaded_images[idx]
            # Handle missing images by skipping
            if img is None: 
                raise FileNotFoundError(f"Image not found at index {idx}")
            img = img.copy()
        else:
            if "data" in task and "image" in task["data"]:
                img_rel_path = task["data"]["image"].replace("/data/upload/", "")
                img_path = os.path.join(self.root, img_rel_path)
            else:
                 img_path = ""
            img = Image.open(img_path).convert("RGB")

        # Load the bounding boxes and labels
        boxes, labels = [], []

        if "annotations" in task:
            for ann in task["annotations"]:
                # Robust parsing for multiple objects/attributes
                # id -> rect_info
                rect_map = {}    
                # rect_id -> mount_class
                mount_map = {}    
                global_mount = None

                if "result" in ann:
                    # Pass 1: Gather all regions and choices
                    for r in ann["result"]:
                        rid = r.get("id")
                        
                        if r["type"] == "rectanglelabels":
                            rect_map[rid] = r
                        
                        elif r["from_name"] == "mounting":
                            if "value" in r and "choices" in r["value"] and len(r["value"]["choices"]) > 0:
                                val = r["value"]["choices"][0]
                                mount_label = self.map.get(val, None)
                                
                                if mount_label is not None:
                                    # Check linkage via from_id
                                    if "from_id" in r:
                                        mount_map[r["from_id"]] = mount_label
                                    else:
                                        # Fallback to global if not linked
                                        global_mount = mount_label
                
                    # Pass 2: Generate prediction targets
                    iw, ih = img.size
                    for rid, r in rect_map.items():
                        # Priority: Specific Link > Global Fallback
                        m = mount_map.get(rid, global_mount)
                        
                        if m is not None:
                            x = r["value"]["x"]
                            y = r["value"]["y"]
                            w = r["value"]["width"]
                            h = r["value"]["height"]
                            
                            boxes.append([
                                x/100*iw, y/100*ih,
                                (x+w)/100*iw, (y+h)/100*ih
                            ])
                            labels.append(m)

        if not boxes:
             boxes = torch.zeros((0, 4), dtype=torch.float32)
             labels = torch.zeros((0,), dtype=torch.int64)
        else:
             boxes = torch.tensor(boxes, dtype=torch.float32)
             labels = torch.tensor(labels, dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx])
        }

        if self.transforms:
            img = self.transforms(img)

        return img, target

### Initialize Dataset and DataLoaders

In [None]:
transform = T.Compose([T.ToTensor()])

# Update paths
DATA_DIR = "label-studio/label-studio/media/upload"
ANNOTATION_FILE = "json_stefania.json"

dataset = MountingDataset(
    root=DATA_DIR,
    ann_file=ANNOTATION_FILE,
    transforms=transform
)

# Filter
if dataset.preload and any(img is None for img in dataset.loaded_images):
    print("Warning: Some images failed to load.")

# Split dataset
total_size = len(dataset)
val_size = int(0.2 * total_size)
train_size = total_size - val_size

train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

print(f"Train size: {len(train_ds)}, Val size: {len(val_ds)}")

## 2. Model Configuration

In [None]:
device = get_device()
model = get_faster_rcnn(NUM_CLASSES).to(device)

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# Might need to try different step sizes
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=3, gamma=0.1
)

## 3. Training Loop

In [None]:
num_epochs = 3

for epoch in range(num_epochs):
    print(f"\nStarting Epoch {epoch+1}/{num_epochs}")
    loss = train_one_epoch(model, train_loader, optimizer, device)
    scheduler.step()
    f1 = f1_score_by_iou(model, val_loader, device)
    print(f"Epoch {epoch+1} | Loss {loss:.4f} | f1-score {f1:.4f}")

## 4. Analytics & Visualization

In [None]:
model.eval()

class_counts = {name: 0 for name in MOUNTING_CLASSES.values()}

with torch.no_grad():
    for img, target in val_ds:
        prediction = model([img.to(device)])[0]
        
        keep = prediction["scores"] > 0.5
        labels = prediction["labels"][keep]
        
        for l in labels:
            name = MOUNTING_CLASSES.get(l.item(), "Unknown")
            if name in class_counts:
                class_counts[name] += 1

print("Detected Mounting Types Distribution:")
for name, count in class_counts.items():
    print(f"{name}: {count}")

### Visualisation

In [None]:
torch.save(model.state_dict(), "d:/Uni/Computer-Vision-2/stefania_livori_mounting.pt")
print("\nVisualizing Sample Predictions:")
for i in range(len(val_ds)):
    img, target = val_ds[i]
    with torch.no_grad():
        prediction = model([img.to(device)])[0]
    
    gt_labels = [MOUNTING_CLASSES.get(l.item(), "Unknown") for l in target['labels']]
    print(f"Sample {i+1} GT: {gt_labels}")
    visualize_predictions(img, prediction, MOUNTING_CLASSES, threshold=0.5)