Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SegFormer #14019

Merged
merged 32 commits into from Oct 28, 2021
Merged

Add SegFormer #14019

merged 32 commits into from Oct 28, 2021

Conversation

NielsRogge
Copy link
Contributor

@NielsRogge NielsRogge commented Oct 15, 2021

What does this PR do?

This PR adds SegFormer, a new model by NVIDIA that is surprisingly simple, yet very powerful for semantic segmentation of images. It uses a hierarchical Transformer as backbone, and an all-MLP decode head. I've implemented 3 models:

  • SegformerModel (backbone-only)
  • SegformerForImageClassification (backbone + classifier head)
  • SegformerForSemanticSegmentation (backbone + semantic segmentation all-MLP head)

Models are on the hub (with approval from the author): https://huggingface.co/models?other=segformer

Here's how to use the semantic segmentation model:

from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from PIL import Image

feature_extractor = SegformerFeatureExtractor(do_random_crop=False)
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")

image = Image.open("...")

# prepare image for model
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values

# forward pass
outputs = model(pixel_values)

# logits are of shape (batch_size, num_labels, height/4, width/4)
logits = outputs.logits

Quick inference notebook with visualization: https://colab.research.google.com/drive/1Kc1VLuFrWUPz0rZXA2E_rKQqdK7kV2iH?usp=sharing

To do/questions

  • Decide on the default values of the feature extractor (which are kind of arbitrary right now)
  • I've called the decode head SegformerDecodeHead, rather than SegformerDecoder. It's more of a lightweight head, than a decoder. Is this ok?
  • Add padding of images + segmentation maps (probably a single function in image_utils.py), cc @sgugger. Currently, I rely on torch.nn.functional.pad, which makes the feature extractor depend on PyTorch. It could also make sense to do it in Numpy (this model for example pads after normalizing, so it would benefit from it as the output after normalization are Numpy arrays).
  • Make sure model doesn't return hidden states when the user doesn't want to
  • Model currently returns a SequenceClassifierOutput, however this will render wrong shapes of logits in the docs. Logits are actually of shape (batch_size, num_labels, height/4, width/4).
  • Add model cards (author has joined the NVIDIA org on the hub and might create these)

@NielsRogge
Copy link
Contributor Author

PR is ready for review, only thing to be added is padding.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly left nits, this is very clean! Great new addition!

README.md Outdated Show resolved Hide resolved
src/transformers/image_utils.py Show resolved Hide resolved

def pad_images(self, images):
"""Pad images to ``self.crop_size``."""
padded_images = nn.functional.pad(images, pad=self.crop_size, value=self.padding_value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for this, it would useful to have our own pad function that uses PyTorch if images is a torch Tensor and NumPy if images is a NumPy array. Wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if you could implement that, that would be great :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to, but I won't have time to do this this week however (not even sure about next week either).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's numpy.pad: https://numpy.org/doc/stable/reference/generated/numpy.pad.html that would make this implementation quite simple. Can you give it a try @NielsRogge ?

src/transformers/models/segformer/modeling_segformer.py Outdated Show resolved Hide resolved
src/transformers/models/segformer/modeling_segformer.py Outdated Show resolved Hide resolved
src/transformers/models/segformer/modeling_segformer.py Outdated Show resolved Hide resolved
src/transformers/models/segformer/modeling_segformer.py Outdated Show resolved Hide resolved
tests/test_modeling_segformer.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this Niels! You'll have to rebase on the current master branch and re-run make fix-copies to ensure that the Korean readme also gets updated.

This looks good to me, I have only left nits and one request regarding the padding method.


def pad_images(self, images):
"""Pad images to ``self.crop_size``."""
padded_images = nn.functional.pad(images, pad=self.crop_size, value=self.padding_value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's numpy.pad: https://numpy.org/doc/stable/reference/generated/numpy.pad.html that would make this implementation quite simple. Can you give it a try @NielsRogge ?


def forward(self, hidden_states, height, width, output_attentions=False):
self_attention_outputs = self.attention(
self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be done outside of the call for readability?

src/transformers/models/segformer/modeling_segformer.py Outdated Show resolved Hide resolved
src/transformers/models/segformer/modeling_segformer.py Outdated Show resolved Hide resolved
@LysandreJik LysandreJik merged commit 1dc96a7 into huggingface:master Oct 28, 2021
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
* First draft

* Make style & quality

* Improve conversion script

* Add print statement to see actual slice

* Make absolute tolerance smaller

* Fix image classification models

* Add post_process_semantic method

* Disable padding

* Improve conversion script

* Rename to ForSemanticSegmentation, add integration test, remove post_process methods

* Improve docs

* Fix code quality

* Fix feature extractor tests

* Fix tests for image classification model

* Delete file

* Add is_torch_available to feature extractor

* Improve documentation of feature extractor methods

* Apply suggestions from @sgugger's code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply some more suggestions of code review

* Rebase with master

* Fix rebase issues

* Make sure model only outputs hidden states when the user wants to

* Apply suggestions from code review

* Add pad method

* Support padding of 2d images

* Add print statement

* Add print statement

* Move padding method to SegformerFeatureExtractor

* Fix issue

* Add casting of segmentation maps

* Add test for padding

* Add small note about padding

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants