diff --git a/utils/loss.py b/utils/loss.py index bf7ab65a30..6eb70a2fa7 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -739,7 +739,7 @@ def build_targets(self, p, targets, imgs): + 3.0 * pair_wise_iou_loss ) - matching_matrix = torch.zeros_like(cost) + matching_matrix = torch.zeros_like(cost, device="cpu") for gt_idx in range(num_gt): _, pos_idx = torch.topk(