**Notes: Truncated NMS is a filtered version of NMS**
- NMS is a post-processing method to remove the redundant predictions while merging.
- **Task**: Convert Provided Truncated Non Maximum Suppression (NMS) algorithm to python version.
- First one is simple conversion of algorithm to python code version as it is.
- Second one is original Non Maximum Suppression (NMS). Truncated NMS is a filtered version of NMS.
- Third one is the Truncated NMS version converted in the form of NMS code. Basically, it does all the preliminary calculation like area calculated, sorting, IoU calculation as NMS code. Just, it adds few more condition to filter out some overlapping predictions based on IoOt and IoIt if the condition meets.

In [3]:
import torch

In [6]:
import numpy as np

def iou(box1, box2):
    """
    Calculate the Intersection Over Union (IOU) between two bounding boxes.
    
    box1 and box2 are tuples (x1, y1, x2, y2) representing the coordinates of the top-left and bottom-right corners.
    """
    x1_intersection = max(box1[0], box2[0])
    y1_intersection = max(box1[1], box2[1])
    x2_intersection = min(box1[2], box2[2])
    y2_intersection = min(box1[3], box2[3])

    # Calculate area of intersection
    intersection_width = max(0, x2_intersection - x1_intersection)
    intersection_height = max(0, y2_intersection - y1_intersection)
    intersection_area = intersection_width * intersection_height

    # Calculate area of both boxes
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])

    # IOU = intersection area / union area
    union_area = area1 + area2 - intersection_area
    return intersection_area / union_area if union_area > 0 else 0

def truncated_nms_original(B, S, IOUt, IOIt, IOOt):
    """
    Perform truncated Non-Maximum Suppression (NMS) on the list of bounding boxes and their scores.
    
    Parameters:
        B: List of bounding boxes.
        S: List of corresponding detection scores.
        IOUt: IOU threshold to keep boxes.
        IOIt: Inside box threshold for truncated NMS.
        IOOt: Outside box threshold for truncated NMS.
        
    Returns:
        K: List of bounding boxes after truncated NMS.
    """
    K = []  # Initialize the list of selected boxes

    while B:
        # Select the box with the highest score
        m = np.argmax(S)
        M = B[m]
        
        # Remove the selected box from the list
        B.pop(m)
        S.pop(m)
        
        # Iterate over the remaining boxes
        for i, bi in enumerate(B[:]):
            # Check if the intersection condition is met
            intersection_area = iou(M, bi)
            condition1 = intersection_area > IOIt and intersection_area < iou(M, bi) <= IOOt
            condition2 = intersection_area <= IOOt and intersection_area > IOIt
            
            condition = condition1 or condition2

            if not condition:
                K.append(M)  # Keep the box in K
                if iou(M, bi) <= IOUt and not condition:
                    # Remove box if IOU condition fails
                    B.pop(i)
                    S.pop(i)
                else:
                    # Otherwise, remove the box and score from B and S
                    B.pop(i)
                    S.pop(i)
            else:
                # Keep the box for further iteration
                if iou(M, bi) >= IOUt:
                    B.pop(i)
                    S.pop(i)
                
    return K

# Example usage
B = [(0, 0, 50, 50), (10, 10, 60, 60), (30, 30, 80, 80)]  # Example bounding boxes
S = [0.9, 0.85, 0.8]  # Corresponding detection scores
IOUt = 0.5  # IOU threshold
IOIt = 0.3  # Inside box threshold
IOOt = 0.7  # Outside box threshold

K = truncated_nms_original(B, S, IOUt, IOIt, IOOt)
print("Selected Boxes:", K)

Selected Boxes: [(0, 0, 50, 50)]


