In [1]:
import sys
import os

sys.path.append('/gpfs/helios/home/ploter/projects/MultiSensorDropout/')

In [2]:
import argparse
import torch

args = argparse.Namespace(
    # Basic training parameters
    seed=42,
    batch_size=1,
    epochs=18,
    learning_rate=1e-3,
    learning_rate_backbone=1e-4,
    learning_rate_backbone_names=["backbone"],
    weight_decay=0.01,
    scheduler_step_size=12,
    eval_interval=1,
    patience=5,
    model='perceiver',
    backbone='cnn',  # Set specific value for evaluation
    eval=True,
    weight_loss_center_point=5,
    weight_loss_bce=1,
    shuffle_views=False,
    object_detection=True,  # Set to True for object detection
    resize_frame = None,
    
    # Matcher parameters
    set_cost_class=1,
    set_cost_bbox=5,
    set_cost_giou=2,
    focal_alpha=0.25,
    focal_gamma=2,
    
    # Loss coefficients
    bbox_loss_coef=5,
    giou_loss_coef=2,
    eos_coef=0.1,
    
    # Checkpoint parameters
    resume='checkpoint_epoch_17.pth',  # Replace with actual checkpoint path
    output_dir="./output",  # Provide a default output directory
    device='cuda' if torch.cuda.is_available() else 'cpu',
    
    # Dataset parameters
    dataset='moving-mnist',
    dataset_path='Max-Ploter/detection-moving-mnist-easy',
    generate_dataset_runtime=False,
    num_workers=4,
    num_frames=20,
    train_dataset_fraction=1.0,
    train_dataset_size=1.0,  # Additional parameter from notebook
    test_dataset_fraction=1.0,
    frame_dropout_pattern=None,
    view_dropout_probs=[],#[0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85],
    sampler_steps=[], #[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
    sequential_sampler=False,
    grid_size=(1, 1),
    tile_overlap=0.0,
    
    # Wandb parameters
    wandb_project='multi-sensor-dropout',
    wandb_id=None,
    
    # Perceiver model parameters
    num_freq_bands=6,  # Set to 4 as in the notebook
    max_freq=10, 
    enc_layers=1,
    num_queries=16,  # Set to 16 as in the notebook
    hidden_dim=128,
    enc_nheads_cross=1,
    nheads=1,
    dropout=0.0,
    self_per_cross_attn=1,
    multi_classification_heads=False,
    
    # LSTM model parameters
    lstm_hidden_size=128,
    
    # Additional parameters for complete compatibility
    focal_loss=True,
)

In [3]:
class DummyCriterion:
    def __init__(self, device):
        # Initialize with the necessary attributes
        self.device = device
        self.weight_dict = {'loss_ce': 1}
        
    def __call__(self, outputs, targets, *args, **kwargs):
        # Return a dictionary with zero losses to maintain the expected interface
        return {
            'loss_ce': torch.tensor(0.0, device=self.device),
            'loss_bbox': torch.tensor(0.0, device=self.device),
            'loss_giou': torch.tensor(0.0, device=self.device),
            'loss': torch.tensor(0.0, device=self.device)
        }

In [4]:
import torch
from torch import nn

class PostProcessTopK(nn.Module):
    """ Wrapper that applies a post-processor and keeps only the top-k predictions by score and a score threshold """

    def __init__(self, post_processor, top_k=10, score_threshold=0.5):
        super().__init__()
        self.post_processor = post_processor
        self.top_k = top_k
        self.score_threshold = score_threshold

    @torch.no_grad()
    def forward(self, outputs, target_sizes):
        # Get results from the original post processor
        results = self.post_processor(outputs, target_sizes)

        # Filter to keep only top-k results and results above the score threshold
        filtered_results = []
        for result in results:
            scores, labels, boxes = result['scores'], result['labels'], result['boxes']

            # Get top-k indices by score
            top_k = min(self.top_k, len(scores))
            if top_k > 0:  # Check if there are any predictions
                top_indices = torch.topk(scores, top_k).indices

                # Filter by top-k indices
                top_scores = scores[top_indices]
                top_labels = labels[top_indices]
                top_boxes = boxes[top_indices]

                # Filter by score threshold
                threshold_indices = top_scores >= self.score_threshold
                final_scores = top_scores[threshold_indices]
                final_labels = top_labels[threshold_indices]
                final_boxes = top_boxes[threshold_indices]

                filtered_results.append({'scores': final_scores, 'labels': final_labels, 'boxes': final_boxes})
            else:
                # Keep empty result if no predictions
                filtered_results.append({'scores': torch.empty(0), 'labels': torch.empty(0, dtype=torch.int64), 'boxes': torch.empty(0, 4)})

        return filtered_results

### Seq NMS post processor

In [5]:
import torch
import torch.nn as nn

# Try to import seq_nms, handle import error
try:
    from pt_seq_nms import seq_nms_from_list
except ImportError:
    print("Warning: 'pt_seq_nms' library not found. SeqNMSEvaluationPostprocessor will not function.")
    print("Install using: pip install git+https://github.com/MrParosk/seq_nms.git")
    seq_nms_from_list = None


