<a href="https://colab.research.google.com/github/guo1428397137-wq/guo-9517-ass/blob/Faster-R-CNN/%E2%80%9CUntitled2_ipynb%E2%80%9D%E7%9A%84%E5%89%AF%E6%9C%AC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchmetrics
import os
from pathlib import Path
from typing import List, Dict, Any
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.ops import nms
import torchvision.transforms.functional as TF
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import random
import json
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score, roc_auc_score, roc_curve

# Basic config
DATA_ROOT = Path("/content/drive/MyDrive/archive")

TRAIN_IMG_DIR   = DATA_ROOT / "train" / "images"
TRAIN_LABEL_DIR = DATA_ROOT / "train" / "labels"
VAL_IMG_DIR     = DATA_ROOT / "valid" / "images"
VAL_LABEL_DIR   = DATA_ROOT / "valid" / "labels"

NUM_CLASSES   = 12
BATCH_SIZE    = 12  # L4可以支持更大的batch
IMG_MAX_SIZE  = 640  # 保持640
EPOCHS        = 15  # 15轮应该足够
LR            = 0.01  # 高学习率配合大batch
WEIGHT_DECAY  = 5e-4
SCORE_THRESH  = 0.05

# Ensure output directory exists
os.makedirs("/content/", exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

CLASS_NAMES = [
    'background',
    'Ants', 'Bees', 'Beetles', 'Caterpillars',
    'Earthworms', 'Earwigs', 'Grasshoppers', 'Moths',
    'Slugs', 'Snails', 'Wasps', 'Weevils'
]

print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Learning rate: {LR}")
print(f"  Number of training rounds: {EPOCHS}")
print(f"  Score threshold: {SCORE_THRESH}")


# Read labels
def load_yolo_labels(label_path: Path, img_w: int, img_h: int):

    if not label_path.exists():
        return torch.zeros((0, 4), dtype=torch.float32), torch.zeros((0,), dtype=torch.int64)

    try:
        data = np.loadtxt(str(label_path), ndmin=2)
    except Exception:
        return torch.zeros((0, 4), dtype=torch.float32), torch.zeros((0,), dtype=torch.int64)

    if data.shape[0] == 0:
        return torch.zeros((0, 4), dtype=torch.float32), torch.zeros((0,), dtype=torch.int64)

    cls = data[:, 0].astype(np.int64)
    cx  = data[:, 1] * img_w
    cy  = data[:, 2] * img_h
    bw  = data[:, 3] * img_w
    bh  = data[:, 4] * img_h

    x1 = np.clip(cx - bw / 2, 0, img_w)
    y1 = np.clip(cy - bh / 2, 0, img_h)
    x2 = np.clip(cx + bw / 2, 0, img_w)
    y2 = np.clip(cy + bh / 2, 0, img_h)

    valid = (x2 > x1 + 5) & (y2 > y1 + 5)
    if not valid.any():
        return torch.zeros((0, 4), dtype=torch.float32), torch.zeros((0,), dtype=torch.int64)

    boxes = np.stack([x1[valid], y1[valid], x2[valid], y2[valid]], axis=1).astype(np.float32)
    labels = cls[valid] + 1

    return torch.as_tensor(boxes, dtype=torch.float32), torch.as_tensor(labels, dtype=torch.int64)


def resize_keep_ratio(img: np.ndarray, target: Dict[str, Any], max_size: int):

    h, w = img.shape[:2]
    scale = min(max_size / max(h, w), 1.0)
    nh, nw = int(h * scale), int(w * scale)

    if (nh, nw) != (h, w):
        img = cv2.resize(img, (nw, nh), interpolation=cv2.INTER_LINEAR)
        if target["boxes"].numel() > 0:
            boxes = target["boxes"]
            boxes[:, [0, 2]] *= float(nw) / w
            boxes[:, [1, 3]] *= float(nh) / h
            boxes[:, 0] = torch.clamp(boxes[:, 0], 0, nw)
            boxes[:, 1] = torch.clamp(boxes[:, 1], 0, nh)
            boxes[:, 2] = torch.clamp(boxes[:, 2], 0, nw)
            boxes[:, 3] = torch.clamp(boxes[:, 3], 0, nh)
            target["boxes"] = boxes

    target["size"] = torch.tensor([nh, nw], dtype=torch.int64)
    return img, target

# Dataset
class YoloDetectionDataset(Dataset):
    def __init__(self, img_dir: Path, label_dir: Path, training: bool = True, img_max_size: int = 800):
        self.img_dir   = Path(img_dir)
        self.label_dir = Path(label_dir)
        self.training  = training
        self.img_max_size = img_max_size

        exts = [".jpg", ".jpeg", ".png"]
        self.img_paths = [p for p in self.img_dir.rglob("*") if p.suffix.lower() in exts]
        self.img_paths.sort()

        if len(self.img_paths) == 0:
            raise RuntimeError(f"No images found in {self.img_dir}")

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx: int):
        img_path = self.img_paths[idx]
        label_path = self.label_dir / (img_path.stem + ".txt")

        img = cv2.imread(str(img_path))
        if img is None:
            raise FileNotFoundError(str(img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]

        boxes, labels = load_yolo_labels(label_path, w, h)

        if boxes.numel() > 0:
            area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
        else:
            area = torch.zeros((0,), dtype=torch.float32)

        iscrowd = torch.zeros((labels.shape[0],), dtype=torch.int64)

        target = {
            "boxes":    boxes,
            "labels":   labels,
            "image_id": torch.tensor([idx]),
            "area":     area,
            "iscrowd":  iscrowd,
        }

        if self.training and boxes.numel() > 0:

            if np.random.rand() < 0.5:
                img = np.ascontiguousarray(img[:, ::-1])
                w_img = img.shape[1]
                b = target["boxes"].clone()
                x1 = b[:, 0].clone()
                x2 = b[:, 2].clone()
                b[:, 0] = w_img - x2
                b[:, 2] = w_img - x1
                target["boxes"] = b

            if np.random.rand() < 0.3:
                img = np.ascontiguousarray(img[::-1, :])
                h_img = img.shape[0]
                b = target["boxes"].clone()
                y1 = b[:, 1].clone()
                y2 = b[:, 3].clone()
                b[:, 1] = h_img - y2
                b[:, 3] = h_img - y1
                target["boxes"] = b

            if np.random.rand() < 0.5:
                img = cv2.convertScaleAbs(
                    img,
                    alpha=np.random.uniform(0.7, 1.3),
                    beta=np.random.uniform(-20, 20),
                )

        img, target = resize_keep_ratio(img, target, self.img_max_size)

        img = TF.to_tensor(img)
        img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        return img, target


def collate_fn(batch):
    imgs, targets = list(zip(*batch))
    return list(imgs), list(targets)

#  DataLoaders
train_dataset = YoloDetectionDataset(TRAIN_IMG_DIR, TRAIN_LABEL_DIR, training=True,  img_max_size=IMG_MAX_SIZE)
val_dataset   = YoloDetectionDataset(VAL_IMG_DIR,   VAL_LABEL_DIR,   training=False, img_max_size=IMG_MAX_SIZE)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=2, collate_fn=collate_fn, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=2, collate_fn=collate_fn, pin_memory=True)

