# Traffic Light Detection using Faster R-CNN on the LISA Dataset (Combined Day/Night)

Implementation and training of the Faster R-CNN model. Models trained on combined day and night data.
Version with MetricLogger fixes, Early Stopping, P/R/F1 extraction from COCO, plotting graphs, and saving results.

## 0. Import Diagnostics and Path Setup

In [None]:
import sys
import os

print("--- Import Path Diagnostics ---")
current_working_directory = os.getcwd()
print(f"Current working directory (CWD): {current_working_directory}")

project_root_directory = os.path.dirname(current_working_directory)
print(f"Parent directory (potential project_root): {project_root_directory}")

if project_root_directory not in sys.path:
    sys.path.insert(0, project_root_directory)
    print(f"Added to sys.path: {project_root_directory}")
else:
    print(f"Directory {project_root_directory} is already in sys.path.")

expected_utils_path_absolute = os.path.join(project_root_directory, 'utils')
expected_utils_path_absolute = os.path.abspath(expected_utils_path_absolute)
print(f"Expected absolute path to utils folder: {expected_utils_path_absolute}")
print(f"Does the utils folder exist at the expected location? {os.path.isdir(expected_utils_path_absolute)}")
if os.path.isdir(expected_utils_path_absolute):
    print(f"Contents of the utils folder: {os.listdir(expected_utils_path_absolute)}")

print("--- End of Diagnostics ---")

print("\nAttempting to import tools...")
COCO_UTILS_AVAILABLE = False
METRIC_LOGGER_AVAILABLE = False 

try:
    from utils.coco_eval import CocoEvaluator
    from utils.coco_utils import get_coco_api_from_dataset
    COCO_UTILS_AVAILABLE = True
    print("SUCCESS: COCO tools imported successfully.")
except ImportError as e:
    print(f"COCO TOOLS IMPORT ERROR: {e}.")
    if 'CocoEvaluator' not in globals():
        class CocoEvaluator:
            def __init__(self, coco_gt, iou_types): self.coco_gt = coco_gt; self.iou_types = iou_types; self.eval_imgs = []
            def update(self, predictions): pass
            def synchronize_between_processes(self): pass
            def accumulate(self): pass
            def summarize(self): print("Mock CocoEvaluator.summarize() used: mAP is not calculated correctly.")
    if 'get_coco_api_from_dataset' not in globals():
        def get_coco_api_from_dataset(dataset): print("Mock get_coco_api_from_dataset() used."); return None

try:
    from utils.utils import MetricLogger, SmoothedValue 
    METRIC_LOGGER_AVAILABLE = True
    print("SUCCESS: MetricLogger and SmoothedValue imported successfully from utils.utils.")
except ImportError as e:
    print(f"IMPORT ERROR from utils.utils: {e}.")
    if 'MetricLogger' not in globals(): 
        class MetricLogger:
            def __init__(self, delimiter=None): self.meters = {}; self.delimiter = delimiter; print("WARNING: Mock MetricLogger used.")
            def add_meter(self, name, meter): self.meters[name] = meter 
            def update(self, **kwargs): pass
            def synchronize_between_processes(self): pass
            def __str__(self): return "MockedMetricLogger"
            def log_every(self, iterable, print_freq, header=None):
                if header: print(header)
                from tqdm.auto import tqdm 
                for i, data_batch in enumerate(tqdm(iterable, desc=header if header else "Iteration", leave=False)):
                    yield data_batch
    if 'SmoothedValue' not in globals():
        class SmoothedValue: 
            def __init__(self, window_size=20, fmt=None): self.deque = []; self.fmt = fmt
            def update(self, value, n=1): self.deque.append(value) 
            def __str__(self): import numpy; return str(numpy.mean(self.deque)) if self.deque else "N/A"

## 1. Imports and Basic Configuration (continued)

In [None]:
import torch
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import albumentations as A
from albumentations.pytorch import ToTensorV2
import time
import cv2 
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import json 

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {DEVICE}")

CLASS_MAPPING = {'background': 0, 'go': 1, 'stop': 2, 'warning': 3, 'off': 4}
INV_CLASS_MAPPING = {v: k for k, v in CLASS_MAPPING.items()} 
NUM_CLASSES_INC_BG = len(CLASS_MAPPING)
NUM_CLASSES_NO_BG = NUM_CLASSES_INC_BG - 1 
CLASS_NAMES_NO_BG = [INV_CLASS_MAPPING[i] for i in range(1, NUM_CLASSES_INC_BG)]
IMAGE_SIZE = 640 

BASE_DATA_PATH = "../dataset/lisa_traffic_light_dataset/"
ANNOTATIONS_DIR = os.path.join(BASE_DATA_PATH, "annotations")
IMAGES_BASE_DIR = os.path.join(BASE_DATA_PATH, "images") 

TRAIN_ANNOTATIONS_FILE = os.path.join(ANNOTATIONS_DIR, "train_annotations.csv")
VAL_ANNOTATIONS_FILE = os.path.join(ANNOTATIONS_DIR, "val_annotations.csv")

OUTPUT_DIR = "../results/faster_rcnn/"
VISUALIZATIONS_DIR = os.path.join(OUTPUT_DIR, "visualizations")
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(VISUALIZATIONS_DIR, exist_ok=True)

NUM_EPOCHS = 10
BATCH_SIZE = 4
LEARNING_RATE = 0.001 
WEIGHT_DECAY = 0.0005 
MOMENTUM = 0.9
EARLY_STOPPING_PATIENCE = 10 
PRINT_FREQ = 10 

USE_SUBSET_DATA = True 
SUBSET_SIZE_TRAIN = 50
SUBSET_SIZE_VAL = 20 
APPLY_AUGMENTATIONS = False

TRAINING_ARGS = {
    "model_name": "FasterRCNN_ResNet50_FPN", "num_epochs": NUM_EPOCHS, "batch_size": BATCH_SIZE,
    "learning_rate": LEARNING_RATE, "weight_decay": WEIGHT_DECAY, "momentum": MOMENTUM,
    "early_stopping_patience": EARLY_STOPPING_PATIENCE, "image_size": IMAGE_SIZE,
    "use_subset_data": USE_SUBSET_DATA,
    "subset_size_train": SUBSET_SIZE_TRAIN if USE_SUBSET_DATA else -1,
    "subset_size_val": SUBSET_SIZE_VAL if USE_SUBSET_DATA else -1,
    "class_mapping": {k:v for k,v in CLASS_MAPPING.items()},
    "apply_augmentations": APPLY_AUGMENTATIONS
}
args_save_path = os.path.join(OUTPUT_DIR, 'training_arguments_faster_rcnn.json')
try:
    with open(args_save_path, 'w') as f:
        json.dump(TRAINING_ARGS, f, indent=4)
    print(f"Training arguments saved to {args_save_path}")
except Exception as e:
    print(f"Error saving training arguments: {e}")

## 2. `LISADataset` Class and Augmentations (with filename_to_id)