class SeqNMSEvaluationPostprocessor(nn.Module):
    """
    Applies standard post-processing (using the *original* PostProcess logic)
    followed by Sequence Non-Maximum Suppression (Seq-NMS) to the outputs
    of a video object detector (like RecurrentVideoObjectModule).

    It first converts raw outputs frame-by-frame, filters low-confidence
    detections, applies Seq-NMS, and formats the final results.

    Assumes the original PostProcess class (as provided in the initial prompt)
    is available and handles the conversion from raw model output (logits, rel boxes)
    to scores, labels, and absolute boxes for foreground classes.
    """
    def __init__(self,
                 postprocessor: nn.Module, # Expects an instance of the original PostProcess
                 min_score_threshold: float = 0.01, # Threshold BEFORE SeqNMS
                 linkage_threshold: float = 0.5,
                 iou_threshold: float = 0.5):
        """
        Initializes the SeqNMSEvaluationPostprocessor.

        Args:
            postprocessor (nn.Module): An instance of the *original* PostProcess
                                       module definition.
            min_score_threshold (float): Minimum confidence score required for a
                                         detection to be considered *before* Seq-NMS.
                                         Helps filter out noise from padding/background.
                                         Defaults to 0.01.
            linkage_threshold (float): The linkage threshold used by Seq-NMS. Defaults to 0.5.
            iou_threshold (float): The IoU threshold used by Seq-NMS. Defaults to 0.5.

        Raises:
            ImportError: If the 'pt_seq_nms' library is not installed.
            TypeError: If the provided 'postprocessor' is not an instance of nn.Module.
        """
        super().__init__()
        if seq_nms_from_list is None:
            # Error message printed during import attempt
            raise ImportError("'pt_seq_nms' library is required but could not be imported.")
        if not isinstance(postprocessor, nn.Module):
             # Basic check, could be more specific if PostProcess class name is known
             raise TypeError(f"Argument 'postprocessor' must be an instance of nn.Module (expecting original PostProcess), but got {type(postprocessor)}")

        self.postprocess = postprocessor
        self.min_score_threshold = min_score_threshold
        self.linkage_threshold = linkage_threshold
        self.iou_threshold = iou_threshold
        print(f"Initialized SeqNMSEvaluationPostprocessor with min_score_thresh={min_score_threshold}, "
              f"linkage_thresh={linkage_threshold}, iou_thresh={iou_threshold}")


    @torch.no_grad()
    def forward(self, outputs: dict, target_sizes: torch.Tensor) -> list[dict]:
        """
        Processes the model output by applying PostProcess, pre-filtering, and then Seq-NMS.

        Args:
            outputs (dict): Output dictionary from the video object detector model.
                            Expected keys and tensor shapes (assuming batch size B=1):
                             - 'pred_logits': Tensor [T, NumQueries, NumClasses]
                             - 'pred_boxes': Tensor [T, NumQueries, 4] (relative cxcywh)
                             Where T is the number of frames.
            target_sizes (torch.Tensor): Tensor of shape [T, 2] containing the original
                                         (height, width) for each of the T frames.

        Returns:
            list[dict]: A list containing T dictionaries, one for each frame.
                        Each dictionary represents the detections kept after Seq-NMS:
                        - 'scores': Tensor [num_kept_detections], original scores of kept detections.
                        - 'labels': Tensor [num_kept_detections], original labels of kept detections.
                        - 'boxes': Tensor [num_kept_detections, 4] (absolute xyxy format).
                        If no detections are kept for a frame, the tensors in the dict will be empty.
        """
        pred_logits = outputs.get('pred_logits')
        pred_boxes = outputs.get('pred_boxes')

        # --- Input Validation ---
        if pred_logits is None or pred_boxes is None:
            raise KeyError("Input 'outputs' dictionary must contain 'pred_logits' and 'pred_boxes'.")
        if not isinstance(pred_logits, torch.Tensor) or not isinstance(pred_boxes, torch.Tensor):
             raise TypeError("'pred_logits' and 'pred_boxes' must be torch tensors.")
        if pred_logits.dim() != 3 or pred_boxes.dim() != 3:
             raise ValueError(f"Expected 3D tensors for logits and boxes [T, NumQueries, Dim], got {pred_logits.shape} and {pred_boxes.shape}")
        if pred_logits.shape[0] != pred_boxes.shape[0] or pred_logits.shape[1] != pred_boxes.shape[1]:
             raise ValueError(f"Shape mismatch between 'pred_logits' {pred_logits.shape} and 'pred_boxes' {pred_boxes.shape}")

        num_frames = pred_logits.shape[0]
        if not isinstance(target_sizes, torch.Tensor):
             raise TypeError("'target_sizes' must be a torch tensor.")
        if target_sizes.shape != (num_frames, 2):
            raise ValueError(f"Expected 'target_sizes' shape [{num_frames}, 2], got {target_sizes.shape}")

        # --- 1. Apply Original PostProcess & Filter Frame-by-Frame ---
        boxes_list_filtered = []
        scores_list_filtered = []
        labels_list_filtered = []
        print(f"Processing {num_frames} frames with original PostProcess and pre-filtering (min_score={self.min_score_threshold})...")
        num_total_before_seqnms = 0

        for t in range(num_frames):
            # Prepare input for PostProcess (needs batch dimension)
            frame_output = {
                'pred_logits': pred_logits[t:t+1], # Shape [1, NumQueries, NumClasses]
                'pred_boxes': pred_boxes[t:t+1]    # Shape [1, NumQueries, 4]
            }
            frame_target_size = target_sizes[t:t+1] # Shape [1, 2]

            # Run *original* PostProcess for the single frame
            # It returns a list containing one dictionary: [{'scores': [N], 'labels': [N], 'boxes': [N, 4]}]
            # N = NumQueries. Scores/Labels are from max over foreground classes.
            processed_result_list = self.postprocess(frame_output, frame_target_size)

            if not processed_result_list:
                 # Should not happen if PostProcess works correctly, but handle defensively
                 print(f"Warning: PostProcess returned empty list for frame {t}. Appending empty tensors.")
                 boxes_list_filtered.append(torch.empty((0, 4), device=pred_logits.device, dtype=torch.float))
                 scores_list_filtered.append(torch.empty(0, device=pred_logits.device, dtype=torch.float))
                 labels_list_filtered.append(torch.empty(0, device=pred_logits.device, dtype=torch.long))
                 continue

            # Extract the results dict for the frame
            result_t = processed_result_list[0]
            scores_t = result_t['scores']
            labels_t = result_t['labels']
            boxes_t = result_t['boxes']

            # --- Pre-filtering before SeqNMS ---
            # Keep only detections with score above the minimum threshold
            keep_mask_pre = scores_t >= self.min_score_threshold
            num_kept_frame_pre = keep_mask_pre.sum().item()
            num_total_before_seqnms += num_kept_frame_pre

            # Append filtered tensors to the lists for SeqNMS input
            boxes_list_filtered.append(boxes_t[keep_mask_pre])
            scores_list_filtered.append(scores_t[keep_mask_pre])
            labels_list_filtered.append(labels_t[keep_mask_pre])

        print(f"Finished PostProcess & pre-filtering. Kept {num_total_before_seqnms} detections across {num_frames} frames.")

        # --- 2. Check if any detections remain for Seq-NMS ---
        if num_total_before_seqnms == 0:
            print("No detections passed the minimum score threshold. Skipping Seq-NMS.")
            # Return a list of empty dictionaries for each frame
            empty_results = []
            device = pred_logits.device # Get device for empty tensors
            for _ in range(num_frames):
                 empty_results.append({
                     'scores': torch.empty(0, device=device, dtype=torch.float),
                     'labels': torch.empty(0, device=device, dtype=torch.long),
                     'boxes': torch.empty((0, 4), device=device, dtype=torch.float)
                 })
            return empty_results

        # --- 3. Apply Seq-NMS ---
        print(f"Applying Seq-NMS with linkage_thresh={self.linkage_threshold}, iou_thresh={self.iou_threshold}...")
        try:
            # seq_nms_from_list operates on the pre-filtered lists
            # It returns a list of updated score tensors where suppressed detections have score 0.
            updated_scores_list = seq_nms_from_list(
                boxes_list=boxes_list_filtered,
                scores_list=scores_list_filtered, # Pass scores after pre-filtering
                classes_list=labels_list_filtered, # Pass labels after pre-filtering
                linkage_threshold=self.linkage_threshold,
                iou_threshold=self.iou_threshold
            )
            print("Seq-NMS applied successfully.")
        except Exception as e:
             # Catch potential errors during the C++/CUDA call
             print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
             print(f"Error occurred during seq_nms_from_list execution: {e}")
             print(f"Inputs shapes summary (post-filtering):")
             print(f"  boxes_list_filtered: {len(boxes_list_filtered)} frames, shapes {[b.shape for b in boxes_list_filtered[:5]]}...")
             print(f"  scores_list_filtered: {len(scores_list_filtered)} frames, shapes {[s.shape for s in scores_list_filtered[:5]]}...")
             print(f"  labels_list_filtered: {len(labels_list_filtered)} frames, shapes {[l.shape for l in labels_list_filtered[:5]]}...")
             print(f"Returning results *before* Seq-NMS due to the error.")
             print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
             # Return the results obtained just after PostProcess and pre-filtering
             final_results_per_frame_error = []
             for i in range(num_frames):
                 final_results_per_frame_error.append({
                    'scores': scores_list_filtered[i],
                    'labels': labels_list_filtered[i],
                    'boxes': boxes_list_filtered[i]
                 })
             return final_results_per_frame_error

        # --- 4. Filter Results Based on Seq-NMS Output ---
        final_results_per_frame = []
        num_kept_total_post_seqnms = 0
        for i in range(num_frames):
            # Get the pre-filtered results for the frame (these were input to seq_nms)
            original_boxes = boxes_list_filtered[i]
            original_labels = labels_list_filtered[i]
            original_scores = scores_list_filtered[i] # Scores before Seq-NMS

            # Get the scores modified by Seq-NMS
            updated_scores = updated_scores_list[i] # Scores after Seq-NMS (0 if suppressed)

            # Create a mask to keep only detections whose score > 0 after Seq-NMS
            keep_mask_post = updated_scores > 0.0
            num_kept_frame_post = keep_mask_post.sum().item()
            num_kept_total_post_seqnms += num_kept_frame_post

            # Filter the *pre-filtered* data using the mask derived from updated scores
            final_results_per_frame.append({
                'scores': original_scores[keep_mask_post], # Report original score
                'labels': original_labels[keep_mask_post],
                'boxes': original_boxes[keep_mask_post]
            })

        print(f"Seq-NMS finished. Kept {num_kept_total_post_seqnms} final detections across {num_frames} frames.")
        return final_results_per_frame

## Recurrent module

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F # Needed for padding

class RecurrentVideoObjectModule(nn.Module):
    """
    Refactored module using a YOLO detector, producing output compatible
    with the original implementation's format {'pred_logits': Tensor, 'pred_boxes': Tensor}.
    Detections from each frame are padded to a maximum number before stacking.
    Includes fix for grayscale input.
    """
    def __init__(self,
                 detector_module: nn.Module, # E.g., a loaded YOLO model
                 num_classes: int = 10, # Needed for converting scores/classes to logits

                ):
        super().__init__()
        self.detector_module = detector_module
        self.num_classes = num_classes + 1
        # Store the number of foreground classes if needed elsewhere
        self.num_foreground_classes = num_classes

    @torch.no_grad() # Crucial for evaluation/inference mode
    def forward(self, samples, targets: list = None):
        """
        Processes video frames sequentially, collects raw detections, pads them
        to the maximum number of detections found in any frame, and stacks them.
        Handles potential grayscale input.

        Args:
            samples (torch.Tensor): Input video tensor (B, T, C, H, W).
                                    Module assumes B=1. Can handle C=1 or C=3.
            targets (list): List containing target dictionaries for the video.
                            Passed through unmodified.

        Returns:
            tuple[dict, list]:
                - dict: {'pred_logits': Tensor [T, MaxDetsPerFrame, NumClasses],
                         'pred_boxes': Tensor [T, MaxDetsPerFrame, 4]}
                         Padded and stacked detections across all frames.
                - list: The original targets list (passed through).
        """
        if samples.dim() != 5:
             raise ValueError(f"Expected input tensor with 5 dimensions (B, T, C, H, W), but got {samples.dim()}")

        if samples.shape[0] != 1:
            print(f"Warning: RecurrentVideoObjectModule expects batch size 1, but got {samples.shape[0]}. Processing only the first item.")
            samples = samples[:1]

        # samples shape: (B, T, C, H, W), B=1 assumed
        num_frames = samples.shape[1]
        src = samples.permute(1, 0, 2, 3, 4).squeeze(1) # Shape: (T, C, H, W)
        device = samples.device
        input_channels = src.shape[1] # Get the number of channels C

        # Check if input is grayscale (C=1) or color (C=3)
        if input_channels not in [1, 3]:
            raise ValueError(f"Expected input frames to have 1 or 3 channels, but got {input_channels}")

        logits_accumulator = []
        boxes_accumulator = []
        max_detections = 0 # Keep track of the maximum number of detections in any frame

        if isinstance(self.detector_module, nn.Module):
              self.detector_module.eval()

        for timestamp, frame in enumerate(src): # Iterate through frames (C, H, W)

            # --- Ensure 3 Channels for YOLO ---
            if input_channels == 1:
                # Repeat the grayscale channel 3 times
                frame = frame.repeat(3, 1, 1) # Shape becomes (3, H, W)

            frame_batch = frame.unsqueeze(0) # Add batch dimension (1, 3, H, W)

            # --- Run Detector (YOLO) ---
            try:
                with torch.no_grad():
                    # Use model.predict for easier handling of results and NMS
                    # Adjust conf/iou as needed for your specific detector/task
                    detector_preds = self.detector_module.predict(frame_batch, conf=0.001, iou=0.5, device=device, verbose=False)
                    # 'preds' is typically a list of Results objects, one per image

            except Exception as e:
                print(f"Error during detector_module.predict on frame {timestamp}: {e}")
                print("Skipping frame due to prediction error.")
                # Decide how to handle: continue, raise, or append empty tensors?
                # Appending empty tensors might be safer for padding later.
                # logits_accumulator.append(torch.empty((0, self.num_classes), device=device))
                # boxes_accumulator.append(torch.empty((0, 4), device=device))
                raise e

            # --- Process Detections ---
            results_obj = detector_preds[0] # Assuming batch size 1 for predict


            # Use normalized xywh format directly if available
            frame_boxes_cxcywh = results_obj.boxes.xywhn.to(device)

            # print(f"frame_boxes_cxcywh: {frame_boxes_cxcywh.shape}")
            
            frame_scores = results_obj.boxes.conf.to(device)

            # print(f"frame_scores: {frame_scores.shape}")
            # print(frame_scores)
            
            frame_classes = results_obj.boxes.cls.to(device).long()

            # print(f"frame_classes: {frame_classes.shape}")

            num_detections = frame_classes.shape[0]

            # print(f"num_detections: {num_detections}")

            # --- Convert scores/classes to approximate logits ---
            frame_logits = torch.zeros((num_detections, self.num_classes), device=device)
            
            # Place the score logits at the correct class index
            frame_logits[torch.arange(num_detections, device=device), frame_classes] = frame_scores
            
            logits_accumulator.append(frame_logits)
            boxes_accumulator.append(frame_boxes_cxcywh)

            # Update max detections found so far
            if num_detections > max_detections:
                max_detections = num_detections

        # --- Pad tensors to max_detections ---
        padded_logits = []
        padded_boxes = []
        pad_val_boxes = 0.0
        background_class_index = self.num_classes - 1 # Last index is background
        for frame_logits, frame_boxes in zip(logits_accumulator, boxes_accumulator):
            num_dets = frame_logits.shape[0]
            pad_size = max_detections - num_dets

            if pad_size > 0:
                # --- Pad logits ---
                # Create padding with 0.0 everywhere first
                logit_padding = torch.zeros((pad_size, self.num_classes), device=device, dtype=frame_logits.dtype)
                # Set the background class logit to +inf for padded entries
                logit_padding[:, background_class_index] = float('inf')
                # Concatenate original logits with the padding
                padded_frame_logits = torch.cat((frame_logits, logit_padding), dim=0)

                # --- Pad boxes ---
                # Pad boxes with 0.0
                box_padding = torch.full((pad_size, 4), pad_val_boxes, device=device, dtype=frame_boxes.dtype)
                # Concatenate original boxes with the padding
                padded_frame_boxes = torch.cat((frame_boxes, box_padding), dim=0)
            else:
                # No padding needed if already max_detections
                padded_frame_logits = frame_logits
                padded_frame_boxes = frame_boxes

            padded_logits.append(padded_frame_logits)
            padded_boxes.append(padded_frame_boxes)


        final_logits = torch.stack(padded_logits) # Shape: [T, max_detections, num_classes]
        final_boxes = torch.stack(padded_boxes)   # Shape: [T, max_detections, 4]

        result = {
            'pred_logits': final_logits,
            'pred_boxes': final_boxes
        }

        return result, targets[0]

    def eval(self):
        """Sets the detector module to evaluation mode."""
        if isinstance(self.detector_module, nn.Module):
            self.detector_module.eval()
        return self


## MAP Evaluator

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import numpy as np
# Assuming box_ops contains the necessary bounding box utility functions like box_cxcywh_to_xyxy
# Import the actual box_ops module
from util import box_ops
# Assuming PostProcess is importable, e.g.:
# from models.perceiver import PostProcess # Adjust import path as needed
# Assuming box_iou is available from torchvision.ops
from torchvision.ops import box_iou


