Skip to content

Commit

Permalink
[SegFormer] Add support for segmentation masks with one label (#20279)
Browse files Browse the repository at this point in the history
* Add support for binary segmentation

* Fix loss calculation and add test

* Remove space

* use fstring

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain>
  • Loading branch information
3 people committed Dec 20, 2022
1 parent 2280880 commit 2875fa9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/transformers/models/segformer/modeling_segformer.py
Expand Up @@ -806,15 +806,20 @@ def forward(

loss = None
if labels is not None:
if not self.config.num_labels > 1:
raise ValueError("The number of labels should be greater than one")
else:
# upsample logits to the images' original size
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
# upsample logits to the images' original size
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
if self.config.num_labels > 1:
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
loss = loss_fct(upsampled_logits, labels)
elif self.config.num_labels == 1:
valid_mask = ((labels >= 0) & (labels != self.config.semantic_loss_ignore_index)).float()
loss_fct = BCEWithLogitsLoss(reduction="none")
loss = loss_fct(upsampled_logits.squeeze(1), labels.float())
loss = (loss * valid_mask).mean()
else:
raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")

if not return_dict:
if output_hidden_states:
Expand Down
14 changes: 14 additions & 0 deletions tests/models/segformer/test_modeling_segformer.py
Expand Up @@ -140,6 +140,16 @@ def create_and_check_for_image_segmentation(self, config, pixel_values, labels):
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
)
self.parent.assertGreater(result.loss, 0.0)

def create_and_check_for_binary_image_segmentation(self, config, pixel_values, labels):
config.num_labels = 1
model = SegformerForSemanticSegmentation(config=config)
model.to(torch_device)
model.eval()
labels = torch.randint(0, 1, (self.batch_size, self.image_size, self.image_size)).to(torch_device)
result = model(pixel_values, labels=labels)
self.parent.assertGreater(result.loss, 0.0)

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
Expand Down Expand Up @@ -177,6 +187,10 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

def test_for_binary_image_segmentation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_binary_image_segmentation(*config_and_inputs)

def test_for_image_segmentation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
Expand Down

0 comments on commit 2875fa9

Please sign in to comment.