diff --git a/src/transformers/models/maskformer/feature_extraction_maskformer.py b/src/transformers/models/maskformer/feature_extraction_maskformer.py index fce59b0a4b8b..ca7c08f23fa5 100644 --- a/src/transformers/models/maskformer/feature_extraction_maskformer.py +++ b/src/transformers/models/maskformer/feature_extraction_maskformer.py @@ -538,7 +538,6 @@ def post_process_panoptic_segmentation( # create the area, since bool we just need to sum :) mask_k_area = mask_k.sum() # this is the area of all the stuff in query k - # TODO not 100%, why are the taking the k query here???? original_area = (mask_probs[k] >= 0.5).sum() mask_does_exist = mask_k_area > 0 and original_area > 0 @@ -565,5 +564,5 @@ def post_process_panoptic_segmentation( ) if is_stuff: stuff_memory_list[pred_class] = current_segment_id - results.append({"segmentation": segmentation, "segments": segments}) + results.append({"segmentation": segmentation, "segments": segments}) return results diff --git a/tests/maskformer/test_modeling_maskformer.py b/tests/maskformer/test_modeling_maskformer.py index f2e1f56f0f5b..67151ead6ff8 100644 --- a/tests/maskformer/test_modeling_maskformer.py +++ b/tests/maskformer/test_modeling_maskformer.py @@ -404,3 +404,23 @@ def test_with_annotations_and_loss(self): outputs = model(**inputs) self.assertTrue(outputs.loss is not None) + + def test_panoptic_segmentation(self): + model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() + feature_extractor = self.default_feature_extractor + + inputs = feature_extractor( + [np.zeros((3, 384, 384)), np.zeros((3, 384, 384))], + annotations=[ + {"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)}, + {"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)}, + ], + return_tensors="pt", + ) + + with torch.no_grad(): + outputs = model(**inputs) + + panoptic_segmentation = feature_extractor.post_process_panoptic_segmentation(outputs) + + self.assertTrue(len(panoptic_segmentation) == 2)