class MeanAveragePrecisionEvaluator:
    """
    Computes Mean Average Precision (mAP) for object detection tasks using a provided postprocessor.
    Also computes separate mAP scores for:
    1. Predictions matched to overlapping ground truth objects.
    2. Predictions matched to ground truth objects near the frame boundary.

    Processes model outputs assuming shape [sequence_length, num_queries, ...] and targets_flat
    as a list of dictionaries of length sequence_length.

    Standard mAP is updated per timestamp. Overlap and Boundary mAP use results
    from a single matcher call performed once per sequence.
    """
    def __init__(self, device, postprocessor, matcher, box_format='xyxy', iou_thresholds=None, overlap_iou_threshold=0.01, boundary_pixel_tolerance=0):
        """
        Initializes the mAP evaluator.

        Parameters:
        device (torch.device): The device to run computations on (e.g., 'cuda', 'cpu').
        postprocessor (torch.nn.Module): Postprocessing module (e.g., PostProcess).
        matcher (torch.nn.Module): Matcher instance (e.g., HungarianMatcher). Assumed to handle
                                   sequence dim as batch dim or be adapted accordingly.
        box_format (str): Box format for torchmetrics ('xyxy').
        iou_thresholds (list, optional): IoU thresholds for mAP calculation.
        overlap_iou_threshold (float): Min IoU between two GT boxes to be considered overlapping.
        boundary_pixel_tolerance (int): Pixel distance from edge to consider a GT box as 'on boundary'.
        """
        self.device = device
        self.postprocessor = postprocessor.to(self.device)
        self.matcher = matcher # Store matcher instance
        self.overlap_iou_threshold = overlap_iou_threshold
        self.boundary_pixel_tolerance = boundary_pixel_tolerance

        # --- Initialize Metrics ---
        self.map_metric = torchmetrics.detection.MeanAveragePrecision(
            box_format='xyxy', iou_type='bbox', iou_thresholds=iou_thresholds
        ).to(self.device)
        self.map_results = None

        self.map_metric_overlap = torchmetrics.detection.MeanAveragePrecision(
            box_format='xyxy', iou_type='bbox', iou_thresholds=iou_thresholds
        ).to(self.device)
        self.map_results_overlap = None

        self.map_metric_boundary = torchmetrics.detection.MeanAveragePrecision(
            box_format='xyxy', iou_type='bbox', iou_thresholds=iou_thresholds
        ).to(self.device)
        self.map_results_boundary = None


    # Helper function provided by user (might need adjustment based on actual matcher output format)
    def _get_src_permutation_idx(self, indices):
      # This helper might not be directly usable if indices is now a list over sequence length
      # The logic using matcher results below accesses indices[t] directly.
      print("Warning: _get_src_permutation_idx might not be suitable for the new matcher output format.")
      # Placeholder return if called unexpectedly
      return torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)


    def update(self, outputs, targets_flat):
        """
        Updates the evaluator state. Standard mAP is updated per timestamp.
        Matcher runs once per sequence. Overlap/Boundary metrics are updated
        per timestamp using the single matcher result.

        Parameters:
        outputs (dict): Raw outputs. Keys: 'pred_logits', 'pred_boxes'. Shape [seq_len, num_queries, ...].
        targets_flat (list[dict]): GT dicts. Len=seq_len. Keys: 'orig_size', 'boxes', 'labels'.
        """
        # --- Input Validation ---
        if not isinstance(targets_flat, list): return
        if 'pred_logits' not in outputs or 'pred_boxes' not in outputs: return

        pred_logits_seq = outputs['pred_logits'].to(self.device) # Shape [seq_len, num_queries, ...]
        pred_boxes_seq = outputs['pred_boxes'].to(self.device)   # Shape [seq_len, num_queries, ...]
        seq_len = pred_logits_seq.shape[0]

        if pred_logits_seq.dim() < 2 or seq_len != len(targets_flat):
             print(f"Error: Mismatch/Shape issue between outputs ({pred_logits_seq.shape}) and targets_flat ({len(targets_flat)}). Skipping.")
             return

        # --- Store per-timestamp processed data ---
        all_postprocessed_preds = []
        all_processed_targets = []
        all_orig_sizes = [] # Store original sizes (tensor[2]) for boundary check

        # --- First Pass: Update Standard mAP & Prepare Data ---
        for t in range(seq_len):
            target_dict = targets_flat[t]
            if not isinstance(target_dict, dict) or not all(k in target_dict for k in ['orig_size', 'boxes', 'labels']):
                 print(f"Error: Invalid target dict at timestamp {t}. Skipping timestamp.")
                 all_postprocessed_preds.append(None)
                 all_processed_targets.append(None)
                 all_orig_sizes.append(None)
                 continue

            # Prepare Predictions (Post-processed)
            orig_size_tensor_t = None
            prediction_dict_postprocessed_t = None
            try:
                orig_size_tensor_t = target_dict['orig_size'].to(self.device)
                if orig_size_tensor_t.shape != (2,): raise ValueError("Incorrect orig_size shape")
                orig_target_size_tensor_batch = orig_size_tensor_t.unsqueeze(0) # Shape [1, 2]

                outputs_raw_t_batch = { # Need batch dim for postprocessor
                    'pred_logits': pred_logits_seq[t].unsqueeze(0),
                    'pred_boxes': pred_boxes_seq[t].unsqueeze(0)
                }
                with torch.no_grad():
                    predictions_postprocessed_t_list = self.postprocessor(outputs_raw_t_batch, orig_target_size_tensor_batch)
                if len(predictions_postprocessed_t_list) != 1: raise RuntimeError("Postprocessor failed")
                prediction_dict_postprocessed_t = predictions_postprocessed_t_list[0]
                all_postprocessed_preds.append(prediction_dict_postprocessed_t)
                all_orig_sizes.append(orig_size_tensor_t) # Store shape [2] tensor

            except Exception as e:
                 print(f"Error processing predictions for timestamp {t}: {e}. Skipping.")
                 all_postprocessed_preds.append(None)
                 all_processed_targets.append(None)
                 all_orig_sizes.append(None)
                 continue

            # Prepare Targets (Absolute xyxy)
            target_dict_processed_t = None
            try:
                target_dict_device = {k: v.to(self.device) for k, v in target_dict.items() if isinstance(v, torch.Tensor)}
                img_h_target_tensor = orig_size_tensor_t[0]
                img_w_target_tensor = orig_size_tensor_t[1]
                scale_fct_target = torch.stack([img_w_target_tensor, img_h_target_tensor, img_w_target_tensor, img_h_target_tensor], dim=0)

                boxes_target = target_dict_device['boxes'] # Assumed relative cxcywh
                labels_target = target_dict_device['labels'] # Original labels

                gt_boxes_xyxy = torch.empty((0,4), device=self.device) # Default empty boxes
                labels_final = torch.empty(0, dtype=torch.long, device=self.device) # Default empty labels

                if boxes_target.numel() > 0:
                     # Only process boxes and use original labels if boxes exist
                     gt_boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes_target) * scale_fct_target
                     # Ensure labels correspond to the boxes processed
                     if labels_target.shape[0] == boxes_target.shape[0]:
                          labels_final = labels_target
                     else:
                          print(f"Warning: Mismatch between number of boxes ({boxes_target.shape[0]}) and labels ({labels_target.shape[0]}) at timestamp {t}. Using empty labels.")
                          # Keep labels_final empty if mismatch occurs
                # else: gt_boxes_xyxy and labels_final remain empty

                target_dict_processed_t = {
                    'boxes': gt_boxes_xyxy, # Shape [N, 4] or [0, 4]
                    'labels': labels_final   # Shape [N] or [0]
                }
                all_processed_targets.append(target_dict_processed_t)
            except Exception as e:
                 print(f"Error preparing targets for timestamp {t}: {e}. Skipping.")
                 all_processed_targets.append(None)
                 if len(all_postprocessed_preds) > len(all_processed_targets):
                      all_postprocessed_preds.pop()
                      all_orig_sizes.pop()
                 continue

            # Update Standard mAP Metric for this timestamp
            # Ensure both prediction and target dicts are valid before updating
            if prediction_dict_postprocessed_t is not None and target_dict_processed_t is not None:
                try:
                    self.map_metric.update([prediction_dict_postprocessed_t], [target_dict_processed_t])
                except Exception as e:
                    print(f"Error updating standard mAP metric for timestamp {t}: {e}")
            else:
                 print(f"Skipping standard mAP update for timestamp {t} due to previous errors.")


        # --- Run Matcher ONCE for the whole sequence ---
        indices_seq = None # Will store list of tuples: [(pred_idx_t, gt_idx_t)] * seq_len
        try:
            # Prepare inputs for matcher (treating sequence as batch)
            outputs_raw_seq = {'pred_logits': pred_logits_seq, 'pred_boxes': pred_boxes_seq}
            # Ensure targets are on device and have required keys
            targets_flat_device = []
            valid_targets_exist = False
            for t_idx, td in enumerate(targets_flat):
                 # Use the already processed targets if available and valid
                 processed_target = all_processed_targets[t_idx]
                 if processed_target is not None and isinstance(processed_target.get('boxes'), torch.Tensor) and isinstance(processed_target.get('labels'), torch.Tensor):
                      # Matcher might expect original format, re-prepare if necessary
                      # For simplicity, let's assume matcher works with processed targets (abs xyxy)
                      # If matcher needs relative cxcywh, re-fetch from targets_flat
                      original_target = targets_flat[t_idx]
                      if isinstance(original_target, dict) and all(k in original_target for k in ['boxes', 'labels']):
                           targets_flat_device.append({k: v.to(self.device) for k, v in original_target.items() if k in ['boxes', 'labels']})
                           valid_targets_exist = True
                      else:
                           # Add placeholder if original target was invalid
                           targets_flat_device.append({'boxes': torch.empty((0,4), device=self.device), 'labels': torch.empty(0, dtype=torch.long, device=self.device)})

                 else:
                      # Add placeholder if processing failed
                      targets_flat_device.append({'boxes': torch.empty((0,4), device=self.device), 'labels': torch.empty(0, dtype=torch.long, device=self.device)})


            if valid_targets_exist: # Only run matcher if there's something to match
                 with torch.no_grad():
                      # Assumes matcher takes [Seq, N, C] outputs and List[Dict] targets (len=Seq)
                      indices_seq = self.matcher(outputs_raw_seq, targets_flat_device)
            else:
                 print("Skipping matcher run as no valid targets were found in the sequence.")

        except Exception as e:
             print(f"Error running matcher for the sequence: {e}")
             indices_seq = None # Ensure indices_seq is None if matcher fails

        # --- Second Pass: Update Overlap and Boundary Metrics ---
        if indices_seq is not None and len(indices_seq) == seq_len: # Check if matcher ran successfully and returned expected length
            for t in range(seq_len):
                # Retrieve pre-processed data for timestamp t
                prediction_dict_postprocessed_t = all_postprocessed_preds[t]
                target_dict_processed_t = all_processed_targets[t]
                orig_size_tensor_t = all_orig_sizes[t]

                # Skip if data preparation failed in the first pass for this timestamp
                if prediction_dict_postprocessed_t is None or target_dict_processed_t is None or orig_size_tensor_t is None:
                    continue

                # Get matcher results for this specific timestamp
                # Add check for empty matcher results for this timestamp
                if t >= len(indices_seq) or not indices_seq[t] or len(indices_seq[t]) != 2:
                     # print(f"Matcher results missing or invalid for timestamp {t}. Skipping overlap/boundary.")
                     continue # Skip if no valid match for this frame
                #matched_pred_indices_t, matched_gt_indices_t = indices_seq[t]
                matched_pred_indices_t = indices_seq[t][0].to(self.device)
                matched_gt_indices_t = indices_seq[t][1].to(self.device)

                gt_boxes_xyxy = target_dict_processed_t['boxes']
                num_gt = gt_boxes_xyxy.shape[0]

                # --- Overlap Calculation ---
                if num_gt > 1: # Overlap requires at least 2 GT boxes
                    try:
                        gt_iou_matrix = box_iou(gt_boxes_xyxy, gt_boxes_xyxy)
                        gt_iou_matrix.fill_diagonal_(0)
                        overlaps_exist = (gt_iou_matrix > self.overlap_iou_threshold).any(dim=1)
                        overlap_gt_indices = torch.where(overlaps_exist)[0].to(self.device)

                        if overlap_gt_indices.numel() > 0:
                            targets_overlap = {
                                'boxes': gt_boxes_xyxy[overlap_gt_indices],
                                'labels': target_dict_processed_t['labels'][overlap_gt_indices]
                            }
                            if matched_pred_indices_t.numel() > 0: # Check if matcher found matches for this frame
                                is_overlap_match = torch.isin(matched_gt_indices_t, overlap_gt_indices)
                                pred_indices_for_overlap = matched_pred_indices_t[is_overlap_match]
                                if pred_indices_for_overlap.numel() > 0:
                                    predictions_overlap = {
                                        'scores': prediction_dict_postprocessed_t['scores'][pred_indices_for_overlap],
                                        'labels': prediction_dict_postprocessed_t['labels'][pred_indices_for_overlap],
                                        'boxes': prediction_dict_postprocessed_t['boxes'][pred_indices_for_overlap],
                                    }
                                    self.map_metric_overlap.update([predictions_overlap], [targets_overlap])
                    except Exception as e:
                        print(f"Error calculating/updating overlap mAP for timestamp {t}: {e}")

                # --- Boundary Calculation ---
                if num_gt > 0: # Boundary check needs at least 1 GT box
                     try:
                        img_h_target_tensor = orig_size_tensor_t[0]
                        img_w_target_tensor = orig_size_tensor_t[1]
                        xmin, ymin, xmax, ymax = gt_boxes_xyxy.unbind(-1)
                        tol = self.boundary_pixel_tolerance
                        is_on_boundary = (xmin <= tol) | (ymin <= tol) | \
                                         (xmax >= img_w_target_tensor - tol) | (ymax >= img_h_target_tensor - tol)
                        boundary_gt_indices = torch.where(is_on_boundary)[0].to(self.device)

                        if boundary_gt_indices.numel() > 0:
                            targets_boundary = {
                                'boxes': gt_boxes_xyxy[boundary_gt_indices],
                                'labels': target_dict_processed_t['labels'][boundary_gt_indices]
                            }
                            if matched_pred_indices_t.numel() > 0: # Check if matcher found matches for this frame
                                is_boundary_match = torch.isin(matched_gt_indices_t, boundary_gt_indices)
                                pred_indices_for_boundary = matched_pred_indices_t[is_boundary_match]
                                if pred_indices_for_boundary.numel() > 0:
                                    predictions_boundary = {
                                        'scores': prediction_dict_postprocessed_t['scores'][pred_indices_for_boundary],
                                        'labels': prediction_dict_postprocessed_t['labels'][pred_indices_for_boundary],
                                        'boxes': prediction_dict_postprocessed_t['boxes'][pred_indices_for_boundary],
                                    }
                                    self.map_metric_boundary.update([predictions_boundary], [targets_boundary])
                     except Exception as e:
                         print(f"Error calculating/updating boundary mAP for timestamp {t}: {e}")
        else:
             print("Matcher did not run successfully or returned unexpected format. Skipping overlap/boundary metric updates.")


    def accumulate(self):
        """
        Computes the final mAP results across all updated timestamps from all sequences.
        Stores the results internally.
        """
        # Accumulate standard mAP
        try:
            self.map_results = self.map_metric.compute()
        except Exception as e:
            print(f"Error computing standard mAP metric: {e}")
            self.map_results = None

        # Accumulate overlap mAP
        try:
            self.map_results_overlap = self.map_metric_overlap.compute()
        except Exception as e:
            print(f"Error computing overlap mAP metric: {e}")
            self.map_results_overlap = None

        # Accumulate boundary mAP
        try:
            self.map_results_boundary = self.map_metric_boundary.compute()
        except Exception as e:
            print(f"Error computing boundary mAP metric: {e}")
            self.map_results_boundary = None


    def summary(self):
        """
        Processes and returns the computed mAP results in a dictionary, including overlap and boundary metrics.
        NOTE: Metrics reflect performance averaged over *all timestamps* processed.

        Returns:
        dict: Dictionary containing scalar mAP results prefixed with 'mAP_', 'mAP_overlap_', and 'mAP_boundary_'.
        """
        summary_dict = {}

        # Process standard mAP results
        if self.map_results is not None:
            print(f"\nRaw mAP Results (Evaluator): {self.map_results}\n")
            for k, v in self.map_results.items():
                if isinstance(v, torch.Tensor) and v.numel() == 1:
                    summary_dict[f'mAP_{k}'] = v.item() # Prefix standard keys
            if not any(k.startswith('mAP_') for k in summary_dict) and self.map_results:
                 if 'map' in self.map_results and self.map_results['map'] == 0: pass
                 else: print("Warning (Evaluator): No scalar standard mAP metrics found.")
        else:
             print("Warning (Evaluator): Standard mAP metric computation failed or produced no results.")

        # Process overlap mAP results
        if self.map_results_overlap is not None:
            print(f"\nRaw Overlap mAP Results (Evaluator): {self.map_results_overlap}\n")
            for k, v in self.map_results_overlap.items():
                if isinstance(v, torch.Tensor) and v.numel() == 1:
                    summary_dict[f'mAP_overlap_{k}'] = v.item() # Prefix overlap keys
            if not any(k.startswith('mAP_overlap_') for k in summary_dict) and self.map_results_overlap:
                 if 'map' in self.map_results_overlap and self.map_results_overlap['map'] == 0: pass
                 else: print("Warning (Evaluator): No scalar overlap mAP metrics found.")
        else:
             print("Warning (Evaluator): Overlap mAP metric computation failed or produced no results.")

        # Process boundary mAP results
        if self.map_results_boundary is not None:
            print(f"\nRaw Boundary mAP Results (Evaluator): {self.map_results_boundary}\n")
            for k, v in self.map_results_boundary.items():
                if isinstance(v, torch.Tensor) and v.numel() == 1:
                    summary_dict[f'mAP_boundary_{k}'] = v.item() # Prefix boundary keys
            if not any(k.startswith('mAP_boundary_') for k in summary_dict) and self.map_results_boundary:
                 if 'map' in self.map_results_boundary and self.map_results_boundary['map'] == 0: pass
                 else: print("Warning (Evaluator): No scalar boundary mAP metrics found.")
        else:
             print("Warning (Evaluator): Boundary mAP metric computation failed or produced no results.")


        return summary_dict

    def reset(self):
        """Resets the internal state of all metrics."""
        self.map_metric.reset()
        self.map_results = None
        self.map_metric_overlap.reset()
        self.map_results_overlap = None
        self.map_metric_boundary.reset()
        self.map_results_boundary = None



