In [None]:
import os
import glob
import xml.etree.ElementTree as ET
from collections import defaultdict

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm
from skmultilearn.model_selection import iterative_train_test_split

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split, Subset, ConcatDataset
import torchvision
from torchvision import models, transforms
from torchvision.datasets import VOCDetection
from torchvision.ops import nms
from torchvision.transforms.functional import to_pil_image
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
transform = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.ToTensor()
])

train_dataset = VOCDetection(root='./train', year='2007', image_set='trainval', download=True, transform=transform)
val_dataset = VOCDetection(root='./test', year='2007', image_set='test', download=True, transform=transform)

In [None]:
CLASS_NAMES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
               'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
               'horse', 'motorbike', 'person', 'pottedplant',
               'sheep', 'sofa', 'train', 'tvmonitor']
CLASS_TO_IDX = {cls: i for i, cls in enumerate(CLASS_NAMES)}

In [None]:
class VOCDataset(Dataset):
    def __init__(self, root_dir, image_set='trainval', year='2007', transform=None, grid_size=7, num_boxes=2):
        self.root_dir = root_dir
        self.transform = transform
        self.grid_size = grid_size
        self.num_boxes = num_boxes
        self.image_dir = os.path.join(root_dir, 'JPEGImages')
        self.annotation_dir = os.path.join(root_dir, 'Annotations')
        list_path = os.path.join(root_dir, 'ImageSets', 'Main', f'{image_set}.txt')
        with open(list_path) as f:
            self.image_ids = [line.strip() for line in f]

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        img_path = os.path.join(self.image_dir, f"{image_id}.jpg")
        ann_path = os.path.join(self.annotation_dir, f"{image_id}.xml")

        image = Image.open(img_path).convert("RGB")
        original_size = image.size  # (width, height)
        boxes, labels = self.parse_voc_xml(ann_path, original_size)

        if self.transform:
            image = self.transform(image)

        # Scale boxes to 448x448 image size
        scaled_boxes = []
        for box in boxes:
            xmin, ymin, xmax, ymax = box
            xmin = xmin / original_size[0] * 448
            xmax = xmax / original_size[0] * 448
            ymin = ymin / original_size[1] * 448
            ymax = ymax / original_size[1] * 448
            scaled_boxes.append([xmin, ymin, xmax, ymax])

        target = self.encode_boxes(scaled_boxes, labels)

        # Also return original boxes and labels for mAP calculation
        original_boxes = torch.tensor(scaled_boxes, dtype=torch.float32)
        original_labels = torch.tensor(labels, dtype=torch.int64)

        return image, target, original_boxes, original_labels, image_id

    def parse_voc_xml(self, xml_path, image_size):
        tree = ET.parse(xml_path)
        root = tree.getroot()
        boxes, labels = [], []

        for obj in root.iter('object'):
            label = obj.find('name').text
            bbox = obj.find('bndbox')
            xmin = int(float(bbox.find('xmin').text))
            ymin = int(float(bbox.find('ymin').text))
            xmax = int(float(bbox.find('xmax').text))
            ymax = int(float(bbox.find('ymax').text))

            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(CLASS_TO_IDX[label])

        return boxes, labels

    def encode_boxes(self, boxes, labels):
        target = torch.zeros((self.grid_size, self.grid_size, self.num_boxes * 5 + len(CLASS_NAMES)))
        img_w, img_h = 448, 448

        for box, label in zip(boxes, labels):
            xmin, ymin, xmax, ymax = box
            x_center = (xmin + xmax) / 2.0 / img_w
            y_center = (ymin + ymax) / 2.0 / img_h
            box_w = (xmax - xmin) / img_w
            box_h = (ymax - ymin) / img_h

            grid_x = int(x_center * self.grid_size)
            grid_y = int(y_center * self.grid_size)

            if grid_x >= self.grid_size or grid_y >= self.grid_size:
                continue

            # Convert to cell-relative coordinates
            cell_x = x_center * self.grid_size - grid_x
            cell_y = y_center * self.grid_size - grid_y

            for b in range(self.num_boxes):
                target[grid_y, grid_x, b*5:(b+1)*5] = torch.tensor([cell_x, cell_y, box_w, box_h, 1])

            target[grid_y, grid_x, self.num_boxes * 5 + label] = 1

        return target