In [None]:
class LISADataset(Dataset):
    def __init__(self, annotations_file, img_base_dir, transforms=None, class_mapping=None):
        try:
            self.full_annotations_df = pd.read_csv(annotations_file)
        except FileNotFoundError:
            print(f"CRITICAL ERROR: Annotation file {annotations_file} not found!")
            self.full_annotations_df = pd.DataFrame(columns=['filename', 'xmin', 'ymin', 'xmax', 'ymax', 'label'])
            
        self.img_base_dir = img_base_dir
        self.transforms = transforms
        self.class_mapping = class_mapping
        
        if self.full_annotations_df.empty:
             self.image_filenames = np.array([]) 
        else:
            self.image_filenames = self.full_annotations_df['filename'].unique()
        
        self.filename_to_id = {fname: i for i, fname in enumerate(self.image_filenames)}
        
        self.image_annotations = {
            filename: self.full_annotations_df[self.full_annotations_df['filename'] == filename]
            for filename in self.image_filenames
        }

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        if idx >= len(self.image_filenames):
            return None, None 
            
        img_relative_path = self.image_filenames[idx]
        img_path = os.path.join(self.img_base_dir, img_relative_path)
        
        try:
            image = Image.open(img_path).convert("RGB")
        except FileNotFoundError: 
            return None, None 
        except Exception as e_img_open:
            print(f"ERROR (LISADataset): Could not open image {img_path}: {e_img_open}")
            return None, None

        annots = self.image_annotations[img_relative_path]
        boxes = annots[['xmin','ymin','xmax','ymax']].values.astype(np.float32)
        
        try: 
            labels_tensor = torch.tensor([self.class_mapping[lbl] for lbl in annots['label']], dtype=torch.int64)
        except KeyError as e: 
            labels_tensor = torch.zeros(0,dtype=torch.int64) 
            boxes = np.zeros((0,4),dtype=np.float32)

        unique_img_id = self.filename_to_id[img_relative_path]
        target={'boxes':torch.as_tensor(boxes,dtype=torch.float32),
                 'labels':labels_tensor,
                 'image_id':torch.tensor([unique_img_id], dtype=torch.int64)} 
        
        if boxes.shape[0]>0: 
            target['area']=torch.as_tensor((boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1]),dtype=torch.float32)
        else: 
            target['area']=torch.zeros(0,dtype=torch.float32)
            target['boxes']=torch.zeros((0,4),dtype=torch.float32)
            target['labels']=torch.zeros(0,dtype=torch.int64)
        
        target['iscrowd']=torch.zeros((target['boxes'].shape[0],),dtype=torch.int64)

        if self.transforms:
            image_np = np.array(image)
            labels_for_albumentations = target['labels'].tolist()
            bboxes_for_albumentations = target['boxes'].numpy()
            try:
                transformed = self.transforms(image=image_np, bboxes=bboxes_for_albumentations, labels=labels_for_albumentations)
                image = transformed['image'] 
                if len(transformed['bboxes']) > 0:
                    target['boxes'] = torch.as_tensor(transformed['bboxes'], dtype=torch.float32)
                    target['labels'] = torch.as_tensor(transformed['labels'], dtype=torch.int64)
                else:
                    target['boxes'] = torch.zeros((0,4), dtype=torch.float32)
                    target['labels'] = torch.zeros(0, dtype=torch.int64)
            except Exception as e:
                default_transform = A.Compose([
                    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
                    A.Normalize(mean=[.485,.456,.406],std=[.229,.224,.225]),
                    ToTensorV2()
                ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))
                transformed = default_transform(image=image_np, bboxes=bboxes_for_albumentations, labels=labels_for_albumentations)
                image = transformed['image']
                if len(transformed['bboxes']) > 0:
                    target['boxes'] = torch.as_tensor(transformed['bboxes'], dtype=torch.float32)
                    target['labels'] = torch.as_tensor(transformed['labels'], dtype=torch.int64)
                else:
                    target['boxes'] = torch.zeros((0,4), dtype=torch.float32)
                    target['labels'] = torch.zeros(0, dtype=torch.int64)
            
        if target['boxes'].shape[0] > 0:
            valid_idx = (target['boxes'][:,2] > target['boxes'][:,0] + 1e-3) & (target['boxes'][:,3] > target['boxes'][:,1] + 1e-3)
            target['boxes'] = target['boxes'][valid_idx]
            target['labels'] = target['labels'][valid_idx]
            if target['boxes'].shape[0] > 0: 
                target['area'] = (target['boxes'][:, 2] - target['boxes'][:, 0]) * (target['boxes'][:, 3] - target['boxes'][:, 1])
                target['iscrowd'] = torch.zeros((target['boxes'].shape[0],), dtype=torch.int64)
            else:
                target['area'] = torch.zeros(0, dtype=torch.float32); target['iscrowd'] = torch.zeros(0, dtype=torch.int64)
                target['boxes'] = torch.zeros((0,4), dtype=torch.float32); target['labels'] = torch.zeros(0, dtype=torch.int64)
        else: 
            target['boxes'] = torch.zeros((0,4), dtype=torch.float32); target['labels'] = torch.zeros(0, dtype=torch.int64)
            target['area'] = torch.zeros(0, dtype=torch.float32); target['iscrowd'] = torch.zeros(0, dtype=torch.int64)
            
        return image, target

def get_train_transforms(apply_augmentations=True):

    base_transforms = [
        A.Resize(IMAGE_SIZE, IMAGE_SIZE),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ]

    if apply_augmentations:
        augmentation_transforms = [
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.5),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=30, val_shift_limit=20, p=0.5),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.15, rotate_limit=50, p=0.6, 
                                border_mode=cv2.BORDER_CONSTANT), 
            A.Affine(shear=(-10, 10), p=0.3),
            A.OneOf([
                A.GaussianBlur(blur_limit=(3, 7), p=0.5),
                A.MotionBlur(blur_limit=7, p=0.5),
            ], p=0.3),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
            A.RandomGamma(gamma_limit=(80,120), p=0.3),
            A.CLAHE(clip_limit=2.0, tile_grid_size=(8,8), p=0.2),
        ]
        all_transforms = augmentation_transforms + base_transforms
    else:
        all_transforms = base_transforms
    
    return A.Compose(all_transforms, bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels'], min_visibility=0.2, min_area=10))

def get_val_test_transforms():
    return A.Compose([
        A.Resize(IMAGE_SIZE, IMAGE_SIZE),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))

def collate_fn(batch):
    batch = [b for b in batch if b is not None and b[0] is not None and b[1] is not None] 
    if not batch: 
        return None, None 
    return tuple(zip(*batch))

## 3. Faster R-CNN Model Definition

In [None]:
def get_faster_rcnn_model(num_classes_incl_background, backbone_name="resnet50"):
    if backbone_name == "resnet50":
        model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(
            weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
        )
    else:
        raise ValueError(f"Unsupported backbone for Faster R-CNN: {backbone_name}")

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes_incl_background)
    return model

## 4. Training Loop and Evaluation (with Early Stopping and Visualizations)

In [None]:
def train_one_epoch(model, optimizer, data_loader, device, epoch_num, print_freq=10):
    model.train()
    current_metric_logger = None 
    if METRIC_LOGGER_AVAILABLE:
        current_metric_logger = MetricLogger(delimiter="  ")
        current_metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
        current_metric_logger.add_meter('loss_classifier', SmoothedValue(window_size=1, fmt='{median:.4f} ({global_avg:.4f})'))
        current_metric_logger.add_meter('loss_box_reg', SmoothedValue(window_size=1, fmt='{median:.4f} ({global_avg:.4f})'))
        current_metric_logger.add_meter('loss_objectness', SmoothedValue(window_size=1, fmt='{median:.4f} ({global_avg:.4f})'))
        current_metric_logger.add_meter('loss_rpn_box_reg', SmoothedValue(window_size=1, fmt='{median:.4f} ({global_avg:.4f})'))
    else: 
        class MockMetricLoggerForTrain:
            def __init__(self, delimiter=None): self.meters = {}; self.delimiter = delimiter
            def add_meter(self, name, meter): self.meters[name] = meter
            def update(self, **kwargs): pass
            def __str__(self): return "MockedTrainMetricLogger"
            def log_every(self, iterable, print_freq, header=None):
                from tqdm.auto import tqdm 
                return tqdm(iterable, desc=header if header else "Training", leave=False)
        current_metric_logger = MockMetricLoggerForTrain(delimiter="  ")
        if 'SmoothedValue' in globals() and hasattr(globals()['SmoothedValue'], '__call__'): current_metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
            
    header = f"Epoch [{epoch_num}]"
    lr_scheduler = None 
    if epoch_num -1 == 0: 
        warmup_factor = 1.0 / 1000
        warmup_iters = min(1000, len(data_loader) - 1) if len(data_loader) > 1 else 1
        if warmup_iters > 0 : 
            lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=warmup_factor, total_iters=warmup_iters)
    
    epoch_losses = {'loss_classifier': [], 'loss_box_reg': [], 'loss_objectness': [], 'loss_rpn_box_reg': [], 'loss': []}

    for images, targets in current_metric_logger.log_every(data_loader, print_freq, header):
        if images is None or targets is None: 
            continue
        
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items() if isinstance(v, torch.Tensor)} for t in targets]
        for target_item in targets: target_item['boxes'] = target_item['boxes'].float()
        
        try:
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            loss_value = losses.item()
            
            if not np.isfinite(loss_value): 
                print(f"Epoch: {epoch_num}, Infinite loss: {loss_value}, skipping batch.")
                continue 

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            
            if lr_scheduler is not None: lr_scheduler.step()
            
            loss_dict_items = {k: v.item() for k, v in loss_dict.items()}
            if METRIC_LOGGER_AVAILABLE:
                current_metric_logger.update(loss=loss_value, **loss_dict_items)
                current_metric_logger.update(lr=optimizer.param_groups[0]["lr"])
            
            epoch_losses['loss'].append(loss_value)
            for k, v_item in loss_dict_items.items():
                if k in epoch_losses: # Ensure key exists
                    epoch_losses[k].append(v_item)

        except Exception as e: 
            print(f"Error in training step: {e}")
            import traceback; traceback.print_exc()
            continue 
    
    avg_epoch_losses = {k: np.mean(v) if v else 0.0 for k, v in epoch_losses.items()}
    
    returned_losses = {}
    for loss_name in epoch_losses.keys():
        if METRIC_LOGGER_AVAILABLE and hasattr(current_metric_logger, 'meters') and loss_name in current_metric_logger.meters and hasattr(current_metric_logger.meters[loss_name], 'global_avg'):
            returned_losses[loss_name] = current_metric_logger.meters[loss_name].global_avg
        else:
            returned_losses[loss_name] = avg_epoch_losses[loss_name]
            
    return returned_losses