## Eval

In [8]:
from engine import evaluate
from dataset import build_dataset
import torch
from engine import evaluate
from dataset import build_dataset
from models import build_model
from models.perceiver import PostProcess
import os
from pathlib import Path
from util.misc import collate_fn, is_main_process, get_sha, get_rank
from torch.utils.data import DataLoader
from ultralytics import YOLO
import torchvision.transforms as T
from models.matcher import build_matcher

args.model = 'perceiver' # 'perceiver'

if args.model == 'perceiver':
    # https://wandb.ai/university-of-tartu-2/multi-sensor-dropout/runs/stu3mw4u/overview
    args.resume = 'checkpoint_epoch_31.pth'
    args.output_dir = "../not_tracked_dir/output_perceiver_detection_2025-04-23_19-42-48/"
    args.resize_frame = 320
elif args.model == 'YOLO':
    args.output_dir = "../not_tracked_dir/output_yolo_v8_2025-04-11/" # YOLO
    args.resume = 'weights/best.pt'
    args.resize_frame = 320
else:
    raise Error(f"Unsupported model: {args.model}")

args.test_dataset_fraction = 1
# args.num_workers = 1

checkpoint_path = os.path.join(args.output_dir, args.resume)

dataset_test = build_dataset(split='test', args=args)

print(f"ds size: {len(dataset_test)}")