print(f"\nTrain samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

# check sample
img, target = train_dataset[0]
print("\nFirst sample:")
print("Image shape:", img.shape)
print("Boxes shape:", target["boxes"].shape)
print("Labels:", target["labels"])

# Model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES + 1)

model.to(device)
print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

# Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer    = torch.optim.SGD(params, lr=LR, momentum=0.9, weight_decay=WEIGHT_DECAY)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)  # 第7轮降低学习率
scaler       = torch.amp.GradScaler('cuda', enabled=(device.type == "cuda"))

# Train & eval
def train_one_epoch(epoch: int):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch:02d}")
    running_loss = 0.0
    loss_components = {'cls': 0, 'box': 0, 'obj': 0, 'rpn': 0}

    for images, targets in pbar:
        images  = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=(device.type == "cuda")):
            loss_dict = model(images, targets)
            loss = sum(loss_dict.values())

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        for key in loss_components:
            loss_key = {
                'cls': 'loss_classifier',
                'box': 'loss_box_reg',
                'obj': 'loss_objectness',
                'rpn': 'loss_rpn_box_reg'
            }[key]
            loss_components[key] += loss_dict.get(loss_key, torch.tensor(0)).item()

        pbar.set_postfix({'loss': f'{loss.item():.3f}'})

    lr_scheduler.step()

    n = len(train_loader)
    return running_loss / n, {k: v/n for k, v in loss_components.items()}