@torch.inference_mode()
def evaluate(model, data_loader, device, class_mapping, coco_utils_available_flag, metric_logger_available_flag, num_classes_no_bg_for_curves, inv_class_mapping_for_curves):
    n_threads = torch.get_num_threads()
    torch.set_num_threads(1)
    cpu_device = torch.device("cpu")
    model.eval()
    current_metric_logger_eval = None 
    iou_threshold = 0.5
    if metric_logger_available_flag:
        current_metric_logger_eval = MetricLogger(delimiter="  ")
        if 'SmoothedValue' in globals() and hasattr(globals()['SmoothedValue'], '__call__'):
             current_metric_logger_eval.add_meter("model_time", SmoothedValue(window_size=1, fmt="{value:.4f}"))
    else:
        class MockMetricLoggerEval: 
            def __init__(self, delimiter=None): self.delimiter = delimiter; self.meters = {}
            def add_meter(self, name, meter): pass
            def update(self, **kwargs): pass
            def synchronize_between_processes(self): pass
            def __str__(self): return "MockedEvalMetricLogger"
            def log_every(self, iterable, print_freq, header=None):
                if header: print(header)
                from tqdm.auto import tqdm 
                for i, data_batch in enumerate(tqdm(iterable, desc=header if header else "Evaluating", leave=False)):
                    yield data_batch
        current_metric_logger_eval = MockMetricLoggerEval(delimiter="  ")

    header = "Test:"
    actual_dataset_for_coco = data_loader.dataset 
    if isinstance(data_loader.dataset, torch.utils.data.Subset):
        actual_dataset_for_coco = data_loader.dataset.dataset

    coco = None; coco_evaluator = None
    if coco_utils_available_flag:
        try:
            coco = get_coco_api_from_dataset(actual_dataset_for_coco)
            if coco:
                try: iou_types = _get_iou_types(model)
                except NameError: iou_types = ["bbox"]
                coco_evaluator = CocoEvaluator(coco, iou_types)
        except Exception as e_coco_api: print(f"ERROR (evaluate) during get_coco_api or CocoEvaluator: {e_coco_api}")

    all_predictions_for_curves = [] 
    all_targets_for_curves = []    
    all_pred_labels_for_cm = []
    all_gt_labels_for_cm = []
    optimal_confidence_threshold_for_cm = 0.5

    inference_times = []
    num_processed_frames = 0
    iterable_eval = current_metric_logger_eval.log_every(data_loader, 100, header) 

    for images, targets in iterable_eval:
        if images is None or targets is None:
            continue
        images_for_model = list(img.to(device) for img in images)
        targets_cpu = [{k: v.cpu().clone() if isinstance(v, torch.Tensor) else v for k,v in t.items()} for t in targets]
        if torch.cuda.is_available(): torch.cuda.synchronize()
        model_time_start = time.time(); outputs = model(images_for_model); model_time_end = time.time(); model_time = model_time_end - model_time_start
        outputs_cpu = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
        inference_times.append(model_time); num_processed_frames += len(images)
        res = {}
        if coco_evaluator:
            for t_cpu, o_cpu in zip(targets_cpu, outputs_cpu):
                if "image_id" in t_cpu and isinstance(t_cpu["image_id"], torch.Tensor) and t_cpu["image_id"].numel()==1:
                    img_id_val = t_cpu["image_id"].item()
                    res[img_id_val] = o_cpu
            if res: coco_evaluator.update(res)
        
        for i_img in range(len(outputs_cpu)):
            all_predictions_for_curves.append(outputs_cpu[i_img]) 
            all_targets_for_curves.append(targets_cpu[i_img])    
            
        if metric_logger_available_flag: current_metric_logger_eval.update(model_time=model_time)
            
    if metric_logger_available_flag: 
        current_metric_logger_eval.synchronize_between_processes()
        print("Averaged stats (MetricLogger):", current_metric_logger_eval)
        
    if coco_evaluator:
        try: coco_evaluator.synchronize_between_processes(); coco_evaluator.accumulate(); coco_evaluator.summarize()
        except Exception as e_coco_final: print(f"COCO finalization error: {e_coco_final}"); import traceback; traceback.print_exc()
    torch.set_num_threads(n_threads)
    metrics = {"mAP_0.5_0.95": -1., "mAP_0.5": -1., "Precision_coco": -1., "Recall_coco": -1., "F1_coco": -1., "FPS": -1.}
    coco_eval_stats_for_curves = None 
    if coco_utils_available_flag and coco_evaluator and hasattr(coco_evaluator, 'coco_eval'):
        for iou_type, coco_eval_obj_instance in coco_evaluator.coco_eval.items(): 
            if coco_eval_obj_instance and hasattr(coco_eval_obj_instance, 'stats') and coco_eval_obj_instance.stats is not None:
                stats = coco_eval_obj_instance.stats
                if iou_type == "bbox": 
                    metrics["mAP_0.5_0.95"] = round(stats[0],4) if len(stats)>0 else -1.
                    metrics["mAP_0.5"]=round(stats[1],4) if len(stats)>1 else -1.
                    precision_coco = metrics["mAP_0.5"] 
                    recall_coco = round(stats[8], 4) if len(stats) > 8 else -1.0 
                    metrics["Precision_coco"] = precision_coco
                    metrics["Recall_coco"] = recall_coco
                    if precision_coco > 0 and recall_coco > 0:
                        metrics["F1_coco"] = round(2 * (precision_coco * recall_coco) / (precision_coco + recall_coco + 1e-9), 4)
                    else: metrics["F1_coco"] = -1.0
                    coco_eval_stats_for_curves = coco_eval_obj_instance 
    if inference_times and num_processed_frames > 0: 
        total_inf_time = sum(inference_times)
        fps = num_processed_frames / total_inf_time if total_inf_time > 0 else 0.0
        metrics["FPS"] = round(fps, 2)
    
    metrics['coco_eval_obj_for_curves'] = coco_eval_stats_for_curves 
    
    print("Calculating P, R, F1 vs Confidence metrics...")
    confidence_thresholds_custom, p_per_class, r_per_class, f1_per_class, p_overall, r_overall, f1_overall = \
        calculate_metrics_vs_confidence(all_predictions_for_curves, all_targets_for_curves, 
                                        num_classes_no_bg=num_classes_no_bg_for_curves, 
                                        inv_class_mapping=inv_class_mapping_for_curves, 
                                        iou_threshold=iou_threshold)
    metrics['custom_metrics_data'] = {
        'conf_thresholds': confidence_thresholds_custom,
        'p_per_class': p_per_class, 'r_per_class': r_per_class, 'f1_per_class': f1_per_class,
        'p_overall': p_overall, 'r_overall': r_overall, 'f1_overall': f1_overall
    }
    print("Finished calculating P, R, F1 vs Confidence metrics.")
    
    if f1_overall and any(f > 0 for f in f1_overall):
        best_f1_idx_overall = np.argmax(f1_overall)
        optimal_confidence_threshold_for_cm = confidence_thresholds_custom[best_f1_idx_overall]
        print(f"Using optimal confidence threshold for CM: {optimal_confidence_threshold_for_cm:.3f} (maximizes overall F1)")

    for i in range(len(all_predictions_for_curves)):
        preds_img = all_predictions_for_curves[i]
        targets_img = all_targets_for_curves[i]
        gt_boxes_img = targets_img['boxes'].cpu().numpy()
        gt_labels_img = targets_img['labels'].cpu().numpy()
        pred_scores_img = preds_img['scores'].cpu().numpy()
        confident_preds_mask = pred_scores_img >= optimal_confidence_threshold_for_cm
        pred_boxes_img_conf = preds_img['boxes'].cpu().numpy()[confident_preds_mask]
        pred_labels_img_conf = preds_img['labels'].cpu().numpy()[confident_preds_mask]

        matched_gt_for_cm = [False] * len(gt_boxes_img)
        for pred_idx in range(len(pred_boxes_img_conf)):
            pred_box = pred_boxes_img_conf[pred_idx]
            pred_label = pred_labels_img_conf[pred_idx]
            best_iou_cm = 0
            best_gt_idx_cm = -1
            for gt_idx in range(len(gt_boxes_img)):
                if not matched_gt_for_cm[gt_idx]:
                    iou = calculate_iou(pred_box, gt_boxes_img[gt_idx])
                    if iou > best_iou_cm:
                        best_iou_cm = iou
                        best_gt_idx_cm = gt_idx
            
            if best_iou_cm >= iou_threshold:
                if gt_labels_img[best_gt_idx_cm] == pred_label:
                    all_gt_labels_for_cm.append(gt_labels_img[best_gt_idx_cm])
                    all_pred_labels_for_cm.append(pred_label)
                else:
                    all_gt_labels_for_cm.append(gt_labels_img[best_gt_idx_cm])
                    all_pred_labels_for_cm.append(pred_label)
                matched_gt_for_cm[best_gt_idx_cm] = True
            else:
                if pred_label != CLASS_MAPPING['background']:
                    all_gt_labels_for_cm.append(CLASS_MAPPING['background'])
                    all_pred_labels_for_cm.append(pred_label)
        
        for gt_idx in range(len(gt_boxes_img)):
            if not matched_gt_for_cm[gt_idx]:
                if gt_labels_img[gt_idx] != CLASS_MAPPING['background']:
                    all_gt_labels_for_cm.append(gt_labels_img[gt_idx])
                    all_pred_labels_for_cm.append(CLASS_MAPPING['background'])
                
    metrics['cm_data'] = {'true': all_gt_labels_for_cm, 'pred': all_pred_labels_for_cm}
                
    return metrics