sampler_test = torch.utils.data.SequentialSampler(dataset_test)

dataloader_test = DataLoader(dataset_test, sampler=sampler_test, batch_size=args.batch_size,
                                collate_fn=collate_fn, pin_memory=True)

if args.model == 'YOLO':
    print("Loading YOLO model for RecurrentVideoObjectModule...")
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"YOLO checkpoint (.pt file) not found at {checkpoint_path}")

    # Load the base YOLO detector
    yolo_detector = YOLO(checkpoint_path)
    print(f"YOLO detector loaded from {checkpoint_path}")

    # Get number of classes from the loaded YOLO model
    num_classes = len(yolo_detector.names)
    print(f"Detected {num_classes} classes from YOLO model.")

    # Instantiate the wrapper module
    model = RecurrentVideoObjectModule(detector_module=yolo_detector, num_classes=num_classes)
    print("RecurrentVideoObjectModule created with YOLO detector.")

    print("Clear norm transforms")
    dataset_test.norm_transforms = T.Compose([])
    
elif args.model == 'perceiver' and os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}")
    model = build_model(args, (0, 0))
    checkpoint = torch.load(checkpoint_path, map_location=args.device)
    model.load_state_dict(checkpoint['model'])
else:
    raise Error("Model not supported.")
    

device = torch.device(args.device)

postprocessors = {
    #'bbox': PostProcess() # SeqNMSEvaluationPostprocessor(PostProcess())
    # '': 
}

matcher = build_matcher(args)
matcher.to(device)

evaluators = [
    MeanAveragePrecisionEvaluator(device=device, postprocessor = PostProcess(), matcher=matcher)
]

criterion = DummyCriterion(device)

model.to(args.device)

results = evaluate(
	model=model,
	dataloader=dataloader_test,
	criterion=criterion,  # Not needed for evaluation
	postprocessors=postprocessors,
	epoch=-1,
	device=device,
    evaluators=evaluators
)

print("Evaluation Results:")
for metric_name, value in results.items():
	print(f"{metric_name}: {value:.4f}")

  from .autonotebook import tqdm as notebook_tqdm


Generating test huggingface MovingMNIST dataset...
Using object detection mode
Resizing frames to 320x320
Transforms: Compose(
    <dataset.transformations.RandomResize object at 0x14775cec5210>
    <dataset.transformations.NormBoxesTransform object at 0x14775cec4950>
)
ds size: 10000
Loading checkpoint from ../not_tracked_dir/output_perceiver_detection_2025-04-23_19-42-48/checkpoint_epoch_31.pth
Using CNN v2 backbone with resized input size 320x320
num_freq_bands: 6
depth: 1
max_freq: 10
input_channels: 128
input_axis: 2
num_latents: 16
latent_dim: 128
cross_heads: 1
latent_heads: 1
cross_dim_head: 154
latent_dim_head: 128
num_classes: -1
attn_dropout: 0.0
ff_dropout: 0.0
weight_tie_layers: False
fourier_encode_data: True
self_per_cross_attn: 1
final_classifier_head: False
num_sensors: 1
__class__: <class 'models.perceiver.Perceiver'>
Using HungarianMatcher for object detection