In [4]:
def nms(
    predictions: torch.tensor,
    match_metric: str = "IOU",
    match_threshold: float = 0.5,
):
    """
    Apply non-maximum suppression to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        predictions: (tensor) The location preds for the image
            along with the class predscores, Shape: [num_boxes,5].
        match_metric: (str) IOU or IOS
        match_threshold: (float) The overlap thresh for
            match metric.
    Returns:
        A list of filtered indexes, Shape: [ ,]
    """

    # we extract coordinates for every
    # prediction box present in P
    x1 = predictions[:, 0]
    y1 = predictions[:, 1]
    x2 = predictions[:, 2]
    y2 = predictions[:, 3]

    # we extract the confidence scores as well
    scores = predictions[:, 4]

    # calculate area of every block in P
    areas = (x2 - x1) * (y2 - y1)
    
    # sort the prediction boxes in P
    # according to their confidence scores
    order = scores.argsort()

    # initialise an empty list for
    # filtered prediction boxes
    keep = []

    while len(order) > 0:
        # extract the index of the
        # prediction with highest score
        # we call this prediction S
        idx = order[-1]

        # push S in filtered predictions list
        keep.append(idx.tolist())

        # remove S from P
        order = order[:-1]

        # sanity check
        if len(order) == 0:
            break

        # select coordinates of BBoxes according to
        # the indices in order
        xx1 = torch.index_select(x1, dim=0, index=order)
        xx2 = torch.index_select(x2, dim=0, index=order)
        yy1 = torch.index_select(y1, dim=0, index=order)
        yy2 = torch.index_select(y2, dim=0, index=order)

        # find the coordinates of the intersection boxes
        xx1 = torch.max(xx1, x1[idx])
        yy1 = torch.max(yy1, y1[idx])
        xx2 = torch.min(xx2, x2[idx])
        yy2 = torch.min(yy2, y2[idx])

        # find height and width of the intersection boxes
        w = xx2 - xx1
        h = yy2 - yy1

        # take max with 0.0 to avoid negative w and h
        # due to non-overlapping boxes
        w = torch.clamp(w, min=0.0)
        h = torch.clamp(h, min=0.0)

        # find the intersection area
        inter = w * h

        # find the areas of BBoxes according the indices in order
        rem_areas = torch.index_select(areas, dim=0, index=order)

        if match_metric == "IOU":
            # find the union of every prediction T in P
            # with the prediction S
            # Note that areas[idx] represents area of S
            union = (rem_areas - inter) + areas[idx]
            # find the IoU of every prediction in P with S
            match_metric_value = inter / union

        elif match_metric == "IOS":
            # find the smaller area of every prediction T in P
            # with the prediction S
            # Note that areas[idx] represents area of S
            smaller = torch.min(rem_areas, areas[idx])
            # find the IoU of every prediction in P with S
            match_metric_value = inter / smaller
        else:
            raise ValueError()

        # keep the boxes with IoU less than thresh_iou
        mask = match_metric_value < match_threshold
        order = order[mask]
    print("Final Bounding Box Count (NMS): ", len(keep))
    return keep

In [2]:
# Example usage
predictions = torch.tensor([
    [10, 10, 50, 50, 0.9, 1],
    [12, 12, 48, 48, 0.85, 1],
    [60, 60, 100, 100, 0.8, 2]
])

match_threshold = 0.5
match_metric= 'IOU'

result = nms(predictions,match_metric=match_metric, match_threshold=match_threshold)
print("Filtered Boxes:", result)

NameError: name 'torch' is not defined