def _get_iou_types(model):
    model_without_ddp = model
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model_without_ddp = model.module
    import torchvision 
    iou_types = ["bbox"]
    if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN): iou_types.append("segm")
    if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN): iou_types.append("keypoints")
    return iou_types

## 4.1 Functions for Visualizing Results

In [None]:
def denormalize_image_tensor(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    """Denormalizes a tensor image with mean and standard deviation.
    Args:
        tensor (torch.Tensor): Tensor image of size (C, H, W) to be denormalized.
        mean (tuple): Mean for each channel.
        std (tuple): Standard deviation for each channel.
    Returns:
        torch.Tensor: Denormalized tensor image.
    """
    tensor = tensor.clone() 
    mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
    std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
    tensor.mul_(std).add_(mean) 
    return torch.clamp(tensor, 0, 1) 

def plot_coco_evaluation_curves(coco_eval_obj, save_dir, class_names_dict_inv):
    if coco_eval_obj is None or not hasattr(coco_eval_obj, 'eval') or not coco_eval_obj.eval or 'precision' not in coco_eval_obj.eval:
        print("No COCO evaluation data to draw standard PR curve.")
        return {}
    precision = coco_eval_obj.eval['precision'] 
    recall_thresholds = coco_eval_obj.params.recThrs 
    
    pr_path = os.path.join(save_dir, "PR_curve.png")
    plt.figure(figsize=(12, 9)) 
    
    if precision.shape[0] > 0 and precision.shape[2] > 0: 
        mean_precisions_iou05 = np.mean(precision[0, :, :, 0, 2], axis=1) 
        plt.plot(recall_thresholds, mean_precisions_iou05, color='navy', lw=3, 
                 label=f'Average for all classes (mAP@0.5 = {coco_eval_obj.stats[1]:.3f})')

    cat_ids_in_eval = coco_eval_obj.params.catIds
    if not isinstance(class_names_dict_inv, dict): class_names_dict_inv = {}
    
    num_categories_in_precision_matrix = precision.shape[2] 
    
    try:
        colors = plt.cm.get_cmap('tab10', num_categories_in_precision_matrix)
    except AttributeError: 
        colors = plt.cm.tab10

    for k_idx in range(num_categories_in_precision_matrix):
        if k_idx < len(cat_ids_in_eval): 
            cat_id = cat_ids_in_eval[k_idx]
            class_name = class_names_dict_inv.get(cat_id, f"ClsID {cat_id}")
            if class_name == 'background': continue
            precision_for_cat = precision[0, :, k_idx, 0, 2]
            ap_for_cat_iou05 = np.mean(precision_for_cat[precision_for_cat > -1]) if np.any(precision_for_cat > -1) else 0.0
            current_color = colors(k_idx % 10) if callable(colors) else colors.colors[k_idx % len(colors.colors)]
            plt.plot(recall_thresholds, precision_for_cat, color=current_color, lw=1.5, 
                     label=f'{class_name} (AP@0.5 = {ap_for_cat_iou05:.3f})')

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve for IoU=0.5 (from COCOeval)')
    plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.5), borderaxespad=0., fontsize='small')
    plt.grid(True)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.tight_layout(rect=[0,0,0.78,1]) 
    plt.savefig(pr_path)
    plt.close()
    print(f"Saved standard PR_curve (from COCOeval) to {pr_path}")
    return {"pr_curve": pr_path}

def plot_training_summary_curves(metrics_df, output_dir):
    if 'epoch' not in metrics_df.columns:
        print("Missing 'epoch' column in metrics_df. Cannot generate summary.")
        return
    epochs = metrics_df['epoch']

    metrics_to_plot_ordered = [
        ("train/box_loss", "loss_box_reg"),
        ("train/cls_loss", "loss_classifier"),
        ("train/obj_loss", "loss_objectness"),
        ("metrics/precision", "Precision_coco"),
        ("metrics/recall", "Recall_coco"),
        ("train/box_loss_RPN", "loss_rpn_box_reg"),
        ("train/loss_overall", "loss"),
        ("metrics/F1-score", "F1_coco"),
        ("metrics/mAP50", "mAP_0.5"),
        ("metrics/mAP50-95", "mAP_0.5_0.95")
    ]

    num_plots_defined = len(metrics_to_plot_ordered)
    num_rows_plot = 2
    num_cols_plot = 5

    fig, axs = plt.subplots(num_rows_plot, num_cols_plot, figsize=(6 * num_cols_plot, 5 * num_rows_plot), squeeze=False)
    axs = axs.flatten()
    fig.suptitle('Faster R-CNN Training & Validation Summary', fontsize=18)

    plot_idx = 0
    for plot_title, df_col_name in metrics_to_plot_ordered:
        if df_col_name in metrics_df.columns and metrics_df[df_col_name].notna().any():
            if plot_idx < len(axs):
                axs[plot_idx].plot(epochs, metrics_df[df_col_name], label=plot_title.split('/')[-1], marker='.')
                axs[plot_idx].set_title(plot_title, fontsize=10)
                axs[plot_idx].set_xlabel('Epoch')
                axs[plot_idx].grid(True)
                axs[plot_idx].legend(fontsize='small')
                plot_idx += 1
            else:
                print(f"Warning: Not enough subplot space for {plot_title}. Check grid configuration.")
                break 
        else:
            print(f"Warning: Metric '{df_col_name}' for title '{plot_title}' not found in metrics_df or contains only NaNs. Skipping this plot.")
            if plot_idx < len(axs):
                axs[plot_idx].axis('off')
                axs[plot_idx].set_title(f'{plot_title}\n(No data)', fontsize=9, color='grey')
                axs[plot_idx].grid(True)
                plot_idx += 1 
    
    for i in range(plot_idx, len(axs)):
        axs[i].axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    save_path = os.path.join(output_dir, "faster_rcnn_metrics_summary.png") 
    plt.savefig(save_path); plt.close(fig)
    print(f"Saved metrics summary (plots) to {save_path}")