Eval -1:: 100%|██████████| 10000/10000 [34:24<00:00,  4.84it/s, loss_running=0, class_error_running=nan, loss_center_point_running=nan, loss_ce_running=0, loss=0, loss_ce_unscaled=0, loss_bbox_unscaled=0, loss_giou_unscaled=0, loss_unscaled=0, loss_ce=0, view_dropout_prob=-1] 



Raw mAP Results (Evaluator): {'map': tensor(0.9025), 'map_50': tensor(0.9688), 'map_75': tensor(0.9442), 'map_small': tensor(0.9025), 'map_medium': tensor(-1.), 'map_large': tensor(-1.), 'mar_1': tensor(0.7367), 'mar_10': tensor(0.9240), 'mar_100': tensor(0.9240), 'mar_small': tensor(0.9240), 'mar_medium': tensor(-1.), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.), 'classes': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)}


Raw Overlap mAP Results (Evaluator): {'map': tensor(0.6940), 'map_50': tensor(0.8853), 'map_75': tensor(0.7777), 'map_small': tensor(0.6940), 'map_medium': tensor(-1.), 'map_large': tensor(-1.), 'mar_1': tensor(0.6686), 'mar_10': tensor(0.7389), 'mar_100': tensor(0.7389), 'mar_small': tensor(0.7389), 'mar_medium': tensor(-1.), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.), 'classes': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)}


Raw Boundary mAP Results 

## Perceiver

https://wandb.ai/university-of-tartu-2/multi-sensor-dropout/runs/stu3mw4u/overview
args.resume = 'checkpoint_epoch_31.pth'
args.output_dir = "../not_tracked_dir/output_perceiver_detection_2025-04-23_19-42-48/"


Raw mAP Results (Evaluator): {'map': tensor(0.9025), 'map_50': tensor(0.9688), 'map_75': tensor(0.9442), 'map_small': tensor(0.9025), 'map_medium': tensor(-1.), 'map_large': tensor(-1.), 'mar_1': tensor(0.7367), 'mar_10': tensor(0.9240), 'mar_100': tensor(0.9240), 'mar_small': tensor(0.9240), 'mar_medium': tensor(-1.), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.), 'classes': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)}


Raw Overlap mAP Results (Evaluator): {'map': tensor(0.6940), 'map_50': tensor(0.8853), 'map_75': tensor(0.7777), 'map_small': tensor(0.6940), 'map_medium': tensor(-1.), 'map_large': tensor(-1.), 'mar_1': tensor(0.6686), 'mar_10': tensor(0.7389), 'mar_100': tensor(0.7389), 'mar_small': tensor(0.7389), 'mar_medium': tensor(-1.), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.), 'classes': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)}


Raw Boundary mAP Results (Evaluator): {'map': tensor(0.8470), 'map_50': tensor(0.9552), 'map_75': tensor(0.9101), 'map_small': tensor(0.8470), 'map_medium': tensor(-1.), 'map_large': tensor(-1.), 'mar_1': tensor(0.8512), 'mar_10': tensor(0.8739), 'mar_100': tensor(0.8739), 'mar_small': tensor(0.8739), 'mar_medium': tensor(-1.), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.), 'classes': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)}

## YOLO
Raw mAP Results (Evaluator): {'map': tensor(0.9209), 'map_50': tensor(0.9588), 'map_75': tensor(0.9367), 'map_small': tensor(0.9210), 'map_medium': tensor(-1.), 'map_large': tensor(-1.), 'mar_1': tensor(0.7747), 'mar_10': tensor(0.9375), 'mar_100': tensor(0.9375), 'mar_small': tensor(0.9375), 'mar_medium': tensor(-1.), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.), 'classes': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)}


Raw Overlap mAP Results (Evaluator): {'map': tensor(0.8149), 'map_50': tensor(0.9072), 'map_75': tensor(0.8693), 'map_small': tensor(0.8149), 'map_medium': tensor(-1.), 'map_large': tensor(-1.), 'mar_1': tensor(0.7976), 'mar_10': tensor(0.8430), 'mar_100': tensor(0.8430), 'mar_small': tensor(0.8430), 'mar_medium': tensor(-1.), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.), 'classes': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)}


Raw Boundary mAP Results (Evaluator): {'map': tensor(0.7145), 'map_50': tensor(0.7703), 'map_75': tensor(0.7448), 'map_small': tensor(0.7145), 'map_medium': tensor(-1.), 'map_large': tensor(-1.), 'mar_1': tensor(0.7244), 'mar_10': tensor(0.7425), 'mar_100': tensor(0.7425), 'mar_small': tensor(0.7425), 'mar_medium': tensor(-1.), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.), 'classes': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)}


Raw mAP Results: {'map': tensor(0.9209), 'map_50': tensor(0.9588), 'map_75': tensor(0.9367), 'map_small': tensor(0.9210), 'map_medium': tensor(-1.), 'map_large': tensor(-1.), 'mar_1': tensor(0.7747), 'mar_10': tensor(0.9375), 'mar_100': tensor(0.9375), 'mar_small': tensor(0.9375), 'mar_medium': tensor(-1.), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.), 'classes': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)}

## RESULT

Raw mAP Results: {'map': tensor(0.7331), 'map_50': tensor(0.9379), 'map_75': tensor(0.8674), 'map_small': tensor(0.7331), 'map_medium': tensor(-1.), 'map_large': tensor(-1.), 'mar_1': tensor(0.6338), 'mar_10': tensor(0.7885), 'mar_100': tensor(0.7885), 'mar_small': tensor(0.7885), 'mar_medium': tensor(-1.), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.), 'classes': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)}

Evaluation Results:
loss_running: 0.0000
class_error_running: nan
loss_center_point_running: nan
loss_ce_running: 0.0000
loss: 0.0000
loss_ce_unscaled: 0.0000
loss_bbox_unscaled: 0.0000
loss_giou_unscaled: 0.0000
loss_unscaled: 0.0000
loss_ce: 0.0000
mAP_map: 0.7331
mAP_map_50: 0.9379
mAP_map_75: 0.8674
mAP_map_small: 0.7331
mAP_map_medium: -1.0000
mAP_map_large: -1.0000
mAP_mar_1: 0.6338
mAP_mar_10: 0.7885
mAP_mar_100: 0.7885
mAP_mar_small: 0.7885
mAP_mar_medium: -1.0000
mAP_mar_large: -1.0000
mAP_map_per_class: -1.0000
mAP_mar_100_per_class: -1.0000

## Result
top 10 and threashold 0.5

Raw mAP Results: {'map': tensor(0.6828), 'map_50': tensor(0.8640), 'map_75': tensor(0.8098), 'map_small': tensor(0.6828), 'map_medium': tensor(-1.), 'map_large': tensor(-1.), 'mar_1': tensor(0.5995), 'mar_10': tensor(0.7301), 'mar_100': tensor(0.7301), 'mar_small': tensor(0.7301), 'mar_medium': tensor(-1.), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.), 'classes': tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)}

Evaluation Results:
loss_running: 0.0000
class_error_running: nan
loss_center_point_running: nan
loss_ce_running: 0.0000
loss: 0.0000
loss_ce_unscaled: 0.0000
loss_bbox_unscaled: 0.0000
loss_giou_unscaled: 0.0000
loss_unscaled: 0.0000
loss_ce: 0.0000
mAP_map: 0.6828
mAP_map_50: 0.8640
mAP_map_75: 0.8098
mAP_map_small: 0.6828
mAP_map_medium: -1.0000
mAP_map_large: -1.0000
mAP_mar_1: 0.5995
mAP_mar_10: 0.7301
mAP_mar_100: 0.7301
mAP_mar_small: 0.7301
mAP_mar_medium: -1.0000
mAP_mar_large: -1.0000
mAP_map_per_class: -1.0000
mAP_mar_100_per_class: -1.0000