In [None]:
trainval_dataset = VOCDataset(
    root_dir='train/VOCdevkit/VOC2007',
    image_set='trainval',
    transform=transform
)

test_root = 'test/VOCdevkit/VOC2007'

test_dataset = VOCDataset(
    root_dir=test_root,
    image_set='test',
    transform=transform
)

In [None]:
def get_image_labels(dataset):
    image_labels = []
    for i in range(len(dataset)):
        try:
            _, _, _, labels, _ = dataset[i]  
            unique_labels = set(labels) 
        except Exception as e:
            print(f"Error at index {i}: {e}")
            unique_labels = set()
        image_labels.append(unique_labels)
    return image_labels

def labels_to_multihot(labels_list, num_classes=20):
    multihot = np.zeros((len(labels_list), num_classes), dtype=int)
    for idx, labels in enumerate(labels_list):
        for label in labels:
            multihot[idx, label] = 1
    return multihot

full_dataset = ConcatDataset([trainval_dataset, test_dataset])

image_labels = get_image_labels(full_dataset)
X = np.arange(len(full_dataset)).reshape(-1, 1)
Y = labels_to_multihot(image_labels, num_classes=len(CLASS_NAMES))

In [None]:
np.random.seed(42)

In [None]:
# 80% train, 20% temp
X_train, Y_train, X_temp, Y_temp = iterative_train_test_split(X, Y, test_size=0.2)
# 10% val, 10% test
X_val, Y_val, X_test, Y_test = iterative_train_test_split(X_temp, Y_temp, test_size=0.5)

train_indices = X_train.flatten().tolist()
val_indices = X_val.flatten().tolist()
test_indices = X_test.flatten().tolist()

train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)
test_dataset = Subset(full_dataset, test_indices)

def custom_collate(batch):
    images = torch.stack([item[0] for item in batch])
    encoded_targets = torch.stack([item[1] for item in batch]) 
    original_boxes = [item[2] for item in batch]      
    original_labels = [item[3] for item in batch]     
    image_ids = [item[4] for item in batch]          
    return images, encoded_targets, original_boxes, original_labels, image_ids

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2, collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2, collate_fn=custom_collate)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2, collate_fn=custom_collate)

In [None]:
def show_image_with_boxes(image, target, grid_size=7, class_names=CLASS_NAMES, conf_thresh=0.5):
    image = image.permute(1, 2, 0).numpy()
    h, w = image.shape[:2]

    fig, ax = plt.subplots(1)
    ax.imshow(image)

    for i in range(grid_size):
        for j in range(grid_size):
            cell = target[i, j]
            for b in range(2):  # num_boxes
                conf = cell[b*5 + 4].item()
                if conf > conf_thresh:
                    x, y, box_w, box_h = cell[b*5:b*5+4]
                    x = (j + x.item()) / grid_size * w
                    y = (i + y.item()) / grid_size * h
                    bw = box_w.item() * w
                    bh = box_h.item() * h

                    # Draw bounding box
                    rect = patches.Rectangle((x - bw/2, y - bh/2), bw, bh, linewidth=2, edgecolor='r', facecolor='none')
                    ax.add_patch(rect)

                    # Find class
                    class_probs = cell[2*5:]  
                    class_idx = class_probs.argmax().item()
                    class_name = class_names[class_idx]
                    score = class_probs[class_idx].item()

                    # Label
                    ax.text(x - bw/2, y - bh/2 - 5, f'{class_name} ({score:.2f})', color='white',
                            fontsize=8, backgroundcolor='red')

    plt.axis('off')
    plt.show()

img, tgt, _, _, _ = test_dataset[4]
show_image_with_boxes(img, tgt)