@torch.no_grad()
def evaluate_map():
    model.eval()
    metric = MeanAveragePrecision(iou_type="bbox")
    pbar = tqdm(val_loader, desc="Eval")

    for images, targets in pbar:
        images  = [img.to(device) for img in images]
        outputs = model(images)

        preds, gts = [], []
        for out, tgt in zip(outputs, targets):
            preds.append({
                "boxes":  out["boxes"].detach().cpu(),
                "scores": out["scores"].detach().cpu(),
                "labels": out["labels"].detach().cpu(),
            })
            gts.append({
                "boxes":   tgt["boxes"].detach().cpu(),
                "labels":  tgt["labels"].detach().cpu(),
                "iscrowd": tgt.get("iscrowd", torch.zeros(tgt["labels"].shape[0], dtype=torch.int64)).detach().cpu()
            })
        metric.update(preds, gts)

    res = metric.compute()
    return {
        "map": res.get("map", torch.tensor(0)).item(),
        "map_50": res.get("map_50", torch.tensor(0)).item(),
        "map_75": res.get("map_75", torch.tensor(0)).item(),
    }

# Training loop
best_map = 0.0
history = {
    'train_loss': [],
    'map': [],
    'map_50': [],
    'map_75': [],
    'loss_cls': [],
    'loss_box': [],
    'lr': []
}

print("\n" + "="*70)
print("Start training")
print("="*70)

import time
start_time = time.time()

for epoch in range(1, EPOCHS + 1):
    train_loss, loss_comp = train_one_epoch(epoch)
    metrics = evaluate_map()

    current_lr = optimizer.param_groups[0]['lr']

    history['train_loss'].append(train_loss)
    history['map'].append(metrics['map'])
    history['map_50'].append(metrics['map_50'])
    history['map_75'].append(metrics.get('map_75', 0))
    history['loss_cls'].append(loss_comp['cls'])
    history['loss_box'].append(loss_comp['box'])
    history['lr'].append(current_lr)

    print(f"[Epoch {epoch:02d}/{EPOCHS}] "
          f"Loss={train_loss:.4f} | "
          f"mAP={metrics['map']:.4f} | "
          f"mAP@50={metrics['map_50']:.4f} | "
          f"LR={current_lr:.6f}")

    if metrics['map'] > best_map:
        best_map = metrics['map']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'map': metrics['map'],
            'history': history
        }, "/content/fasterrcnn_best.pth")
        print("Saved best model")
    print("-" * 70)

training_time = time.time() - start_time
print(f"\nTraining finished in {training_time/60:.2f} minutes")
print(f"Best mAP: {best_map:.4f}")