def draw_predictions_on_image(image_pil, gt_boxes, gt_labels, pred_boxes, pred_labels, pred_scores, class_mapping_inv, score_thresh=0.3):
    draw = ImageDraw.Draw(image_pil)
    try: font = ImageFont.truetype("arial.ttf", 12) 
    except IOError: font = ImageFont.load_default() 
    if gt_boxes is not None and gt_labels is not None:
        for box, label_id in zip(gt_boxes, gt_labels):
            if isinstance(box, torch.Tensor): box = box.tolist()
            if isinstance(label_id, torch.Tensor): label_id = label_id.item()
            draw.rectangle(box, outline="lime", width=2) 
            label_text = class_mapping_inv.get(label_id, str(label_id))
            draw.text((box[0], box[1] - 12 if box[1] > 12 else box[1] + 1), f"GT: {label_text}", fill="lime", font=font)
    if pred_boxes is not None and pred_labels is not None and pred_scores is not None:
        for box, label_id, score in zip(pred_boxes, pred_labels, pred_scores):
            if isinstance(box, torch.Tensor): box = box.tolist()
            if isinstance(label_id, torch.Tensor): label_id = label_id.item()
            if isinstance(score, torch.Tensor): score = score.item()
            if score < score_thresh: continue 
            draw.rectangle(box, outline="red", width=2) 
            label_text = class_mapping_inv.get(label_id, str(label_id))
            draw.text((box[0], box[1] - 24 if box[1] > 24 else box[1] + 10), f"P: {label_text} {score:.2f}", fill="red", font=font)
    return image_pil

def save_prediction_examples(model, data_loader, device, num_images_to_save, output_dir, class_mapping_inv, epoch_num_str): 
    model.eval()
    saved_count = 0
    images_to_collate = []
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]

    if data_loader is None or (hasattr(data_loader, 'dataset') and len(data_loader.dataset) == 0):
        print("No data in data_loader for save_prediction_examples.")
        return

    data_iter = iter(data_loader)
    batch_size = data_loader.batch_size if data_loader and hasattr(data_loader, 'batch_size') and data_loader.batch_size else 1
    if batch_size == 0: batch_size = 1 
    batches_to_process = (num_images_to_save + batch_size - 1) // batch_size

    for _ in tqdm(range(batches_to_process), desc=f"Generating prediction examples (epoch {epoch_num_str})", leave=False):
        try:
            images_batch, targets_batch = next(data_iter)
        except StopIteration:
            break 
        if images_batch is None or targets_batch is None:
            continue

        images_for_model = list(img.to(device) for img in images_batch)
        with torch.no_grad():
            predictions = model(images_for_model)

        for i in range(len(images_batch)):
            if saved_count >= num_images_to_save:
                break
            img_tensor_normalized = images_batch[i].cpu()
            img_tensor_denormalized = denormalize_image_tensor(img_tensor_normalized, mean=imagenet_mean, std=imagenet_std)
            try:
                pil_img = torchvision.transforms.ToPILImage()(img_tensor_denormalized)
            except Exception as e_pil:
                print(f"Error converting tensor (after denormalization) to PIL: {e_pil}")
                continue
            gt_boxes = targets_batch[i].get('boxes', torch.empty(0,4))
            gt_labels = targets_batch[i].get('labels', torch.empty(0, dtype=torch.int64))
            pred_boxes = predictions[i].get('boxes', torch.empty(0,4)).cpu()
            pred_labels = predictions[i].get('labels', torch.empty(0, dtype=torch.int64)).cpu()
            pred_scores = predictions[i].get('scores', torch.empty(0)).cpu()
            drawn_image = draw_predictions_on_image(pil_img, gt_boxes, gt_labels, pred_boxes, pred_labels, pred_scores, class_mapping_inv)
            images_to_collate.append(drawn_image)
            saved_count += 1
        if saved_count >= num_images_to_save:
            break
    
    if not images_to_collate:
        print("Failed to generate any images with predictions.")
        return

    num_to_collate = min(len(images_to_collate), 16) 
    images_to_collate = images_to_collate[:num_to_collate]
    if not images_to_collate: return

    num_cols = 4
    num_rows = (len(images_to_collate) + num_cols - 1) // num_cols
    first_image_width = images_to_collate[0].width
    first_image_height = images_to_collate[0].height
    grid_img_width = first_image_width * num_cols
    grid_img_height = first_image_height * num_rows
    grid_image = Image.new('RGB', (grid_img_width, grid_img_height), color='white')

    for idx, img_to_paste in enumerate(images_to_collate):
        row_idx = idx // num_cols
        col_idx = idx % num_cols
        if img_to_paste.size != (first_image_width, first_image_height):
            img_to_paste = img_to_paste.resize((first_image_width, first_image_height))
        grid_image.paste(img_to_paste, (col_idx * first_image_width, row_idx * first_image_height))
    save_path = os.path.join(output_dir, "predictions.jpg") 
    grid_image.save(save_path)
    print(f"Saved example predictions to {save_path}")

def save_ground_truth_examples(data_loader, device, num_images_to_save, output_dir, class_mapping_inv, epoch_num_str_for_desc=""):
    saved_count = 0
    images_to_collate = []
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]

    if data_loader is None or (hasattr(data_loader, 'dataset') and len(data_loader.dataset) == 0):
        print("No data in data_loader for save_ground_truth_examples.")
        return

    data_iter = iter(data_loader)
    batch_size = data_loader.batch_size if data_loader and hasattr(data_loader, 'batch_size') and data_loader.batch_size else 1
    if batch_size == 0: batch_size = 1
    batches_to_process = (num_images_to_save + batch_size - 1) // batch_size
    desc_str = "Generating ground truth examples"
    if epoch_num_str_for_desc:
        desc_str += f" (epoch {epoch_num_str_for_desc})"

    for _ in tqdm(range(batches_to_process), desc=desc_str, leave=False):
        try:
            images_batch, targets_batch = next(data_iter)
        except StopIteration:
            break 
        if images_batch is None or targets_batch is None:
            continue

        for i in range(len(images_batch)):
            if saved_count >= num_images_to_save:
                break
            img_tensor_normalized = images_batch[i].cpu()
            img_tensor_denormalized = denormalize_image_tensor(img_tensor_normalized, mean=imagenet_mean, std=imagenet_std)
            try:
                pil_img = torchvision.transforms.ToPILImage()(img_tensor_denormalized)
            except Exception as e_pil:
                print(f"Error converting tensor (after denormalization) to PIL in GT: {e_pil}")
                continue
            gt_boxes = targets_batch[i].get('boxes', torch.empty(0,4))
            gt_labels = targets_batch[i].get('labels', torch.empty(0, dtype=torch.int64))
            drawn_image = draw_predictions_on_image(pil_img, gt_boxes, gt_labels, 
                                                    pred_boxes=None, pred_labels=None, pred_scores=None, 
                                                    class_mapping_inv=class_mapping_inv)
            images_to_collate.append(drawn_image)
            saved_count += 1
        if saved_count >= num_images_to_save:
            break
    
    if not images_to_collate:
        print("Failed to generate any images with ground truth.")
        return

    num_to_collate = min(len(images_to_collate), 16)
    images_to_collate = images_to_collate[:num_to_collate]
    if not images_to_collate: return

    num_cols = 4
    num_rows = (len(images_to_collate) + num_cols - 1) // num_cols
    first_image_width = images_to_collate[0].width
    first_image_height = images_to_collate[0].height
    grid_img_width = first_image_width * num_cols
    grid_img_height = first_image_height * num_rows
    grid_image = Image.new('RGB', (grid_img_width, grid_img_height), color='white')

    for idx, img_to_paste in enumerate(images_to_collate):
        row_idx = idx // num_cols
        col_idx = idx % num_cols
        if img_to_paste.size != (first_image_width, first_image_height):
            img_to_paste = img_to_paste.resize((first_image_width, first_image_height))
        grid_image.paste(img_to_paste, (col_idx * first_image_width, row_idx * first_image_height))
    save_path = os.path.join(output_dir, "ground_truth_examples.jpg") 
    grid_image.save(save_path)
    print(f"Saved example ground truth images to {save_path}")

