Skip to content

Commit

Permalink
conversations
Browse files Browse the repository at this point in the history
[Unispeech] Fix slow tests (#15818)

* remove soundfile old way of loading audio

* Adapt slow test

[Barthez Tokenizer] Fix saving (#15815)

[TFXLNet] Correct tf xlnet generate (#15822)

* [TFXLNet] Correct tf xlnet

* adapt test comment

Fix the push run (#15807)

Fix semantic segmentation pipeline test (#15826)

Fix dummy_inputs() to dummy_inputs in symbolic_trace doc (#15776)

Add model specific output classes to PoolFormer model docs (#15746)

* Added model specific output classes to poolformer docs

* Fixed Segformer typo in Poolformer docs

Adding the option to return_timestamps on pure CTC ASR models. (#15792)

* Adding the option to return_timestamps on pure CTC ASR models.

* Remove `math.prod` which was introduced in Python 3.8

* int are not floats.

* Reworking the PR to support "char" vs "word" output.

* Fixup!

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Quality.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

HFTracer.trace should use/return self.graph to be compatible with torch.fx.Tracer (#15824)

Fix tf.concatenate + test past_key_values for TF models (#15774)

* fix wrong method name tf.concatenate

* add tests related to causal LM / decoder

* make style and quality

* clean-up

* Fix TFBertModel's extended_attention_mask when past_key_values is provided

* Fix tests

* fix copies

* More tf.int8 -> tf.int32 in TF test template

* clean-up

* Update TF test template

* revert the previous commit + update the TF test template

* Fix TF template extended_attention_mask when past_key_values is provided

* Fix some styles manually

* clean-up

* Fix ValueError: too many values to unpack in the test

* Fix more: too many values to unpack in the test

* Add a comment for extended_attention_mask when there is past_key_values

* Fix TFElectra extended_attention_mask when past_key_values is provided

* Add tests to other TF models

* Fix for TF Electra test: add prepare_config_and_inputs_for_decoder

* Fix not passing training arg to lm_head in TFRobertaForCausalLM

* Fix tests (with past) for TF Roberta

* add testing for pask_key_values for TFElectra model

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

[examples/summarization and translation] fix readme (#15833)

Add ONNX Runtime quantization for text classification notebook (#15817)

Re-enable doctests for the quicktour (#15828)

* Re-enable doctests for the quicktour

* Re-enable doctests for task_summary (#15830)

* Remove &

Framework split model report (#15825)

Add TFConvNextModel (#15750)

* feat: initial implementation of convnext in tensorflow.

* fix: sample code for the classification model.

* chore: added checked for  from the classification model.

* chore: set bias initializer in the classification head.

* chore: updated license terms.

* chore: removed ununsed imports

* feat: enabled  argument during using drop_path.

* chore: replaced tf.identity with layers.Activation(linear).

* chore: edited default checkpoint.

* fix: minor bugs in the initializations.

* partial-fix: tf model errors for loading pretrained pt weights.

* partial-fix: call method updated

* partial-fix: cross loading of weights (4x3 variables to be matched)

* chore: removed unneeded comment.

* removed playground.py

* rebasing

* rebasing and removing playground.py.

* fix: renaming TFConvNextStage conv and layer norm layers

* chore: added initializers and other minor additions.

* chore: added initializers and other minor additions.

* add: tests for convnext.

* fix: integration tester class.

* fix: issues mentioned in pr feedback (round 1).

* fix: how output_hidden_states arg is propoagated inside the network.

* feat: handling of  arg for pure cnn models.

* chore: added a note on equal contribution in model docs.

* rebasing

* rebasing and removing playground.py.

* feat: encapsulation for the convnext trunk.

* Fix variable naming; Test-related corrections; Run make fixup

* chore: added Joao as a contributor to convnext.

* rebasing

* rebasing and removing playground.py.

* rebasing

* rebasing and removing playground.py.

* chore: corrected copyright year and added comment on NHWC.

* chore: fixed the black version and ran formatting.

* chore: ran make style.

* chore: removed from_pt argument from test, ran make style.

* rebasing

* rebasing and removing playground.py.

* rebasing

* rebasing and removing playground.py.

* fix: tests in the convnext subclass, ran make style.

* rebasing

* rebasing and removing playground.py.

* rebasing

* rebasing and removing playground.py.

* chore: moved convnext test to the correct location

* fix: locations for the test file of convnext.

* fix: convnext tests.

* chore: applied  sgugger's suggestion for dealing w/ output_attentions.

* chore: added comments.

* chore: applied updated quality enviornment style.

* chore: applied formatting with quality enviornment.

* chore: revert to the previous tests/test_modeling_common.py.

* chore: revert to the original test_modeling_common.py

* chore: revert to previous states for test_modeling_tf_common.py and modeling_tf_utils.py

* fix: tests for convnext.

* chore: removed output_attentions argument from convnext config.

* chore: revert to the earlier tf utils.

* fix: output shapes of the hidden states

* chore: removed unnecessary comment

* chore: reverting to the right test_modeling_tf_common.py.

* Styling nits

Co-authored-by: ariG23498 <aritra.born2fly@gmail.com>
Co-authored-by: Joao Gante <joao@huggingface.co>
Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
  • Loading branch information
4 people committed Feb 28, 2022
1 parent d3c83b7 commit ad2db3e
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 47 deletions.
7 changes: 2 additions & 5 deletions docs/source/model_doc/maskformer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@ The abstract from the paper is the following:
*Modern approaches typically formulate semantic segmentation as a per-pixel classification task, while instance-level segmentation is handled with an alternative mask classification. Our key insight: mask classification is sufficiently general to solve both semantic- and instance-level segmentation tasks in a unified manner using the exact same model, loss, and training procedure. Following this observation, we propose MaskFormer, a simple mask classification model which predicts a set of binary masks, each associated with a single global class label prediction. Overall, the proposed mask classification-based method simplifies the landscape of effective approaches to semantic and panoptic segmentation tasks and shows excellent empirical results. In particular, we observe that MaskFormer outperforms per-pixel classification baselines when the number of classes is large. Our mask classification-based method outperforms both current state-of-the-art semantic (55.6 mIoU on ADE20K) and panoptic segmentation (52.7 PQ on COCO) models.*

Tips:
- During training, the authors of DETR did find it helpful to use auxiliary losses in the decoder, especially to help
the model output the correct number of objects of each class. If you set the parameter `use_auxilary_loss` of
[`MaskFormerConfig`] to `True`, then prediction feedforward neural networks and Hungarian losses
are added after each decoder layer (with the FFNs sharing parameters).
- MaskFormer's Transformer decoder is identical to the decoder of [DETR](detr). During training, the authors of DETR did find it helpful to use auxiliary losses in the decoder, especially to help the model output the correct number of objects of each class. If you set the parameter `use_auxilary_loss` of [`MaskFormerConfig`] to `True`, then prediction feedforward neural networks and Hungarian losses are added after each decoder layer (with the FFNs sharing parameters).
- If you want to train the model in a distributed environment across multiple nodes, then one should update the
`get_num_masks` function inside in the `MaskFormerLoss` class of `modeling_maskformer.py`. When training on multiple nodes, this should be
set to the average number of target masks across all nodes, as can be seen in the original implementation [here](https://github.com/facebookresearch/MaskFormer/blob/da3e60d85fdeedcb31476b5edd7d328826ce56cc/mask_former/modeling/criterion.py#L169).
- One can use [`MaskFormerFeatureExtractor`] to prepare images for the model.
- To get the final segmentation, depending on the task, you can call [`~MaskFormerFeatureExtractor.post_process_semantic_segmentation`] or [`~MaskFormerFeatureExtractor.post_process_panoptic_segmentation`].
- To get the final segmentation, depending on the task, you can call [`~MaskFormerFeatureExtractor.post_process_semantic_segmentation`] or [`~MaskFormerFeatureExtractor.post_process_panoptic_segmentation`]. Both tasks can be solved using [`MaskFormerForInstanceSegmentation`] output, the latter needs an additional `is_thing_map` to know which instances must be merged together..

This model was contributed by [francesco](https://huggingface.co/francesco). The original code can be found [here](https://github.com/facebookresearch/MaskFormer).

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class MaskFormerConfig(PretrainedConfig):
no_object_weight (`float`, *optional*, defaults to 0.1):
Weight to apply to the null (no object) class.
use_auxilary_loss (`bool`, *optional*, defaults to `False`):
If `true` [`MaskFormerOutput`] will contain the axusilary losses computed using the logits from each
decoder's stage.
If `true` [`MaskFormerForInstanceSegmentationOutput`] will contain the axusilary losses computed using the
logits from each decoder's stage.
backbone_config (`Dict`, *optional*):
The configuration passed to the backbone, if unset, the configuration corresponding to
`swin-base-patch4-window12-384` will be used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __call__(self, original_config: object) -> MaskFormerFeatureExtractor:
)


class OriginalMaskFormerCheckpoinToOursConverter:
class OriginalMaskFormerCheckpointToOursConverter:
def __init__(self, original_model: nn.Module, config: MaskFormerConfig):
self.original_model = original_model
self.config = config
Expand Down Expand Up @@ -676,7 +676,7 @@ def get_name(checkpoint_file: Path):
if not save_directory.exists():
save_directory.mkdir(parents=True)

for config_file, checkpoint_file in OriginalMaskFormerCheckpoinToOursConverter.using_dirs(
for config_file, checkpoint_file in OriginalMaskFormerCheckpointToOursConverter.using_dirs(
checkpoints_dir, config_dir
):

Expand All @@ -695,7 +695,7 @@ def get_name(checkpoint_file: Path):

mask_former = MaskFormerModel(config=config).eval()

converter = OriginalMaskFormerCheckpoinToOursConverter(original_model, config)
converter = OriginalMaskFormerCheckpointToOursConverter(original_model, config)

maskformer = converter.convert(mask_former)

Expand Down
58 changes: 29 additions & 29 deletions src/transformers/models/maskformer/feature_extraction_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@

class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a MaskFormer feature extractor.
Constructs a MaskFormer feature extractor. The feature extractor can be used to prepare image(s) and optional
targets for the model.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Expand All @@ -55,15 +56,18 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is
set to `True`.
size_divisibility (`int`, *optional*, defaults to 32):
Some backbones need images divisible by a certain number, if not passes it detauls to the value used in
swin.
Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
Swin Transformer.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
ImageNet std.
ignore_index (`int`, *optional*, default to 255):
Value of the index (label) to ignore.
"""

model_input_names = ["pixel_values", "pixel_mask"]
Expand All @@ -77,14 +81,15 @@ def __init__(
do_normalize=True,
image_mean=None,
image_std=None,
ignore_index=255,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.max_size = max_size
self.size_divisibility = size_divisibility
self.ignore_label = 255
self.ignore_index = ignore_index
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] # ImageNet mean
self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] # ImageNet std
Expand Down Expand Up @@ -153,17 +158,6 @@ def get_size(image_size, size, max_size=None):

return rescaled_image, target

def _normalize(self, image, mean, std, target=None):
"""
Normalize the image with a certain mean and std.
If given, also normalize the target bounding boxes based on the size of the image.
"""

image = self.normalize(image, mean=mean, std=std)

return image, target

def __call__(
self,
images: ImageInput,
Expand Down Expand Up @@ -265,17 +259,7 @@ def __call__(
images[idx] = self._resize(image=image, target=None, size=self.size, max_size=self.max_size)[0]

if self.do_normalize:
if annotations is not None:
for idx, (image, target) in enumerate(zip(images, annotations)):
image, target = self._normalize(
image=image, mean=self.image_mean, std=self.image_std, target=target
)
images[idx] = image
annotations[idx] = target
else:
images = [
self._normalize(image=image, mean=self.image_mean, std=self.image_std)[0] for image in images
]
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# NOTE I will be always forced to pad them them since they have to be stacked in the batch dim
encoded_inputs = self.encode_inputs(
images, annotations, pad_and_return_pixel_mask, return_tensors=return_tensors
Expand Down Expand Up @@ -315,6 +299,20 @@ def encode_inputs(
pixel_values_list (`List[torch.Tensor]`):
List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
width)`.
annotations (`Dict`, `List[Dict]`, *optional*):
The corresponding annotations as dictionary of numpy arrays with the following keys:
- **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`.
- **labels** (`np.ndarray`) The target labels of shape `(num_classes)`.
pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
If left to the default, will return a pixel mask that is:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects.
Expand All @@ -337,7 +335,6 @@ def encode_inputs(
pixel_mask = []
mask_labels = []
class_labels = []

for idx, image in enumerate(pixel_values_list):
# create padded image
if pad_and_return_pixel_mask:
Expand Down Expand Up @@ -383,6 +380,9 @@ def post_process_segmentation(
outputs ([`MaskFormerForInstanceSegmentationOutput`]):
The outputs from [`MaskFormerForInstanceSegmentation`].
target_size (`Tuple[int, int]`, *optional*):
If set, the `masks_queries_logits` will be resized to `target_size`.
Returns:
`torch.Tensor`:
A tensor of shape (`batch_size, num_labels, height, width`).
Expand Down Expand Up @@ -475,14 +475,14 @@ def post_process_panoptic_segmentation(
object_mask_threshold (`float`, *optional*, defaults to 0.8):
The object mask threshold.
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
The overlap mask area threshold.
The overlap mask area threshold to use.
is_thing_map (`Dict[int, bool]`, *optional*):
Dictionary mapping class indices to either `True` or `False`, depending on whether or not they are a
thing. If not set, defaults to the `is_thing_map` of COCO panoptic.
Returns:
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represent a `segment_id`.
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`.
- **segments** -- a dictionary with the following keys
- **id** -- an integer representing the `segment_id`.
- **category_id** -- an integer representing the segment's label.
Expand Down
22 changes: 14 additions & 8 deletions src/transformers/models/maskformer/modeling_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class MaskFormerSwinModelOutputWithPooling(ModelOutput):
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
`batch, channels, height, width`. Due to padding, their spatial size cannot inferred before the `forward`
method.
`batch, channels, height, width`. Due to padding, their spatial size cannot be inferred before the
`forward` method.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Expand Down Expand Up @@ -314,7 +314,7 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput):

def upsample_like(pixel_values: Tensor, like: Tensor, mode: str = "bilinear") -> Tensor:
"""
An utility function that upsamples `pixel_values` to match the dimension of `like`
An utility function that upsamples `pixel_values` to match the dimension of `like`.
Args:
pixel_values (`torch.Tensor`):
Expand Down Expand Up @@ -369,7 +369,7 @@ def sigmoid_focal_loss(
) -> Tensor:
r"""
Focal loss proposed in [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) originally used in
RetinaNet. The loss is computed as follows
RetinaNet. The loss is computed as follows:
$$ \mathcal{L}_{\text{focal loss} = -(1 - p_t)^{\gamma}\log{(p_t)} $$
Expand Down Expand Up @@ -657,6 +657,7 @@ def forward(self, input):
return drop_path(input, self.drop_prob, self.training, self.scale_by_keep)


# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin
class MaskFormerSwinSelfAttention(nn.Module):
def __init__(self, config, dim, num_heads):
super().__init__()
Expand Down Expand Up @@ -698,7 +699,13 @@ def transpose_for_scores(self, x):
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
):
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)

Expand Down Expand Up @@ -1061,7 +1068,6 @@ def custom_forward(*inputs):
hidden_states = layer_hidden_states

if output_attentions:
# TODO no idea if that is correct
all_self_attentions = all_self_attentions + (layer_all_hidden_states[1],)

if not return_dict:
Expand Down Expand Up @@ -1663,7 +1669,7 @@ def forward(self, masks_queries_logits, class_queries_logits, mask_labels, class
assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.cpu())
indices.append(assigned_indices)

# TODO this is a little weird, they can be stacked in one tensor
# It could be stacked in one tensor
matched_indices = [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
]
Expand Down Expand Up @@ -1969,7 +1975,7 @@ def forward(self, down: Tensor, left: Tensor) -> Tensor:
class MaskFormerFPNModel(nn.Module):
def __init__(self, in_features: int, lateral_widths: List[int], feature_size: int = 256):
"""
Feature Pyramid Network, given an input tensor and a set of features map of different feature/spatial size, it
Feature Pyramid Network, given an input tensor and a set of feature map of different feature/spatial size, it
creates a list of feature maps with the same feature size.
Args:
Expand Down

0 comments on commit ad2db3e

Please sign in to comment.