In [None]:
# Training set
print("Train Dataset:")
print("Number of images:", len(train_dataset))
img, target, _, _, _ = train_dataset[0]
print("Image shape:", img.shape)
print("Target shape", target.shape)

# Validation set
print("\nValidation Dataset:")
print("Number of images:", len(val_dataset))
img, target, _, _, _ = val_dataset[0]
print("Image shape:", img.shape)
print("Target shape", target.shape)

# Test set
print("\nTest Dataset:")
print("Number of images:", len(test_dataset))
img, target, _, _, _ = test_dataset[0]
print("Image shape:", img.shape)
print("Target shape", target.shape)

In [None]:
class YOLOHead(nn.Module):
    def __init__(self, in_channels, grid_size=7, num_boxes=2, num_classes=20):
        super(YOLOHead, self).__init__()
        self.grid_size = grid_size
        self.num_boxes = num_boxes
        self.num_classes = num_classes
        self.output_dim = grid_size * grid_size * (num_boxes * 5 + num_classes)

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 1024, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.ReLU()
        )

        self.pool = nn.AdaptiveAvgPool2d((self.grid_size, self.grid_size))

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024 * grid_size * grid_size, 4096),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(4096, self.output_dim)
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        x = self.fc(x)
        return x.view(-1, self.grid_size, self.grid_size, self.num_boxes * 5 + self.num_classes)

class YOLOResNet(nn.Module):
    def __init__(self, grid_size=7, num_boxes=2, num_classes=20):
        super(YOLOResNet, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-2])
        self.yolo_head = YOLOHead(in_channels=2048, grid_size=grid_size, num_boxes=num_boxes, num_classes=num_classes)

    def forward(self, x):
        features = self.feature_extractor(x)
        return self.yolo_head(features)

In [None]:
class YOLOLoss(nn.Module):
    def __init__(self, S=7, B=2, C=20, lambda_coord=5, lambda_noobj=0.5):
        super(YOLOLoss, self).__init__()
        self.mse = nn.MSELoss(reduction="sum")

        self.S = S
        self.B = B  
        self.C = C  
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj

    def forward(self, predictions, target):
        N = predictions.size(0)
        predictions = predictions.view(N, self.S, self.S, self.C + self.B * 5)
        target = target.view(N, self.S, self.S, self.C + self.B * 5)

        pred_classes = predictions[..., self.B * 5:]                         
        pred_boxes = predictions[..., :self.B * 5].view(N, self.S, self.S, self.B, 5)  

        target_classes = target[..., self.B * 5:]
        target_boxes = target[..., :self.B * 5].view(N, self.S, self.S, self.B, 5)

        object_mask = target_boxes[..., 4] > 0  

        # ===================== #
        #   CLASSIFICATION LOSS
        # ===================== #
        has_object = (target_boxes[..., 4].sum(dim=-1) > 0)  
        class_loss = self.mse(
            pred_classes[has_object],
            target_classes[has_object]
        )

        # ===================== #
        #   BOX COORDINATE LOSS
        # ===================== #
        box_pred = pred_boxes[object_mask]  
        box_target = target_boxes[object_mask]

        coord_loss = self.mse(box_pred[..., 0:2], box_target[..., 0:2])  

        pred_wh = torch.sign(box_pred[..., 2:4]) * torch.sqrt(torch.abs(box_pred[..., 2:4]) + 1e-6)
        target_wh = torch.sqrt(box_target[..., 2:4])
        coord_loss += self.mse(pred_wh, target_wh)

        coord_loss *= self.lambda_coord

        # ===================== #
        #   OBJECT CONFIDENCE LOSS
        # ===================== #
        object_loss = self.mse(box_pred[..., 4], box_target[..., 4])  # confidence

        # ===================== #
        #   NO OBJECT LOSS
        # ===================== #
        noobj_mask = ~object_mask
        noobj_pred = pred_boxes[noobj_mask][..., 4]
        noobj_target = target_boxes[noobj_mask][..., 4]
        noobj_loss = self.mse(noobj_pred, noobj_target)
        noobj_loss *= self.lambda_noobj

        # ===================== #
        #   TOTAL LOSS
        # ===================== #
        total_loss = class_loss + coord_loss + object_loss + noobj_loss
        return total_loss / N