def calculate_iou(box1, box2):
    """Calculates Intersection over Union (IoU) between two boxes.
    Boxes are in [xmin, ymin, xmax, ymax] format.
    """
    x1_inter = max(box1[0], box2[0])
    y1_inter = max(box1[1], box2[1])
    x2_inter = min(box1[2], box2[2])
    y2_inter = min(box1[3], box2[3])
    width_inter = max(0, x2_inter - x1_inter)
    height_inter = max(0, y2_inter - y1_inter)
    area_inter = width_inter * height_inter
    area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    area_union = area_box1 + area_box2 - area_inter
    iou = area_inter / area_union if area_union > 0 else 0.0
    return iou

def calculate_metrics_vs_confidence(all_predictions, all_targets, num_classes_no_bg, inv_class_mapping, iou_threshold=0.5, confidence_thresholds=None):
    if confidence_thresholds is None:
        confidence_thresholds = np.linspace(0.01, 1.0, 100) 
    actual_class_ids = [cls_id for cls_id in inv_class_mapping.keys() if cls_id != CLASS_MAPPING['background']] 
    precisions_per_class = {cls_id: [] for cls_id in actual_class_ids}
    recalls_per_class = {cls_id: [] for cls_id in actual_class_ids}
    f1_scores_per_class = {cls_id: [] for cls_id in actual_class_ids}
    overall_precisions = []
    overall_recalls = []
    overall_f1_scores = []

    for conf_thresh in tqdm(confidence_thresholds, desc="Calculating P/R/F1 vs Confidence"):
        total_tp_overall = 0; total_fp_overall = 0; total_fn_overall = 0
        tp_per_class = {cls_id: 0 for cls_id in actual_class_ids}
        fp_per_class = {cls_id: 0 for cls_id in actual_class_ids}
        fn_per_class = {cls_id: 0 for cls_id in actual_class_ids}

        for i in range(len(all_predictions)):
            preds_img = all_predictions[i]; targets_img = all_targets[i]
            gt_boxes_img = targets_img['boxes'].cpu().numpy(); gt_labels_img = targets_img['labels'].cpu().numpy()
            pred_scores_img = preds_img['scores'].cpu().numpy()
            confident_mask = pred_scores_img >= conf_thresh
            pred_boxes_img_conf = preds_img['boxes'].cpu().numpy()[confident_mask]
            pred_labels_img_conf = preds_img['labels'].cpu().numpy()[confident_mask]
            pred_scores_img_conf = pred_scores_img[confident_mask]

            if len(pred_boxes_img_conf) > 0:
                sorted_indices = np.argsort(pred_scores_img_conf)[::-1]
                pred_boxes_img_conf = pred_boxes_img_conf[sorted_indices]
                pred_labels_img_conf = pred_labels_img_conf[sorted_indices]

            matched_gt_indices = [False] * len(gt_boxes_img)
            for pred_idx in range(len(pred_boxes_img_conf)):
                pred_box = pred_boxes_img_conf[pred_idx]; pred_label = pred_labels_img_conf[pred_idx]
                best_iou = 0; best_gt_idx = -1
                for gt_idx in range(len(gt_boxes_img)):
                    if gt_labels_img[gt_idx] == pred_label and not matched_gt_indices[gt_idx]:
                        iou = calculate_iou(pred_box, gt_boxes_img[gt_idx])
                        if iou > best_iou: best_iou = iou; best_gt_idx = gt_idx
                if best_iou >= iou_threshold:
                    if pred_label in tp_per_class: tp_per_class[pred_label] += 1
                    total_tp_overall += 1; matched_gt_indices[best_gt_idx] = True
                else:
                    if pred_label in fp_per_class: fp_per_class[pred_label] += 1
                    total_fp_overall += 1
            for gt_idx in range(len(gt_boxes_img)):
                if not matched_gt_indices[gt_idx]:
                    gt_label = gt_labels_img[gt_idx]
                    if gt_label in fn_per_class: fn_per_class[gt_label] += 1; total_fn_overall +=1

        for cls_id in actual_class_ids:
            p_cls = tp_per_class[cls_id] / (tp_per_class[cls_id] + fp_per_class[cls_id]) if (tp_per_class[cls_id] + fp_per_class[cls_id]) > 0 else 0
            r_cls = tp_per_class[cls_id] / (tp_per_class[cls_id] + fn_per_class[cls_id]) if (tp_per_class[cls_id] + fn_per_class[cls_id]) > 0 else 0
            f1_cls = 2 * p_cls * r_cls / (p_cls + r_cls) if (p_cls + r_cls) > 0 else 0
            precisions_per_class[cls_id].append(p_cls); recalls_per_class[cls_id].append(r_cls); f1_scores_per_class[cls_id].append(f1_cls)
        p_overall = total_tp_overall / (total_tp_overall + total_fp_overall) if (total_tp_overall + total_fp_overall) > 0 else 0
        r_overall = total_tp_overall / (total_tp_overall + total_fn_overall) if (total_tp_overall + total_fn_overall) > 0 else 0
        f1_overall = 2 * p_overall * r_overall / (p_overall + r_overall) if (p_overall + r_overall) > 0 else 0
        overall_precisions.append(p_overall); overall_recalls.append(r_overall); overall_f1_scores.append(f1_overall)
    return confidence_thresholds, precisions_per_class, recalls_per_class, f1_scores_per_class, overall_precisions, overall_recalls, overall_f1_scores