In [5]:
def truncated_nms_merge(
    predictions: torch.Tensor,
    match_metric: str = "IOU",
    IOUt: float = 0.7,  # Truncation IoU threshold for keeping boxes
    IOIt: float = 0.5,  # Inside IoU threshold
    IOOt: float = 0.3   # Outside IoU threshold
):
    """
    Apply truncated non-maximum suppression to avoid detecting too many
    overlapping bounding boxes for a given object, with added truncation logic.

    Args:
        predictions (tensor): The location preds for the image along with the class scores, Shape: [num_boxes, 5].
        match_metric (str): IOU or IOS (Intersection over Area or Intersection over Union)
        match_threshold (float): The overlap threshold for match metric.
        IOUt (float): Intersection over Union threshold for truncation (threshold to keep boxes)
        IOIt (float): Inside Intersection over Union threshold (threshold to keep inside box)
        IOOt (float): Outside Intersection over Union threshold (threshold for outside box)
    
    Returns:
        List: A list of filtered indexes
    """
    # Extract coordinates for every prediction box present in P
    x1 = predictions[:, 0]
    y1 = predictions[:, 1]
    x2 = predictions[:, 2]
    y2 = predictions[:, 3]

    # Extract the confidence scores
    scores = predictions[:, 4]

    # Calculate area of every box
    areas = (x2 - x1) * (y2 - y1)

    # Sort the prediction boxes in P according to their confidence scores
    order = scores.argsort()

    # Initialize an empty list for filtered prediction boxes
    keep = []

    while len(order) > 0:
        # Extract the index of the prediction with the highest score (S)
        idx = order[-1]

        # Push S in filtered predictions list
        keep.append(idx.tolist())

        # Remove S from P
        order = order[:-1]

        # Sanity check
        if len(order) == 0:
            break

        # Select coordinates of remaining boxes according to the indices in order
        xx1 = torch.index_select(x1, dim=0, index=order)
        xx2 = torch.index_select(x2, dim=0, index=order)
        yy1 = torch.index_select(y1, dim=0, index=order)
        yy2 = torch.index_select(y2, dim=0, index=order)

        # Find the coordinates of the intersection boxes
        xx1 = torch.max(xx1, x1[idx])
        yy1 = torch.max(yy1, y1[idx])
        xx2 = torch.min(xx2, x2[idx])
        yy2 = torch.min(yy2, y2[idx])

        # Find height and width of the intersection boxes
        w = xx2 - xx1
        h = yy2 - yy1

        # Take max with 0.0 to avoid negative width and height
        w = torch.clamp(w, min=0.0)
        h = torch.clamp(h, min=0.0)

        # Find the intersection area
        inter = w * h

        # Find the areas of the remaining boxes according to the indices in order
        rem_areas = torch.index_select(areas, dim=0, index=order)

        # Calculate the match metric value (IoU or IoS)
        if match_metric == "IOU":
            # Find the union of every prediction T in P with the prediction S
            union = (rem_areas - inter) + areas[idx]
            match_metric_value = inter / union
        elif match_metric == "IOS":
            # Find the smaller area of every prediction T in P with the prediction S
            smaller = torch.min(rem_areas, areas[idx])
            match_metric_value = inter / smaller
        else:
            raise ValueError("Invalid match_metric. Choose either 'IOU' or 'IOS'.")

        # Add the condition for Truncated NMS:
        # - Keep boxes with IoU below the truncation threshold (IOUt)
        # - Handle inside (IOIt) and outside (IOOt) intersection thresholds
        mask = (match_metric_value < IOUt)  # Keep boxes with IoU below threshold

        # Apply the truncated NMS conditions
        for i, m in enumerate(mask):
            # Condition for truncated NMS (condition 1 and 2 from original pseudocode)
            iou = match_metric_value[i]
            if iou > IOIt and iou <= IOOt:  # Inside box condition
                mask[i] = False  # Remove box if condition is met
            elif iou <= IOOt and iou > IOIt:  # Outside box condition
                mask[i] = True  # Keep box if condition is met

        # Filter out the boxes based on the updated mask
        order = order[mask]

    print("Total Valid prediction: ", len(keep))
    return keep

In [10]:
# Example usage
predictions = torch.tensor([
        [10, 10, 50, 50, 0.8],
        [12, 12, 48, 48, 0.85],
        [60, 60, 100, 100, 0.75],
        [11, 11, 49, 49, 0.7]
])

IOUt = 0.7
IOIt = 0.5
IOOt = 0.3
match_threshold = 0.5
match_metric= 'IOU'

result = truncated_nms_merge(predictions,match_metric)
print("Filtered Boxes:", result)

Total Valid prediction:  2
Filtered Boxes: [1, 2]


In [11]:
import torch