In [None]:
def convert_yolo_output_to_boxes(output, grid_size=7, num_boxes=2, num_classes=20, conf_thresh=0.25, iou_thresh=0.5):
    """
    Convert YOLO model output to bounding boxes, class predictions, and confidence scores.
    Applies Non-Maximum Suppression (NMS).
    """
    batch_size = output.shape[0]
    all_boxes = []
    all_scores = []
    all_class_idxs = []

    for b in range(batch_size):
        boxes = []
        scores = []
        class_idxs = []

        pred = output[b]  

        for i in range(grid_size):
            for j in range(grid_size):
                cell_pred = pred[i, j]

                for box_idx in range(num_boxes):
                    box_start = box_idx * 5
                    confidence = cell_pred[box_start + 4].item()

                    if confidence < conf_thresh:
                        continue

                    x_cell, y_cell = cell_pred[box_start:box_start+2]
                    w, h = cell_pred[box_start+2:box_start+4]

                    x_center = (j + x_cell) / grid_size
                    y_center = (i + y_cell) / grid_size
                    w = w.clamp(0, 1)
                    h = h.clamp(0, 1)

                    x1 = (x_center - w/2).clamp(0, 1)
                    y1 = (y_center - h/2).clamp(0, 1)
                    x2 = (x_center + w/2).clamp(0, 1)
                    y2 = (y_center + h/2).clamp(0, 1)

                    box = [x1.item() * 448, y1.item() * 448, x2.item() * 448, y2.item() * 448]

                    class_scores = cell_pred[num_boxes*5:]
                    class_idx = torch.argmax(class_scores).item()
                    class_score = class_scores[class_idx].item()

                    score = class_score * confidence

                    boxes.append(box)
                    scores.append(score)
                    class_idxs.append(class_idx)

        boxes = torch.tensor(boxes)
        scores = torch.tensor(scores)
        class_idxs = torch.tensor(class_idxs)

        # Apply NMS per class
        keep_boxes = []
        keep_scores = []
        keep_classes = []
        for cls in range(num_classes):
            inds = (class_idxs == cls).nonzero(as_tuple=True)[0]
            if inds.numel() == 0:
                continue
            cls_boxes = boxes[inds]
            cls_scores = scores[inds]
            keep = nms(cls_boxes, cls_scores, iou_thresh)
            keep_boxes.append(cls_boxes[keep])
            keep_scores.append(cls_scores[keep])
            keep_classes.append(torch.full((len(keep),), cls, dtype=torch.int64))

        if keep_boxes:
            all_boxes.append(torch.cat(keep_boxes, dim=0))
            all_scores.append(torch.cat(keep_scores, dim=0))
            all_class_idxs.append(torch.cat(keep_classes, dim=0))
        else:
            all_boxes.append(torch.zeros((0, 4)))
            all_scores.append(torch.zeros(0))
            all_class_idxs.append(torch.zeros(0, dtype=torch.int64))

    return all_boxes, all_scores, all_class_idxs

In [None]:
def calculate_iou(box1, box2):

    x1 = torch.max(box1[0], box2[0])
    y1 = torch.max(box1[1], box2[1])
    x2 = torch.min(box1[2], box2[2])
    y2 = torch.min(box1[3], box2[3])

    if x2 < x1 or y2 < y1:
        return 0.0

    intersection_area = (x2 - x1) * (y2 - y1)

    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

    iou = intersection_area / float(box1_area + box2_area - intersection_area)
    return iou