def plot_custom_metric_curves(custom_metrics_data, save_dir, inv_class_mapping):
    if not custom_metrics_data:
        print("No data to draw custom metric curves."); return
    conf_thresholds = custom_metrics_data['conf_thresholds']
    p_per_class = custom_metrics_data['p_per_class']; r_per_class = custom_metrics_data['r_per_class']; f1_per_class = custom_metrics_data['f1_per_class']
    p_overall = custom_metrics_data['p_overall']; r_overall = custom_metrics_data['r_overall']; f1_overall = custom_metrics_data['f1_overall']
    actual_class_ids = [cls_id for cls_id in inv_class_mapping.keys() if cls_id != CLASS_MAPPING['background']]
    num_actual_classes = len(actual_class_ids)
    colors = plt.cm.get_cmap('tab10', num_actual_classes) if num_actual_classes > 0 else plt.cm.get_cmap('tab10')
    
    best_f1_idx_overall = -1
    best_conf_for_f1 = -1
    if f1_overall and any(f1 > 0 for f1 in f1_overall):
        best_f1_idx_overall = np.argmax(f1_overall)
        best_conf_for_f1 = conf_thresholds[best_f1_idx_overall]

    plt.figure(figsize=(10, 7))
    for i, cls_id in enumerate(actual_class_ids):
        plt.plot(conf_thresholds, p_per_class[cls_id], color=colors(i % 10), lw=1, label=f'{inv_class_mapping[cls_id]}')
    if best_f1_idx_overall != -1:
        best_p_overall = p_overall[best_f1_idx_overall]
        plt.plot(conf_thresholds, p_overall, color='blue', lw=3, label=f'All classes {best_p_overall:.2f} at {best_conf_for_f1:.3f} conf (for max F1)')
    else:
        plt.plot(conf_thresholds, p_overall, color='blue', lw=3, label='All classes (no F1 data)')
    plt.xlabel('Confidence Threshold'); plt.ylabel('Precision'); plt.title('Precision-Confidence Curve')
    plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.5)); plt.grid(True); plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
    plt.tight_layout(rect=[0,0,0.8,1]); plt.savefig(os.path.join(save_dir, "P_curve.png")); plt.close()
    print(f"Saved P_curve.png to {save_dir}")

    plt.figure(figsize=(10, 7))
    for i, cls_id in enumerate(actual_class_ids):
        plt.plot(conf_thresholds, r_per_class[cls_id], color=colors(i % 10), lw=1, label=f'{inv_class_mapping[cls_id]}')
    if best_f1_idx_overall != -1:
        best_r_overall = r_overall[best_f1_idx_overall]
        plt.plot(conf_thresholds, r_overall, color='blue', lw=3, label=f'All classes {best_r_overall:.2f} at {best_conf_for_f1:.3f} conf (for max F1)')
    else:
        plt.plot(conf_thresholds, r_overall, color='blue', lw=3, label='All classes (no F1 data)')
    plt.xlabel('Confidence Threshold'); plt.ylabel('Recall'); plt.title('Recall-Confidence Curve')
    plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.5)); plt.grid(True); plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
    plt.tight_layout(rect=[0,0,0.8,1]); plt.savefig(os.path.join(save_dir, "R_curve.png")); plt.close()
    print(f"Saved R_curve.png to {save_dir}")

    plt.figure(figsize=(10, 7))
    for i, cls_id in enumerate(actual_class_ids):
        plt.plot(conf_thresholds, f1_per_class[cls_id], color=colors(i % 10), lw=1, label=f'{inv_class_mapping[cls_id]}')
    if best_f1_idx_overall != -1:
        best_f1_overall_val = f1_overall[best_f1_idx_overall]
        plt.plot(conf_thresholds, f1_overall, color='blue', lw=3, label=f'All classes {best_f1_overall_val:.2f} at {best_conf_for_f1:.3f} conf')
    else:
        plt.plot(conf_thresholds, f1_overall, color='blue', lw=3, label='All classes (no F1 data)')
    plt.xlabel('Confidence Threshold'); plt.ylabel('F1-Score'); plt.title('F1-Confidence Curve')
    plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.5)); plt.grid(True); plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
    plt.tight_layout(rect=[0,0,0.8,1]); plt.savefig(os.path.join(save_dir, "F1_curve.png")); plt.close()
    print(f"Saved F1_curve.png to {save_dir}")

def plot_confusion_matrices(cm_data, class_names_no_bg, output_dir):
    if not cm_data or not cm_data['true'] or not cm_data['pred']:
        print("No data to generate confusion matrix.")
        return
    
    true_labels_int = cm_data['true']
    pred_labels_int = cm_data['pred']
    
    all_possible_labels_ids = sorted(list(CLASS_MAPPING.values()))
    all_possible_labels_names = [INV_CLASS_MAPPING.get(l_id, f'ID {l_id}') for l_id in all_possible_labels_ids]

    if not true_labels_int and not pred_labels_int:
        print("No labels (true or predicted) to generate confusion matrix.")
        cm = np.zeros((len(all_possible_labels_ids), len(all_possible_labels_ids)), dtype=int)
    else:
        cm = confusion_matrix(true_labels_int, pred_labels_int, labels=all_possible_labels_ids)

    plt.figure(figsize=(12, 10)) 
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=all_possible_labels_names, yticklabels=all_possible_labels_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "confusion_matrix.png"))
    plt.close()
    print(f"Saved confusion_matrix.png to {output_dir}")

    cm_sum_axis1 = cm.sum(axis=1)[:, np.newaxis]
    cm_normalized = np.zeros_like(cm, dtype=float)
    np.divide(cm, cm_sum_axis1, out=cm_normalized, where=cm_sum_axis1!=0)
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=all_possible_labels_names, yticklabels=all_possible_labels_names)
    plt.title('Normalized Confusion Matrix (by true labels)')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "confusion_matrix_normalized.png"))
    plt.close()
    print(f"Saved confusion_matrix_normalized.png to {output_dir}")

## 5. Main Training Script (Combined Data)

In [None]:
print(f"--- Training Faster R-CNN on combined data (Day/Night) ---")

if not os.path.exists(TRAIN_ANNOTATIONS_FILE) or (os.path.exists(TRAIN_ANNOTATIONS_FILE) and os.path.getsize(TRAIN_ANNOTATIONS_FILE) < 50):
    print(f"CRITICAL ERROR: Training annotations file {TRAIN_ANNOTATIONS_FILE} does not exist or is empty.")
elif not os.path.exists(VAL_ANNOTATIONS_FILE) or (os.path.exists(VAL_ANNOTATIONS_FILE) and os.path.getsize(VAL_ANNOTATIONS_FILE) < 50):
    print(f"WARNING: Validation annotations file {VAL_ANNOTATIONS_FILE} does not exist or is empty. Evaluation may not be possible.")

dataset_train_full = LISADataset(annotations_file=TRAIN_ANNOTATIONS_FILE,
                                   img_base_dir=IMAGES_BASE_DIR, 
                                   transforms=get_train_transforms(apply_augmentations=APPLY_AUGMENTATIONS), 
                                   class_mapping=CLASS_MAPPING)

dataset_val_full = LISADataset(annotations_file=VAL_ANNOTATIONS_FILE,
                                 img_base_dir=IMAGES_BASE_DIR, 
                                 transforms=get_val_test_transforms(), 
                                 class_mapping=CLASS_MAPPING)

final_dataset_train = dataset_train_full 
final_dataset_val = dataset_val_full    

if USE_SUBSET_DATA:
    len_train_full = len(dataset_train_full) if dataset_train_full and hasattr(dataset_train_full, 'image_filenames') and len(dataset_train_full.image_filenames)>0 else 0
    len_val_full = len(dataset_val_full) if dataset_val_full and hasattr(dataset_val_full, 'image_filenames') and len(dataset_val_full.image_filenames)>0 else 0
    train_indices = torch.randperm(len_train_full)[:SUBSET_SIZE_TRAIN].tolist() if len_train_full > SUBSET_SIZE_TRAIN and len_train_full > 0 else list(range(len_train_full))
    val_indices = torch.randperm(len_val_full)[:SUBSET_SIZE_VAL].tolist() if len_val_full > SUBSET_SIZE_VAL and len_val_full > 0 else list(range(len_val_full))
    if not train_indices and len_train_full > 0: final_dataset_train = dataset_train_full 
    elif len_train_full == 0: final_dataset_train = dataset_train_full 
    else: final_dataset_train = Subset(dataset_train_full, train_indices)
    if not val_indices and len_val_full > 0: final_dataset_val = dataset_val_full 
    elif len_val_full == 0: final_dataset_val = dataset_val_full 
    else: final_dataset_val = Subset(dataset_val_full, val_indices)
    print(f"Using subset of data: {len(final_dataset_train)} training, {len(final_dataset_val)} validation.")
else:
    print(f"Using full dataset: {len(final_dataset_train)} training, {len(final_dataset_val)} validation.")

if APPLY_AUGMENTATIONS:
    print("Augmentations for the training set ARE ENABLED.")
else:
    print("Augmentations for the training set ARE DISABLED.")

if len(final_dataset_train) == 0:
    print("CRITICAL ERROR: Training set is empty. Training cannot start.")