def truncated_nms(
    predictions: torch.Tensor,
    iou_thresh: float = 0.7,  # IOUt: standard IoU threshold for suppression
    iit_thresh: float = 0.5,  # IOIt: threshold for inside overlap
    iot_thresh: float = 0.3   # IOOt: threshold for outside overlap
):
    """
    Apply Truncated Non-Maximum Suppression (NMS) to filter overlapping bounding boxes,
    incorporating additional filters based on spatial overlap.

    This function extends standard NMS by considering two extra ratios:
      - Inside Ratio: Intersection area / area(selected box)
      - Outside Ratio: Intersection area / area(candidate box)

    For each iteration:
      1. The highest-scored box is selected.
      2. For every remaining candidate, we compute:
           - IoU: intersection over union.
           - inside_ratio: fraction of the selected box overlapped.
           - outside_ratio: fraction of the candidate box overlapped.
      3. If the candidate’s IoU with the selected box is below `iou_thresh`,
         it is retained.
         Otherwise (IoU >= iou_thresh), the candidate is only retained if it is either
           - largely inside the selected box (inside_ratio >= iit_thresh), or
           - has significant overlap of its own area with the selected box (outside_ratio >= iot_thresh).
    
    Args:
        predictions (torch.Tensor): Bounding boxes with scores of shape [N, 5],
                                    where each row is [x1, y1, x2, y2, score].
        iou_thresh (float): IoU threshold (IOUt) for standard overlap suppression.
        iit_thresh (float): Intersection of inside threshold (IOIt).
        iot_thresh (float): Intersection of outside threshold (IOOt).

    Returns:
        List[int]: List of indices of bounding boxes to keep.
    """
    
    # Extract coordinates and scores
    x1 = predictions[:, 0]
    y1 = predictions[:, 1]
    x2 = predictions[:, 2]
    y2 = predictions[:, 3]
    scores = predictions[:, 4]

    # Compute areas of all boxes
    areas = (x2 - x1) * (y2 - y1)
    # Sort indices in ascending order of scores; highest score is last.
    order = scores.argsort()
    
    keep = []

    while order.numel() > 0:
        # Select index of the box with the highest score
        idx = order[-1].item()
        keep.append(idx)
        # Remove the selected index from the order list
        order = order[:-1]
        
        if order.numel() == 0:
            break

        # Select coordinates for the remaining boxes using the indices in 'order'
        xx1 = torch.index_select(x1, 0, order)
        yy1 = torch.index_select(y1, 0, order)
        xx2 = torch.index_select(x2, 0, order)
        yy2 = torch.index_select(y2, 0, order)
        
        # Compute coordinates for the intersection boxes
        inter_x1 = torch.max(xx1, x1[idx])
        inter_y1 = torch.max(yy1, y1[idx])
        inter_x2 = torch.min(xx2, x2[idx])
        inter_y2 = torch.min(yy2, y2[idx])
        
        # Compute width and height of the intersections
        w = (inter_x2 - inter_x1).clamp(min=0)
        h = (inter_y2 - inter_y1).clamp(min=0)
        inter = w * h

        # Calculate union area for IoU
        rem_areas = torch.index_select(areas, 0, order)
        union = rem_areas + areas[idx] - inter
        iou = inter / union

        # Compute additional ratios:
        # Inside ratio: fraction of the selected box's area that is overlapped by the candidate box.
        inside_ratio = inter / areas[idx]
        # Outside ratio: fraction of the candidate box's area that is overlapped by the selected box.
        outside_ratio = inter / rem_areas

        # Build mask for candidates to keep:
        # (1) If IoU is below the standard threshold, we keep the candidate.
        # (2) If IoU is above or equal to the threshold, we only keep it if either:
        #     - It overlaps a sufficient portion of the selected box (inside_ratio >= iit_thresh), or
        #     - A sufficient portion of the candidate overlaps with the selected box (outside_ratio >= iot_thresh).
        mask = (iou < iou_thresh) | ((iou >= iou_thresh) & ((inside_ratio >= iit_thresh) | (outside_ratio >= iot_thresh)))
        order = order[mask]

    print("Final Bounding Box Count (Truncated NMS):", len(keep))
    return keep

# Example usage:
if __name__ == "__main__":
    # Dummy predictions: each row is [x1, y1, x2, y2, score]
    predictions = torch.tensor([
        [10, 10, 50, 50, 0.8],
        [12, 12, 48, 48, 0.85],
        [60, 60, 100, 100, 0.75],
        [11, 11, 49, 49, 0.7]
    ], dtype=torch.float32)


    keep_indices = truncated_nms(predictions, iou_thresh=0.7, iit_thresh=0.5, iot_thresh=0.3)
    print("Kept indices:", keep_indices)


Final Bounding Box Count (Truncated NMS): 4
Kept indices: [1, 0, 2, 3]


In [12]:
import torch