def calculate_map(
    all_pred_boxes_batch,  
    all_pred_scores_batch, 
    all_pred_classes_batch,
    all_gt_boxes_batch,   
    all_gt_classes_batch,  
    num_classes=20,
    iou_threshold=0.5
):
    average_precisions = {}
    epsilon = 1e-8 

    for class_idx in range(num_classes):

        detections_for_class = []  
        ground_truths_for_class_by_image = defaultdict(list) 
        total_gt_for_class = 0

        for i in range(len(all_pred_boxes_batch)): 
       
            gt_boxes_img = all_gt_boxes_batch[i]
            gt_classes_img = all_gt_classes_batch[i]

            class_gt_indices = (gt_classes_img == class_idx).nonzero(as_tuple=True)[0]
            for gt_idx in class_gt_indices:
                ground_truths_for_class_by_image[i].append({
                    'box': gt_boxes_img[gt_idx],
                    'used': False
                })
                total_gt_for_class += 1

            pred_boxes_img = all_pred_boxes_batch[i]
            pred_scores_img = all_pred_scores_batch[i]
            pred_classes_img = all_pred_classes_batch[i]

            class_pred_indices = (pred_classes_img == class_idx).nonzero(as_tuple=True)[0]
            for pred_idx in class_pred_indices:
                detections_for_class.append({
                    'score': pred_scores_img[pred_idx].item(), 
                    'image_idx': i,
                    'box': pred_boxes_img[pred_idx]
                })

        if total_gt_for_class == 0:
            average_precisions[class_idx] = 0.0
            continue

        if not detections_for_class:
            average_precisions[class_idx] = 0.0
            continue

        detections_for_class.sort(key=lambda x: x['score'], reverse=True)

        num_detections = len(detections_for_class)
        tp_arr = torch.zeros(num_detections)
        fp_arr = torch.zeros(num_detections)

        for det_idx, det in enumerate(detections_for_class):
            img_idx_of_det = det['image_idx']
            pred_box = det['box']

            gt_objects_in_image = ground_truths_for_class_by_image[img_idx_of_det]

            best_iou = -1.0
            best_gt_match_idx = -1

            for gt_obj_idx, gt_obj in enumerate(gt_objects_in_image):
                iou = calculate_iou(pred_box, gt_obj['box'])
                if iou > best_iou:
                    best_iou = iou
                    best_gt_match_idx = gt_obj_idx

            if best_iou >= iou_threshold:
          
                if not gt_objects_in_image[best_gt_match_idx]['used']:
                    tp_arr[det_idx] = 1
                    gt_objects_in_image[best_gt_match_idx]['used'] = True
                else: 
                    fp_arr[det_idx] = 1
            else: 
                fp_arr[det_idx] = 1

        # 4. Calculate Precision and Recall
        tp_cumsum = torch.cumsum(tp_arr, dim=0)
        fp_cumsum = torch.cumsum(fp_arr, dim=0)

        recalls = tp_cumsum / (total_gt_for_class + epsilon)
        precisions = tp_cumsum / (tp_cumsum + fp_cumsum + epsilon)

        precisions = torch.cat((torch.tensor([1.0]), precisions)) 
        recalls = torch.cat((torch.tensor([0.0]), recalls))     

        for i in range(len(precisions) - 2, -1, -1): 
            precisions[i] = torch.max(precisions[i], precisions[i+1])

        recall_changes_indices = torch.where(recalls[1:] != recalls[:-1])[0]

        ap = torch.sum((recalls[recall_changes_indices + 1] - recalls[recall_changes_indices]) * precisions[recall_changes_indices + 1])

        average_precisions[class_idx] = ap.item()

    # Calculate mAP
    valid_aps = [ap for ap in average_precisions.values() if not torch.isnan(torch.tensor(ap))] 
    if not valid_aps:
         mean_ap = 0.0
    else:
        mean_ap = sum(valid_aps) / len(valid_aps) if valid_aps else 0.0

    return mean_ap, average_precisions