## Plot

In [None]:
from ultralytics.utils.plotting import plot_images
from engine import evaluate
from dataset import build_dataset
import torch
from engine import evaluate
from dataset import build_dataset
from models import build_model
from models.perceiver import PostProcess
import os
from pathlib import Path
from util.misc import collate_fn, is_main_process, get_sha, get_rank
from torch.utils.data import DataLoader
from ultralytics import YOLO
import torchvision.transforms as T

args.model = 'YOLO' # 'perceiver'

if args.model == 'YOLO':

    args.output_dir = "../not_tracked_dir/output_yolo_v8_2025-04-11/" # YOLO
    args.resume = 'weights/best.pt'
    args.resize_frame = 320
    
    checkpoint_path = os.path.join(args.output_dir, args.resume)
    print("Loading YOLO model for RecurrentVideoObjectModule...")
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"YOLO checkpoint (.pt file) not found at {checkpoint_path}")

    # Load the base YOLO detector
    yolo_detector = YOLO(checkpoint_path)
    print(f"YOLO detector loaded from {checkpoint_path}")

    # Get number of classes from the loaded YOLO model
    num_classes = len(yolo_detector.names)
    print(f"Detected {num_classes} classes from YOLO model.")

    # Instantiate the wrapper module
    model = RecurrentVideoObjectModule(detector_module=yolo_detector, num_classes=num_classes)
    print("RecurrentVideoObjectModule created with YOLO detector.")

    args.test_dataset_fraction = 0.01
    
elif args.model == 'perceiver' and os.path.exists(checkpoint_path):

    # https://wandb.ai/university-of-tartu-2/multi-sensor-dropout/runs/stu3mw4u/overview
    args.resume = 'checkpoint_epoch_31.pth'
    args.output_dir = "../not_tracked_dir/output_perceiver_detection_2025-04-23_19-42-48"
    args.resize_frame = 320
    args.test_dataset_fraction = 0.01
    checkpoint_path = os.path.join(args.output_dir, args.resume)
    model = build_model(args, (0,0))
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=args.device)
    model.load_state_dict(checkpoint['model'])


save_dir = os.path.join(args.output_dir, 'test_eval_best')
os.makedirs(save_dir, exist_ok=True)

dataset_test = build_dataset(split='test', args=args)

if args.model == 'YOLO':
    print("Clear norm transforms")
    dataset_test.norm_transforms = T.Compose([])

sampler_test = torch.utils.data.SequentialSampler(dataset_test)

dataloader_test = DataLoader(dataset_test, sampler=sampler_test, batch_size=args.batch_size,
                                collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)


model.to(args.device)
model.eval()

def get_batch_by_index(dataloader, index):
    iterator = iter(dataloader)
    for i in range(index + 1):
        try:
            batch = next(iterator)
            if i == index:
                return batch
        except StopIteration:
            return None  # Index out of bounds


In [None]:
import torch.nn.functional as F
from util import box_ops
import numpy as np

ni = 2

samples=None

if samples is None:
    samples, targets = get_batch_by_index(dataloader_test, ni)
    original_video, targets_ = dataloader_test.dataset.dataset[ni]

out, targets_flat = model(samples, targets)

out_logits, pred_boxes = out['pred_logits'], out['pred_boxes']
batch_size = out_logits.shape[0]

if args.model == 'perceiver':
    prob = F.softmax(out_logits, -1)
elif args.model == 'YOLO':
    prob = out_logits # confs instead

scores, indices = prob[..., :-1].max(-1)

keep = scores > 0.1

# Prepare lists to store filtered results across the batch
all_boxes_pixels = []
all_scores = []
all_classes = []
all_batch_idx = []
paths = []

for i in range(batch_size): # frames
    keep_batch = keep[i] # Boolean mask for queries for this image
    paths.append(f'{args.model}_vd_{ni}_frm_{i}')
    if keep_batch.any():
        # Get filtered scores, classes, and boxes for this image
        batch_scores = scores[i][keep_batch]
        batch_classes = indices[i][keep_batch]
        batch_boxes_norm = pred_boxes[i][keep_batch] # Normalized cxcywh

        # Convert boxes to pixel xyxy
        # batch_boxes_pixels = rescale_bboxes(batch_boxes_norm, img_size_wh)
        batch_boxes_pixels = batch_boxes_norm# box_ops.box_cxcywh_to_xyxy(batch_boxes_norm)
        # batch_boxes_pixels = batch_boxes_pixels * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32, device=batch_boxes_norm.device)

        # Store results
        all_boxes_pixels.append(batch_boxes_pixels)
        all_scores.append(batch_scores)
        all_classes.append(batch_classes)
        # Add batch index for each kept box
        all_batch_idx.append(torch.full_like(batch_scores, fill_value=i, dtype=torch.long))

# Concatenate results from all images into single tensors
if all_batch_idx: # Check if any boxes were kept
    final_pred_boxes = torch.cat(all_boxes_pixels, dim=0)
    final_pred_scores = torch.cat(all_scores, dim=0)
    final_pred_classes = torch.cat(all_classes, dim=0)
    final_pred_batch_idx = torch.cat(all_batch_idx, dim=0)
else:
    # Handle case with no detections
    final_pred_boxes = torch.empty((0, 4), device=samples.device)
    final_pred_scores = torch.empty((0,), device=samples.device)
    final_pred_classes = torch.empty((0,), dtype=torch.long, device=samples.device)
    final_pred_batch_idx = torch.empty((0,), dtype=torch.long, device=samples.device)

# --- 3. Prepare Images ---

def denormalize_image(tensor, mean, std):
    """Denormalize a tensor image"""
    # Clone to avoid modifying original tensor
    tensor = tensor.clone()
    mean = torch.tensor(mean, device=tensor.device).view(-1, 1, 1)
    std = torch.tensor(std, device=tensor.device).view(-1, 1, 1)
    tensor.mul_(std).add_(mean)
    # Clamp and convert to uint8
    tensor = torch.clamp(tensor * 255.0, min=0.0, max=255.0).to(torch.uint8)
    return tensor

mmnist_stat = {
    'perceiver': (0.023958550628466375, 0.14140212075592035), # mean, std
    'YOLO': (0, 1), # mean, std
}

# Denormalize and convert B C H W -> B H W C
images_to_plot = []
samples_flat = samples.squeeze(0) # B T C H W -> T C H W
for i in range(samples_flat.shape[0]):
    img = denormalize_image(samples_flat[i], mmnist_stat[args.model][0], mmnist_stat[args.model][1]) # B C H W -> C H W (uint8)
    # img = img.permute(1, 2, 0) # C H W -> H W C
    images_to_plot.append(img.cpu().numpy()) # plot_images expects numpy HWC uint8

# # plot_images expects a single array/tensor if plotting multiple images from a batch
images_to_plot_batch = np.stack(images_to_plot) # Create batch H W C


# Plot Predictions
plot_images(
    images=images_to_plot_batch,
    batch_idx=final_pred_batch_idx.cpu(),  # Index mapping box to image
    cls=final_pred_classes.cpu(),          # Class index for each box
    bboxes=final_pred_boxes.detach().cpu(),         # Boxes in pixel xyxy format
    confs=final_pred_scores.cpu(),      # Optional: Include scores on plot
    fname=os.path.join(save_dir, f"{args.model}_val_batch{ni}_pred.jpg"),
    paths=paths
)