def efficient_truncated_nms(
    predictions: torch.Tensor,
    iou_thresh: float = 0.7,  # IOUt: standard IoU threshold for suppression
    iit_thresh: float = 0.5,  # IOIt: threshold for inside overlap
    iot_thresh: float = 0.3   # IOOt: threshold for outside overlap
):
    """
    Apply an optimized version of Truncated Non-Maximum Suppression (NMS) to filter overlapping bounding boxes.
    
    This implementation uses advanced tensor indexing and vectorized operations to reduce redundant boxes
    while preserving useful predictions.
    
    Args:
        predictions (torch.Tensor): Bounding boxes with scores of shape [N, 5],
                                    where each row is [x1, y1, x2, y2, score].
        iou_thresh (float): IoU threshold (IOUt) for standard overlap suppression.
        iit_thresh (float): Intersection of inside threshold (IOIt).
        iot_thresh (float): Intersection of outside threshold (IOOt).

    Returns:
        List[int]: List of indices of bounding boxes to keep.
    """
    # Extract box coordinates and scores
    boxes = predictions[:, :4]
    scores = predictions[:, 4]
    
    # Compute area of each box
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    
    # Sort indices based on scores (ascending order; highest score is last)
    order = scores.argsort()
    keep = []

    while order.numel() > 0:
        # Select the box with the highest score
        current_idx = order[-1].item()
        keep.append(current_idx)
        
        # Remove the current highest scoring box from the list
        order = order[:-1]
        if order.numel() == 0:
            break

        # Use advanced indexing to select the remaining boxes
        current_box = boxes[current_idx].unsqueeze(0)  # Shape: [1, 4]
        rem_boxes = boxes[order]                       # Shape: [N, 4]

        # Compute the intersection coordinates
        inter_x1 = torch.maximum(rem_boxes[:, 0], current_box[0, 0])
        inter_y1 = torch.maximum(rem_boxes[:, 1], current_box[0, 1])
        inter_x2 = torch.minimum(rem_boxes[:, 2], current_box[0, 2])
        inter_y2 = torch.minimum(rem_boxes[:, 3], current_box[0, 3])
        
        # Compute width and height of the intersections and clamp to zero
        inter_w = (inter_x2 - inter_x1).clamp(min=0)
        inter_h = (inter_y2 - inter_y1).clamp(min=0)
        inter_area = inter_w * inter_h

        # Compute IoU: intersection over union
        rem_areas = areas[order]
        union = rem_areas + areas[current_idx] - inter_area
        iou = inter_area / union

        # Compute the additional spatial ratios:
        # Inside ratio: fraction of the selected box's area overlapped by the candidate
        inside_ratio = inter_area / areas[current_idx]
        # Outside ratio: fraction of the candidate's area overlapped by the selected box
        outside_ratio = inter_area / rem_areas

        # Build the mask:
        # (a) If IoU is below the threshold, candidate is retained.
        # (b) If IoU is high, candidate is retained only if it satisfies either:
        #     - inside_ratio >= iit_thresh or
        #     - outside_ratio >= iot_thresh.
        mask = (iou < iou_thresh) | ((iou >= iou_thresh) & ((inside_ratio >= iit_thresh) | (outside_ratio >= iot_thresh)))
        
        # Apply the mask to update the order list
        order = order[mask]

    print("Final Bounding Box Count (Efficient Truncated NMS):", len(keep))
    return keep

# Example usage:
if __name__ == "__main__":
    # Dummy predictions: each row is [x1, y1, x2, y2, score]
    predictions = torch.tensor([
        [10, 10, 50, 50, 0.8],
        [12, 12, 48, 48, 0.85],
        [60, 60, 100, 100, 0.75],
        [11, 11, 49, 49, 0.7]
    ], dtype=torch.float32)
    
    keep_indices = efficient_truncated_nms(predictions, iou_thresh=0.7, iit_thresh=0.5, iot_thresh=0.3)
    print("Kept indices:", keep_indices)


Final Bounding Box Count (Efficient Truncated NMS): 4
Kept indices: [1, 0, 2, 3]


In [13]:
import torch