def calculate_precision_recall(pred_boxes, pred_scores, pred_classes, gt_boxes, gt_classes, num_classes=20, iou_threshold=0.5):

    total_tp = 0
    total_fp = 0
    total_fn = 0

    for i in range(len(pred_boxes)):
     
        boxes = pred_boxes[i]
        scores = pred_scores[i]
        classes = pred_classes[i]

        gt_b = gt_boxes[i]
        gt_c = gt_classes[i]

        gt_detected = torch.zeros(len(gt_b))

        for j in range(len(boxes)):
        
            max_iou = 0
            max_idx = -1

            for k in range(len(gt_b)):
     
                if gt_c[k] != classes[j]:
                    continue

                iou = calculate_iou(boxes[j], gt_b[k])
                if iou > max_iou:
                    max_iou = iou
                    max_idx = k

            if max_iou >= iou_threshold and gt_detected[max_idx] == 0:
                total_tp += 1
                gt_detected[max_idx] = 1 
            else:
                total_fp += 1

        total_fn += (1 - gt_detected).sum().item()

    precision = total_tp / (total_tp + total_fp + 1e-8)
    recall = total_tp / (total_tp + total_fn + 1e-8)

    return precision, recall

In [None]:
def evaluate_model(model, data_loader, device):
    model.eval()

    all_pred_boxes = []
    all_pred_scores = []
    all_pred_classes = []
    all_gt_boxes = []
    all_gt_classes = []

    with torch.no_grad():
        for images, _, gt_boxes, gt_classes, _ in tqdm(data_loader, desc="Evaluating"):
            images = images.to(device)

            # Forward pass
            outputs = model(images)

            # Convert outputs to boxes
            pred_boxes, pred_scores, pred_classes = convert_yolo_output_to_boxes(outputs)

            # Append to lists
            all_pred_boxes.extend(pred_boxes)
            all_pred_scores.extend(pred_scores)
            all_pred_classes.extend(pred_classes)
            all_gt_boxes.extend(gt_boxes)
            all_gt_classes.extend(gt_classes)

    # Calculate metrics
    map_score, ap_per_class = calculate_map(all_pred_boxes, all_pred_scores, all_pred_classes,
                                           all_gt_boxes, all_gt_classes)
    precision, recall = calculate_precision_recall(all_pred_boxes, all_pred_scores, all_pred_classes,
                                                 all_gt_boxes, all_gt_classes)

    return map_score, precision, recall, ap_per_class

In [None]:
def show_image_with_boxes_nms(image, target, grid_size=7, class_names=CLASS_NAMES, conf_thresh=0.5):
    image = image.permute(1, 2, 0).cpu().numpy()
    h, w = image.shape[:2]

    fig, ax = plt.subplots(1)
    ax.imshow(image)

    for i in range(grid_size):
        for j in range(grid_size):
            cell = target[i, j]

            # Get both boxes and their confidences
            box1 = cell[0:5]
            box2 = cell[5:10]
            conf1 = box1[4].item()
            conf2 = box2[4].item()

            # Choose the box with the higher confidence
            if conf1 > conf2:
                box = box1
                conf = conf1
            else:
                box = box2
                conf = conf2

            if conf < conf_thresh:
                continue

            # Box coordinates
            x, y, box_w, box_h = box[:4]
            x = ((j + x) / grid_size * w).item()
            y = ((i + y) / grid_size * h).item()
            bw = (box_w * w).item()
            bh = (box_h * h).item()

            # Draw bounding box
            rect = patches.Rectangle((x - bw/2, y - bh/2), bw, bh, linewidth=2, edgecolor='r', facecolor='none')
            ax.add_patch(rect)

            # Get class prediction
            class_probs = cell[10:]
            if class_probs.numel() == 0:
                continue

            class_idx = class_probs.argmax().item()
            class_score = class_probs[class_idx].item()
            class_name = class_names[class_idx]

            ax.text(x - bw/2, y - bh/2 - 5, f'{class_name} ({class_score:.2f})', color='white',
                    fontsize=8, backgroundcolor='red')

    plt.axis('off')
    plt.show()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = YOLOResNet().to(device)
#model.load_state_dict(torch.load('/content/yolo_model_weights_2.pth', map_location=device))

criterion = YOLOLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True)