with open('/content/training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

# Plot

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Loss
axes[0, 0].plot(history['train_loss'], marker='o', label='Total Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=11)
axes[0, 0].set_ylabel('Loss', fontsize=11)
axes[0, 0].set_title('Training Loss', fontsize=13, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# mAP
axes[0, 1].plot(history['map'], marker='o', label='mAP', linewidth=2)
axes[0, 1].plot(history['map_50'], marker='s', label='mAP@50', linewidth=2)
axes[0, 1].set_xlabel('Epoch', fontsize=11)
axes[0, 1].set_ylabel('mAP', fontsize=11)
axes[0, 1].set_title('Validation mAP', fontsize=13, fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Loss components
axes[1, 0].plot(history['loss_cls'], marker='o', label='Classification', linewidth=2)
axes[1, 0].plot(history['loss_box'], marker='s', label='Box Regression', linewidth=2)
axes[1, 0].set_xlabel('Epoch', fontsize=11)
axes[1, 0].set_ylabel('Loss', fontsize=11)
axes[1, 0].set_title('Loss Components', fontsize=13, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Learning rate
axes[1, 1].plot(history['lr'], marker='o', color='orangered', linewidth=2)
axes[1, 1].set_xlabel('Epoch', fontsize=11)
axes[1, 1].set_ylabel('Learning Rate', fontsize=11)
axes[1, 1].set_title('Learning Rate Schedule', fontsize=13, fontweight='bold')
axes[1, 1].set_yscale('log')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('/content/training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

# Visualization
def denormalize_image(img_tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std  = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img_tensor * std + mean
    return torch.clamp(img, 0, 1).permute(1, 2, 0).numpy()


def visualize_predictions(dataset, num_images=6, score_thresh=SCORE_THRESH, nms_thresh=0.5):

    model.eval()
    idxs = random.sample(range(len(dataset)), k=min(num_images, len(dataset)))

    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()

    for i, idx in enumerate(idxs):
        img, target = dataset[idx]
        img_np = denormalize_image(img.cpu())

        with torch.no_grad():
            out = model([img.to(device)])[0]

        boxes  = out["boxes"].cpu()
        scores = out["scores"].cpu()
        labels = out["labels"].cpu()

        keep_conf = scores >= score_thresh
        boxes = boxes[keep_conf]
        scores = scores[keep_conf]
        labels = labels[keep_conf]

        if len(boxes) > 0:
            final_boxes, final_scores, final_labels = [], [], []
            for cls_id in torch.unique(labels):
                cls_mask = labels == cls_id
                cls_boxes = boxes[cls_mask]
                cls_scores = scores[cls_mask]

                keep_nms = nms(cls_boxes, cls_scores, nms_thresh)

                final_boxes.append(cls_boxes[keep_nms])
                final_scores.append(cls_scores[keep_nms])
                final_labels.append(labels[cls_mask][keep_nms])

            boxes = torch.cat(final_boxes).numpy()
            scores = torch.cat(final_scores).numpy()
            labels = torch.cat(final_labels).numpy()
        else:
            boxes = boxes.numpy()
            scores = scores.numpy()
            labels = labels.numpy()

        vis = (img_np * 255).astype(np.uint8).copy()

        for (x1, y1, x2, y2), s, lb in zip(boxes, scores, labels):
            x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
            cv2.rectangle(vis, (x1, y1), (x2, y2), (255, 0, 0), 2)
            class_name = CLASS_NAMES[int(lb)] if int(lb) < len(CLASS_NAMES) else str(int(lb))
            text = f"{class_name}: {s:.2f}"
            (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
            cv2.rectangle(vis, (x1, max(0, y1-th-5)), (x1+tw, y1), (255, 0, 0), -1)
            cv2.putText(vis, text, (x1, max(th, y1-5)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 2)

        gt_boxes = target["boxes"].numpy()
        for (x1, y1, x2, y2) in gt_boxes:
            x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
            cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)

        axes[i].imshow(vis)
        axes[i].axis('off')
        axes[i].set_title(f'Sample {idx} | {len(boxes)} pred boxes', fontsize=10, fontweight='bold')

    from matplotlib.patches import Patch
    legend = [
        Patch(facecolor='red', label='Prediction'),
        Patch(facecolor='green', label='Ground Truth')
    ]
    fig.legend(handles=legend, loc='upper right', fontsize=12)

    plt.tight_layout()
    plt.savefig('/content/predictions.png', dpi=300, bbox_inches='tight')
    plt.show()

print("\nVisualizing predictions (with NMS):")
visualize_predictions(val_dataset, num_images=6, score_thresh=SCORE_THRESH, nms_thresh=0.5)



@torch.no_grad()
def comprehensive_evaluation(dataset, score_thresh=SCORE_THRESH, iou_thresh=0.5):
    """
    全面评估：混淆矩阵、Precision/Recall/F1/Accuracy、AUC
    """
    model.eval()


    all_gt_labels = []
    all_pred_labels = []
    all_pred_scores = []
    all_correct = []

    y_true = []
    y_pred = []

    print("\n" + "="*70)
    print("Running comprehensive evaluation...")
    print("="*70)

    for idx in tqdm(range(len(dataset)), desc='Evaluating'):
        img, target = dataset[idx]

        with torch.no_grad():
            out = model([img.to(device)])[0]

        pred_boxes = out["boxes"].cpu().numpy()
        pred_scores = out["scores"].cpu().numpy()
        pred_labels = out["labels"].cpu().numpy()

        gt_boxes = target["boxes"].numpy()
        gt_labels = target["labels"].numpy()


        keep = pred_scores >= score_thresh
        pred_boxes = pred_boxes[keep]
        pred_scores = pred_scores[keep]
        pred_labels = pred_labels[keep]


        matched_gt = set()
        for pred_box, pred_label, pred_score in zip(pred_boxes, pred_labels, pred_scores):
            best_iou = 0
            best_gt_idx = -1

            for gt_idx, (gt_box, gt_label) in enumerate(zip(gt_boxes, gt_labels)):
                if gt_idx in matched_gt:
                    continue


                x1 = max(pred_box[0], gt_box[0])
                y1 = max(pred_box[1], gt_box[1])
                x2 = min(pred_box[2], gt_box[2])
                y2 = min(pred_box[3], gt_box[3])

                if x2 > x1 and y2 > y1:
                    intersection = (x2 - x1) * (y2 - y1)
                    pred_area = (pred_box[2] - pred_box[0]) * (pred_box[3] - pred_box[1])
                    gt_area = (gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])
                    union = pred_area + gt_area - intersection
                    iou = intersection / union if union > 0 else 0

                    if iou > best_iou:
                        best_iou = iou
                        best_gt_idx = gt_idx


            if best_iou >= iou_thresh and best_gt_idx >= 0:
                gt_label = gt_labels[best_gt_idx]
                if pred_label == gt_label:
                    all_correct.append(1)
                    matched_gt.add(best_gt_idx)
                    y_true.append(gt_label)
                    y_pred.append(pred_label)
                else:
                    all_correct.append(0)
                    y_true.append(gt_label)
                    y_pred.append(pred_label)
            else:
                all_correct.append(0)

                y_pred.append(pred_label)

                y_true.append(0)

            all_pred_scores.append(pred_score)


        for gt_idx, gt_label in enumerate(gt_labels):
            if gt_idx not in matched_gt:
                all_correct.append(0)
                y_true.append(gt_label)
                y_pred.append(0)
                all_pred_scores.append(0.0)


    all_correct = np.array(all_correct)
    all_pred_scores = np.array(all_pred_scores)
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    print(f"\nTotal predictions: {len(all_correct)}")
    print(f"Correct predictions: {all_correct.sum()}")


    print("\n" + "="*70)
    print("CONFUSION MATRIX")
    print("="*70)


    cm = confusion_matrix(y_true, y_pred, labels=list(range(NUM_CLASSES + 1)))


    plt.figure(figsize=(14, 12))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=CLASS_NAMES,
                yticklabels=CLASS_NAMES,
                cbar_kws={'label': 'Count'})
    plt.xlabel('Predicted Label', fontsize=12, fontweight='bold')
    plt.ylabel('True Label', fontsize=12, fontweight='bold')
    plt.title('Confusion Matrix', fontsize=14, fontweight='bold', pad=20)
    plt.tight_layout()
    plt.savefig('/content/confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()


    print("\nPer-class metrics:")

    class_metrics = {}
    for label in range(1, NUM_CLASSES + 1):
        mask_true = (y_true == label)
        mask_pred = (y_pred == label)

        tp = np.sum(mask_true & mask_pred)
        fp = np.sum(~mask_true & mask_pred)
        fn = np.sum(mask_true & ~mask_pred)

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        support = np.sum(mask_true)

        class_metrics[label] = {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'support': support
        }

        class_name = CLASS_NAMES[label].lower()

        idx = f"{label-1:02d}"
        print(f"{idx} {class_name:<13} | P: {precision:.4f} R: {recall:.4f} F1: {f1:.4f} AUC: -- S: {support:>3d}")


    print("\nConfusion Matrix (rows=true, cols=pred):")
    print(cm)

    print("\nFull classification report:")
    print(f"{'':>15} {'precision':>10} {'recall':>10} {'f1-score':>10} {'support':>10}")
    print()


    for label in range(1, NUM_CLASSES + 1):
        class_name = CLASS_NAMES[label].lower()
        m = class_metrics[label]
        print(f"{class_name:>15} {m['precision']:10.4f} {m['recall']:10.4f} {m['f1']:10.4f} {m['support']:10d}")

    print()

    mask_objects = (y_true > 0) | (y_pred > 0)
    y_true_obj = y_true[mask_objects]
    y_pred_obj = y_pred[mask_objects]


    overall_accuracy = accuracy_score(y_true_obj, y_pred_obj)
    macro_precision = precision_score(y_true_obj, y_pred_obj, average='macro', zero_division=0)
    macro_recall = recall_score(y_true_obj, y_pred_obj, average='macro', zero_division=0)
    macro_f1 = f1_score(y_true_obj, y_pred_obj, average='macro', zero_division=0)

    overall_precision = precision_score(y_true_obj, y_pred_obj, average='weighted', zero_division=0)
    overall_recall = recall_score(y_true_obj, y_pred_obj, average='weighted', zero_division=0)
    overall_f1 = f1_score(y_true_obj, y_pred_obj, average='weighted', zero_division=0)

    print(f"{'accuracy':>15} {overall_accuracy:34.4f} {len(y_true_obj):10d}")
    print(f"{'macro avg':>15} {macro_precision:10.4f} {macro_recall:10.4f} {macro_f1:10.4f} {len(y_true_obj):10d}")
    print(f"{'weighted avg':>15} {overall_precision:10.4f} {overall_recall:10.4f} {overall_f1:10.4f} {len(y_true_obj):10d}")


    if len(all_correct) > 0 and len(np.unique(all_correct)) > 1:
        # AUC
        auc_score = roc_auc_score(all_correct, all_pred_scores)
        print(f"\nAUC (Correct vs Incorrect): {auc_score:.4f}")
    else:
        auc_score = None

 =

    plt.figure(figsize=(14, 12))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=CLASS_NAMES,
                yticklabels=CLASS_NAMES,
                cbar_kws={'label': 'Count'})
    plt.xlabel('Predicted Label', fontsize=12, fontweight='bold')
    plt.ylabel('True Label', fontsize=12, fontweight='bold')
    plt.title('Confusion Matrix', fontsize=14, fontweight='bold', pad=20)
    plt.tight_layout()
    plt.savefig('/content/confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.close()


    evaluation_results = {
        'overall': {
            'precision_weighted': float(overall_precision),
            'recall_weighted': float(overall_recall),
            'f1_weighted': float(overall_f1),
            'accuracy': float(overall_accuracy),
            'precision_macro': float(macro_precision),
            'recall_macro': float(macro_recall),
            'f1_macro': float(macro_f1),
            'auc': float(auc_score) if auc_score is not None else None
        },
        'per_class': {
            CLASS_NAMES[i]: {
                'precision': float(class_metrics[i]['precision']),
                'recall': float(class_metrics[i]['recall']),
                'f1': float(class_metrics[i]['f1']),
                'support': int(class_metrics[i]['support'])
            } for i in range(1, NUM_CLASSES + 1)
        },
        'confusion_matrix': cm.tolist(),
        'settings': {
            'score_threshold': score_thresh,
            'iou_threshold': iou_thresh
        }
    }

    with open('/content/evaluation_results.json', 'w') as f:
        json.dump(evaluation_results, f, indent=2)

    return evaluation_results



eval_results = comprehensive_evaluation(val_dataset, score_thresh=SCORE_THRESH, iou_thresh=0.3)

Using device: cuda
  Batch Size: 12
  Learning rate: 0.01
  Number of training rounds: 15
  Score threshold: 0.05

Train samples: 11502, Val samples: 1095

First sample:
Image shape: torch.Size([3, 640, 640])
Boxes shape: torch.Size([1, 4])
Labels: tensor([11])

Model parameters: 41,355,536

Start training


Epoch 01: 100%|██████████| 959/959 [1:09:02<00:00,  4.32s/it, loss=1.342]
Eval: 100%|██████████| 92/92 [12:40<00:00,  8.27s/it]


[Epoch 01/15] Loss=2.2997 | mAP=0.0000 | mAP@50=0.0000 | LR=0.010000
Saved best model
----------------------------------------------------------------------


Epoch 02: 100%|██████████| 959/959 [13:45<00:00,  1.16it/s, loss=0.739]
Eval: 100%|██████████| 92/92 [00:47<00:00,  1.92it/s]


[Epoch 02/15] Loss=1.0981 | mAP=0.0001 | mAP@50=0.0004 | LR=0.010000
Saved best model
----------------------------------------------------------------------


Epoch 03: 100%|██████████| 959/959 [13:46<00:00,  1.16it/s, loss=0.739]
Eval: 100%|██████████| 92/92 [00:47<00:00,  1.92it/s]


[Epoch 03/15] Loss=0.7411 | mAP=0.0001 | mAP@50=0.0008 | LR=0.010000
Saved best model
----------------------------------------------------------------------


Epoch 04: 100%|██████████| 959/959 [13:48<00:00,  1.16it/s, loss=1.036]
Eval: 100%|██████████| 92/92 [00:48<00:00,  1.90it/s]


[Epoch 04/15] Loss=0.7523 | mAP=0.0001 | mAP@50=0.0006 | LR=0.010000
----------------------------------------------------------------------


Epoch 05: 100%|██████████| 959/959 [13:50<00:00,  1.15it/s, loss=1.065]
Eval: 100%|██████████| 92/92 [00:49<00:00,  1.87it/s]


[Epoch 05/15] Loss=0.8738 | mAP=0.0001 | mAP@50=0.0006 | LR=0.010000
----------------------------------------------------------------------


Epoch 06: 100%|██████████| 959/959 [13:52<00:00,  1.15it/s, loss=1.187]
Eval: 100%|██████████| 92/92 [00:50<00:00,  1.84it/s]


[Epoch 06/15] Loss=1.0915 | mAP=0.0002 | mAP@50=0.0009 | LR=0.010000
Saved best model
----------------------------------------------------------------------


Epoch 07: 100%|██████████| 959/959 [13:54<00:00,  1.15it/s, loss=1.455]
Eval: 100%|██████████| 92/92 [00:50<00:00,  1.81it/s]


[Epoch 07/15] Loss=1.2831 | mAP=0.0002 | mAP@50=0.0009 | LR=0.001000
Saved best model
----------------------------------------------------------------------


Epoch 08:   4%|▍         | 36/959 [00:32<13:23,  1.15it/s, loss=1.452]