def weighted_truncated_nms(
    predictions: torch.Tensor,
    iou_thresh: float = 0.7,  # IOUt: standard IoU threshold for merging
    iit_thresh: float = 0.5,  # IOIt: threshold for inside overlap
    iot_thresh: float = 0.3   # IOOt: threshold for outside overlap
):
    """
    Apply Weighted Truncated Non-Maximum Suppression (NMS) that merges redundant boxes
    using weighted averaging instead of simply discarding them.

    For each iteration:
      1. The highest-scored box is selected.
      2. All candidate boxes that have a high IoU with the selected box are further checked:
           - They are only considered redundant if the overlap does not meet the
             spatial quality criteria (inside_ratio or outside_ratio).
      3. All redundant boxes (including the selected box) are merged using a weighted average.
         The weighted average is computed using the confidence scores as weights.
      4. The merged box is added to the final list, and all merged boxes are removed
         from further consideration.

    Args:
        predictions (torch.Tensor): Bounding boxes with scores of shape [N, 5],
                                    where each row is [x1, y1, x2, y2, score].
        iou_thresh (float): IoU threshold for merging boxes.
        iit_thresh (float): Intersection of inside threshold.
        iot_thresh (float): Intersection of outside threshold.

    Returns:
        List[Tuple[torch.Tensor, float]]: List of merged boxes and their associated scores.
    """
    boxes = predictions[:, :4]      # shape [N, 4]
    scores = predictions[:, 4]        # shape [N]
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    
    # Sort boxes by score (lowest to highest) so that highest is at the end.
    order = scores.argsort()
    merged_boxes = []

    while order.numel() > 0:
        # Get the index of the highest scoring box
        current_idx = order[-1].item()
        current_box = boxes[current_idx].clone()
        current_score = scores[current_idx].clone()
        
        # Remove current box from order
        order = order[:-1]
        
        # If no remaining boxes, simply add the current box.
        if order.numel() == 0:
            merged_boxes.append((current_box, current_score))
            break

        # Retrieve remaining boxes and scores using advanced indexing.
        rem_boxes = boxes[order]
        rem_scores = scores[order]
        current_box_expanded = current_box.unsqueeze(0)
        
        # Compute intersection coordinates
        inter_x1 = torch.maximum(rem_boxes[:, 0], current_box_expanded[0, 0])
        inter_y1 = torch.maximum(rem_boxes[:, 1], current_box_expanded[0, 1])
        inter_x2 = torch.minimum(rem_boxes[:, 2], current_box_expanded[0, 2])
        inter_y2 = torch.minimum(rem_boxes[:, 3], current_box_expanded[0, 3])
        
        inter_w = (inter_x2 - inter_x1).clamp(min=0)
        inter_h = (inter_y2 - inter_y1).clamp(min=0)
        inter_area = inter_w * inter_h

        rem_areas = areas[order]
        union = rem_areas + areas[current_idx] - inter_area
        iou = inter_area / union

        # Compute spatial overlap ratios:
        inside_ratio = inter_area / areas[current_idx]
        outside_ratio = inter_area / rem_areas

        # Determine which boxes are considered redundant:
        # (a) Low IoU boxes are kept separate.
        # (b) For high IoU (>= iou_thresh), keep only those with good spatial overlap:
        #     either inside_ratio >= iit_thresh or outside_ratio >= iot_thresh.
        merge_mask = (iou >= iou_thresh) & ~((inside_ratio >= iit_thresh) | (outside_ratio >= iot_thresh))
        
        # Identify indices (within 'order') of boxes to merge with the current box.
        merge_indices = torch.nonzero(merge_mask).squeeze(1)

        if merge_indices.numel() > 0:
            # Gather boxes and scores to merge (include the current box).
            merge_boxes = torch.cat([current_box.unsqueeze(0), rem_boxes[merge_indices]], dim=0)
            merge_scores = torch.cat([torch.tensor([current_score], device=predictions.device), rem_scores[merge_indices]], dim=0)
            
            # Compute weighted average for each coordinate using scores as weights.
            weighted_box = torch.zeros(4, device=predictions.device)
            for i in range(4):
                weighted_box[i] = (merge_boxes[:, i] * merge_scores).sum() / merge_scores.sum()
            
            # Optionally, update the score (here we keep the maximum score).
            merged_score = merge_scores.max().item()
            
            # Append the merged box.
            merged_boxes.append((weighted_box, merged_score))
            
            # Remove merged boxes from 'order'. We create a mask that keeps those not merged.
            remaining_mask = torch.ones(order.size(0), dtype=torch.bool, device=order.device)
            remaining_mask[merge_indices] = False
            order = order[remaining_mask]
        else:
            # No redundant boxes to merge; keep the current box as is.
            merged_boxes.append((current_box, current_score))
    
    print("Final Merged Box Count (Weighted Truncated NMS):", len(merged_boxes))
    return merged_boxes

# Example usage:
if __name__ == "__main__":
    # Dummy predictions: each row is [x1, y1, x2, y2, score]
    predictions = torch.tensor([
        [10, 10, 50, 50, 0.9],
        [12, 12, 52, 52, 0.8],
        [60, 60, 100, 100, 0.75],
        [11, 11, 49, 49, 0.7]
    ], dtype=torch.float32)
    
    merged = weighted_truncated_nms(predictions, iou_thresh=0.7, iit_thresh=0.5, iot_thresh=0.3)
    for idx, (box, score) in enumerate(merged):
        print(f"Merged Box {idx}: {box.tolist()}, Score: {score}")


Final Merged Box Count (Weighted Truncated NMS): 4
Merged Box 0: [10.0, 10.0, 50.0, 50.0], Score: 0.8999999761581421
Merged Box 1: [12.0, 12.0, 52.0, 52.0], Score: 0.800000011920929
Merged Box 2: [60.0, 60.0, 100.0, 100.0], Score: 0.75
Merged Box 3: [11.0, 11.0, 49.0, 49.0], Score: 0.699999988079071