num_epochs = 20
train_losses = []
val_losses = []
val_maps = []
val_precisions = []
val_recalls = []

In [None]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)

    for images, targets, _, _, _ in loop:
        images = images.to(device)
        targets = targets.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, targets)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # VALIDATION
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for images, targets, _, _, _ in val_loader:
            images = images.to(device)
            targets = targets.to(device)

            outputs = model(images)
            loss = criterion(outputs, targets)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    # EVALUATION
    print("\nCalculating validation metrics...")
    map_score, precision, recall, ap_per_class = evaluate_model(model, val_loader, device)
    val_maps.append(map_score)
    val_precisions.append(precision)
    val_recalls.append(recall)

    print(f"Epoch [{epoch+1}/{num_epochs}]:")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f}")
    print(f"  Val mAP@0.5: {map_score:.4f}")
    print(f"  Val Precision: {precision:.4f}")
    print(f"  Val Recall: {recall:.4f}")

    if isinstance(scheduler, ReduceLROnPlateau):
      scheduler.step(map_score)

In [None]:
torch.save(model.state_dict(), "yolo_model_weights.pth")

In [None]:
print("\n=== Final Evaluation on Test Set ===")
test_map, test_precision, test_recall, test_ap_per_class = evaluate_model(model, test_loader, device)

print(f"Test mAP@0.5: {test_map:.4f}")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")

print("\nClass-wise AP scores:")
for cls_idx, ap in test_ap_per_class.items():
    print(f"{CLASS_NAMES[cls_idx]}: {ap:.4f}")

In [None]:
def visualize_prediction_with_metrics(model, image, target, original_boxes, original_labels, iou_threshold=0.5, score_threshold=0.3):
    model.eval()
    with torch.no_grad():
        # Get model prediction
        input_tensor = image.unsqueeze(0).to(device)
        output = model(input_tensor).squeeze(0)

        # Convert output to boxes
        pred_boxes, pred_scores, pred_classes = convert_yolo_output_to_boxes(output.unsqueeze(0))
        pred_boxes = pred_boxes[0]
        pred_scores = pred_scores[0]
        pred_classes = pred_classes[0]

        # Filter by score threshold
        keep = pred_scores > score_threshold
        pred_boxes = pred_boxes[keep]
        pred_scores = pred_scores[keep]
        pred_classes = pred_classes[keep]

        # Apply NMS
        if len(pred_boxes) > 0:
            keep_idx = nms(pred_boxes, pred_scores, iou_threshold)
            pred_boxes = pred_boxes[keep_idx]
            pred_scores = pred_scores[keep_idx]
            pred_classes = pred_classes[keep_idx]

        # Visualization code (same as before)
        image_np = image.permute(1, 2, 0).cpu().numpy()
        h, w = image_np.shape[:2]

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))

        # Ground truth
        ax1.imshow(image_np)
        ax1.set_title("Ground Truth")

        for box, label in zip(original_boxes, original_labels):
            x1, y1, x2, y2 = box
            class_idx = label.item()
            class_name = CLASS_NAMES[class_idx]

            rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='g', facecolor='none')
            ax1.add_patch(rect)
            ax1.text(x1, y1-5, class_name, color='white', fontsize=8, backgroundcolor='green')

        # Predictions
        ax2.imshow(image_np)
        ax2.set_title("Predictions")

        for box, score, class_idx in zip(pred_boxes, pred_scores, pred_classes):
            x1, y1, x2, y2 = box
            class_name = CLASS_NAMES[class_idx.item()]

            rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='r', facecolor='none')
            ax2.add_patch(rect)
            ax2.text(x1, y1-5, f'{class_name} ({score:.2f})', color='white', fontsize=8, backgroundcolor='red')

        plt.tight_layout()
        plt.show()

# Visualize a sample from test set
sample_idx = 99  # Change this to view different samples
img, target, orig_boxes, orig_labels, _ = test_dataset[sample_idx]
visualize_prediction_with_metrics(model, img, target, orig_boxes, orig_labels)