else:
    data_loader_train = DataLoader(final_dataset_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=0, pin_memory=torch.cuda.is_available())
    data_loader_val = None
    if len(final_dataset_val) > 0:
        data_loader_val = DataLoader(final_dataset_val, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=0, pin_memory=torch.cuda.is_available())
    else: print("Validation set is empty, DataLoader for validation will not be created.")

    model = get_faster_rcnn_model(num_classes_incl_background=NUM_CLASSES_INC_BG, backbone_name="resnet50")
    model.to(DEVICE)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 

    all_metrics_history = []
    best_map_metric = -1.0
    epochs_without_improvement = 0
    last_eval_metrics_for_plotting = {} 

    for epoch_idx in range(NUM_EPOCHS):
        current_epoch_num = epoch_idx + 1
        avg_train_losses_dict = train_one_epoch(model, optimizer, data_loader_train, DEVICE, current_epoch_num, print_freq=PRINT_FREQ) 
        
        current_epoch_metrics_row = {"epoch": current_epoch_num}
        for loss_name, loss_val in avg_train_losses_dict.items():
            current_epoch_metrics_row[loss_name] = loss_val 

        eval_metrics_dict = {}
        if data_loader_val and len(final_dataset_val) > 0:
            print(f"--- Epoch {current_epoch_num} Evaluation ---")
            eval_metrics_dict = evaluate(model, data_loader_val, DEVICE, CLASS_MAPPING, 
                                         COCO_UTILS_AVAILABLE, METRIC_LOGGER_AVAILABLE,
                                         NUM_CLASSES_NO_BG, INV_CLASS_MAPPING) 
            last_eval_metrics_for_plotting = eval_metrics_dict 
            
            current_epoch_metrics_row["mAP_0.5"] = eval_metrics_dict.get("mAP_0.5", -1.0)
            current_epoch_metrics_row["mAP_0.5_0.95"] = eval_metrics_dict.get("mAP_0.5_0.95", -1.0)
            current_epoch_metrics_row["Precision_coco"] = eval_metrics_dict.get("Precision_coco", -1.0)
            current_epoch_metrics_row["Recall_coco"] = eval_metrics_dict.get("Recall_coco", -1.0)
            current_epoch_metrics_row["F1_coco"] = eval_metrics_dict.get("F1_coco", -1.0)
            current_epoch_metrics_row["FPS"] = eval_metrics_dict.get("FPS", -1.0)
            current_epoch_metrics_row["coco_eval_obj_for_curves"] = eval_metrics_dict.get('coco_eval_obj_for_curves', None)
            current_epoch_metrics_row["custom_metrics_data"] = eval_metrics_dict.get('custom_metrics_data', None)
            current_epoch_metrics_row["cm_data"] = eval_metrics_dict.get('cm_data', None) 

            all_metrics_history.append(current_epoch_metrics_row)
            current_map_val = eval_metrics_dict.get("mAP_0.5", -1.0)
            if current_map_val > best_map_metric:
                best_map_metric = current_map_val
                epochs_without_improvement = 0
                model_save_path = os.path.join(OUTPUT_DIR, "faster_rcnn_best_model.pth")
                torch.save(model.state_dict(), model_save_path)
                print(f"Epoch {current_epoch_num}: New best model saved to {model_save_path} with mAP@0.5: {best_map_metric:.4f}")
            else:
                epochs_without_improvement += 1
                print(f"Epoch {current_epoch_num}: No improvement in mAP@0.5. Best: {best_map_metric:.4f}. Epochs without improvement: {epochs_without_improvement}")
            if epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
                print(f"Early stopping! No improvement in mAP@0.5 for {EARLY_STOPPING_PATIENCE} epochs.")
                break 
        else:
            all_metrics_history.append(current_epoch_metrics_row) 

        lr_scheduler.step() 

    if all_metrics_history: 
        metrics_df = pd.DataFrame(all_metrics_history)
        coco_eval_object_for_plotting = None
        custom_metrics_data_for_plotting = None
        cm_data_for_plotting = None

        if 'coco_eval_obj_for_curves' in metrics_df.columns:
            valid_coco_evals = metrics_df['coco_eval_obj_for_curves'].dropna()
            if not valid_coco_evals.empty:
                coco_eval_object_for_plotting = valid_coco_evals.iloc[-1]
        if 'custom_metrics_data' in metrics_df.columns:
            valid_custom_metrics = metrics_df['custom_metrics_data'].dropna()
            if not valid_custom_metrics.empty:
                custom_metrics_data_for_plotting = valid_custom_metrics.iloc[-1]
        if 'cm_data' in metrics_df.columns: 
            valid_cm_data = metrics_df['cm_data'].dropna()
            if not valid_cm_data.empty:
                cm_data_for_plotting = valid_cm_data.iloc[-1]

        cols_to_drop_for_csv = ['coco_eval_obj_for_curves', 'custom_metrics_data', 'cm_data'] 
        metrics_df_for_csv = metrics_df.drop(columns=cols_to_drop_for_csv, errors='ignore')
        
        metrics_filename = os.path.join(OUTPUT_DIR, "faster_rcnn_training_metrics.csv")
        metrics_df_for_csv.to_csv(metrics_filename, index=False)
        print(f"Training metrics saved to {metrics_filename}")
        
        plot_training_summary_curves(metrics_df, OUTPUT_DIR) 

        if coco_eval_object_for_plotting: 
            plot_coco_evaluation_curves(coco_eval_object_for_plotting, VISUALIZATIONS_DIR, INV_CLASS_MAPPING)
        elif last_eval_metrics_for_plotting.get('coco_eval_obj_for_curves'): 
             plot_coco_evaluation_curves(last_eval_metrics_for_plotting['coco_eval_obj_for_curves'], VISUALIZATIONS_DIR, INV_CLASS_MAPPING)
        else:
            print("COCOeval object not found to generate standard PR curve.")
        
        if custom_metrics_data_for_plotting:
            plot_custom_metric_curves(custom_metrics_data_for_plotting, VISUALIZATIONS_DIR, INV_CLASS_MAPPING)
        elif last_eval_metrics_for_plotting.get('custom_metrics_data'):
            plot_custom_metric_curves(last_eval_metrics_for_plotting.get('custom_metrics_data'), VISUALIZATIONS_DIR, INV_CLASS_MAPPING)
        else:
            print("Data not found to generate P, R, F1 vs Confidence curves.")
        
        if cm_data_for_plotting:
            plot_confusion_matrices(cm_data_for_plotting, CLASS_NAMES_NO_BG, VISUALIZATIONS_DIR)
        elif last_eval_metrics_for_plotting.get('cm_data'):
            plot_confusion_matrices(last_eval_metrics_for_plotting.get('cm_data'), CLASS_NAMES_NO_BG, VISUALIZATIONS_DIR)
        else:
            print("Data not found to generate confusion matrices.")

    else: print("No metrics collected to save.")
    
    if data_loader_val and len(final_dataset_val) > 0:
        print("\nSaving example predictions and ground truth from the last/best model...")
        best_model_path = os.path.join(OUTPUT_DIR, "faster_rcnn_best_model.pth")
        final_epoch_num_str_desc = str(all_metrics_history[-1]['epoch']) if all_metrics_history else 'final'
        
        vis_model = get_faster_rcnn_model(num_classes_incl_background=NUM_CLASSES_INC_BG, backbone_name="resnet50")
        if os.path.exists(best_model_path):
            print(f"Loading best model from {best_model_path} for visualization.")
            vis_model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
        else: 
            print("Using last model state from training loop for visualization ('best_model.pth' not found).")
            vis_model.load_state_dict(model.state_dict()) 
        vis_model.to(DEVICE)
        vis_model.eval() 

        save_prediction_examples(vis_model, data_loader_val, DEVICE, 
                                 num_images_to_save=16, 
                                 output_dir=VISUALIZATIONS_DIR, 
                                 class_mapping_inv=INV_CLASS_MAPPING, 
                                 epoch_num_str=final_epoch_num_str_desc) 

        save_ground_truth_examples(data_loader_val, DEVICE, 
                                   num_images_to_save=16, 
                                   output_dir=VISUALIZATIONS_DIR, 
                                   class_mapping_inv=INV_CLASS_MAPPING,
                                   epoch_num_str_for_desc=final_epoch_num_str_desc)


print("\n--- Faster R-CNN training completed ---")