In [None]:
import torch

def non_max_suppression(boxes, scores, iou_threshold):
    """
    Perform non-maximum suppression (NMS) on the bounding boxes.

    Parameters:
    boxes (torch.Tensor): Tensor of shape (N, 4), where N is the number of bounding boxes.
                          Each box is represented by [x1, y1, x2, y2].
    scores (torch.Tensor): Tensor of shape (N,), where N is the number of bounding boxes.
                           Represents the scores associated with each bounding box.
    iou_threshold (float): The threshold for IoU. Boxes with IoU greater than this value will be suppressed.

    Returns:
    keep (torch.Tensor): Indices of the bounding boxes that are kept after NMS.
    """
    if boxes.size(0) == 0:
        return torch.empty((0,), dtype=torch.int64)
    
    # Compute the area of the bounding boxes
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)

    # Sort the boxes by scores in descending order
    _, order = scores.sort(0, descending=True)

    keep = []
    while order.numel() > 0:
        # Select the box with the highest score and add its index to the keep list
        i = order[0]
        keep.append(i)

        if order.numel() == 1:
            break

        # Compute the IoU of the selected box with the rest
        xx1 = torch.max(x1[i], x1[order[1:]])
        yy1 = torch.max(y1[i], y1[order[1:]])
        xx2 = torch.min(x2[i], x2[order[1:]])
        yy2 = torch.min(y2[i], y2[order[1:]])

        w = torch.clamp(xx2 - xx1 + 1, min=0)
        h = torch.clamp(yy2 - yy1 + 1, min=0)
        inter = w * h
        iou = inter / (areas[i] + areas[order[1:]] - inter)

        # Keep boxes with IoU less than the threshold
        inds = torch.where(iou <= iou_threshold)[0]
        order = order[inds + 1]

    return torch.tensor(keep, dtype=torch.int64)

# Example usage
boxes = torch.tensor([[100, 100, 210, 210], [105, 105, 215, 215], [150, 150, 255, 255]])
scores = torch.tensor([0.9, 0.8, 0.7])
iou_threshold = 0.3

keep = non_max_suppression(boxes, scores, iou_threshold)
filtered_boxes = boxes[keep]

print(f"Kept boxes after NMS: {filtered_boxes}")
