-
Notifications
You must be signed in to change notification settings - Fork 31.3k
🚨 Refactor DETR to updated standards #41549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…sks, vision input embeds and query embeds
| if not isinstance(line, str): | ||
| line = line.decode() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line was an str when I tried to use this, not sure why! I can open a separate PR for it though
| if pixel_values is None and inputs_embeds is None: | ||
| raise ValueError("You have to specify either pixel_values or inputs_embeds") | ||
|
|
||
| if inputs_embeds is None: | ||
| batch_size, num_channels, height, width = pixel_values.shape | ||
| device = pixel_values.device | ||
|
|
||
| if pixel_mask is None: | ||
| pixel_mask = torch.ones(((batch_size, height, width)), device=device) | ||
| vision_features = self.backbone(pixel_values, pixel_mask) | ||
| feature_map, mask = vision_features[-1] | ||
|
|
||
| # Apply 1x1 conv to map (N, C, H, W) -> (N, d_model, H, W), then flatten to (N, HW, d_model) | ||
| # (feature map and position embeddings are flattened and permuted to (batch_size, sequence_length, hidden_size)) | ||
| projected_feature_map = self.input_projection(feature_map) | ||
| flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) | ||
| spatial_position_embeddings = ( | ||
| self.position_embedding(shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask) | ||
| .flatten(2) | ||
| .permute(0, 2, 1) | ||
| ) | ||
| flattened_mask = mask.flatten(1) | ||
| else: | ||
| batch_size = inputs_embeds.shape[0] | ||
| device = inputs_embeds.device | ||
| flattened_features = inputs_embeds | ||
| # When using inputs_embeds, we need to infer spatial dimensions for position embeddings | ||
| # Assume square feature map | ||
| seq_len = inputs_embeds.shape[1] | ||
| feat_dim = int(seq_len**0.5) | ||
| # Create position embeddings for the inferred spatial size | ||
| spatial_position_embeddings = ( | ||
| self.position_embedding( | ||
| shape=torch.Size([batch_size, self.config.d_model, feat_dim, feat_dim]), | ||
| device=device, | ||
| dtype=inputs_embeds.dtype, | ||
| ) | ||
| .flatten(2) | ||
| .permute(0, 2, 1) | ||
| ) | ||
| # If a pixel_mask is provided with inputs_embeds, interpolate it to feat_dim, then flatten. | ||
| if pixel_mask is not None: | ||
| mask = nn.functional.interpolate(pixel_mask[None].float(), size=(feat_dim, feat_dim)).to(torch.bool)[0] | ||
| flattened_mask = mask.flatten(1) | ||
| else: | ||
| # If no mask provided, assume all positions are valid | ||
| flattened_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now truly supports passing input_embeds instead of silently doing nothing with it
| if decoder_inputs_embeds is not None: | ||
| queries = decoder_inputs_embeds | ||
| else: | ||
| queries = torch.zeros_like(object_queries_position_embeddings) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, truly supports decoder_inputs_embeds as input
| attention_mask=None, | ||
| object_queries=object_queries, | ||
| query_position_embeddings=query_position_embeddings, | ||
| attention_mask=decoder_attention_mask, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supports masking of queries (as advertised)
| if attention_mask is not None: | ||
| # [batch_size, num_queries] -> [batch_size, 1, num_queries, num_queries] | ||
| attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) | ||
|
|
||
| # expand encoder attention mask | ||
| # expand encoder attention mask (for cross-attention on encoder outputs) | ||
| if encoder_hidden_states is not None and encoder_attention_mask is not None: | ||
| # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] | ||
| encoder_attention_mask = _prepare_4d_attention_mask( | ||
| encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] | ||
| ) | ||
|
|
||
| # optional intermediate hidden states | ||
| intermediate = () if self.config.auxiliary_loss else None | ||
|
|
||
| # decoder layers | ||
| all_hidden_states = () if output_hidden_states else None | ||
| all_self_attns = () if output_attentions else None | ||
| all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None | ||
|
|
||
| for idx, decoder_layer in enumerate(self.layers): | ||
| # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) | ||
| if output_hidden_states: | ||
| all_hidden_states += (hidden_states,) | ||
| if self.training: | ||
| dropout_probability = torch.rand([]) | ||
| if dropout_probability < self.layerdrop: | ||
| continue | ||
|
|
||
| layer_outputs = decoder_layer( | ||
| hidden_states = decoder_layer( | ||
| hidden_states, | ||
| combined_attention_mask, | ||
| object_queries, | ||
| query_position_embeddings, | ||
| attention_mask, | ||
| spatial_position_embeddings, | ||
| object_queries_position_embeddings, | ||
| encoder_hidden_states, # as a positional argument for gradient checkpointing | ||
| encoder_attention_mask=encoder_attention_mask, | ||
| output_attentions=output_attentions, | ||
| **kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Truly supports attention mask on vision features (it was always None before)
|
Hello @molbap @ArthurZucker! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial thoughts, focused on the masks/interface part
| ): | ||
| if use_attention_mask: | ||
| self.skipTest( | ||
| "This test uses attention masks which are not compatible with DETR. Skipping when use_attention_mask is True." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, why tho? Are the attention masks perhaps 3D instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's more that _test_eager_matches_sdpa_inference is not adapted to the vision space (+object queries here). It tries to add a "decoder_input_ids" to the inputs, plus the seqlen created for the dummy masks were wrong. Seeing as the function is already quite cluttered and difficult to read, I figured trying to add support for vision model directly there would not be ideal. We can either override the tests in this model specifically, or try to have a more general test for vision models. Another option would be to be able to parameterize the tests by providing how to find the correct seqlen and input names.
I would love some help on this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, is this specific to detr or will we encounter more so for other models in the vision family? It's best to not skip too much if it comes down the line. Depending on how many are affected by this, we either should
- Fix the base test, e.g. with parametrization, splitting the test a bit (more models with similar problems)
- Overwrite the test and make specific changes (low amount of models with similar problems)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is with the test's base design indeed. It will lead to more skipped tests down the line because the division encoder/encoder-decoder/decoder isn't that clearly made. The amount of models with similar problems isn't "low" imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think it will increase too with us fixing the attention masks for vision models, so we definitely need to improve the base test
|
Thanks for the review @vasqu ! I standardized attention and masking following your advice :) |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good from my side, amazing work! Just left some smaller comments but nothing crazy
| position_embedding = DetrLearnedPositionEmbedding(n_steps) | ||
| else: | ||
| raise ValueError(f"Not supported {config.position_embedding_type}") | ||
| def eager_attention_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use a # Copied from ... statement so that during changes we have the connection here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer not to come back to Copied from statements 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only for models that do not use modular yet. The issue is that if we don't do this and change anything upstream on Bert, it won't be reflected here (and the dependent models) - otherwise, we need to manually keep track of these deps 😅 Don't mind this if we refactor to modular here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal is to use modular eventually (in this PR even) :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Then dont mind this comment :p only want to have some dep here
| _no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"] | ||
| _supports_sdpa = True | ||
| _supports_flash_attn = True | ||
| _supports_attention_backend = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be able to _support_flex_attention (not 100% sure about the flag name) since we use create_bidirectional_masks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's _supports_flex_attn, and yep
| ): | ||
| if use_attention_mask: | ||
| self.skipTest( | ||
| "This test uses attention masks which are not compatible with DETR. Skipping when use_attention_mask is True." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, is this specific to detr or will we encounter more so for other models in the vision family? It's best to not skip too much if it comes down the line. Depending on how many are affected by this, we either should
- Fix the base test, e.g. with parametrization, splitting the test a bit (more models with similar problems)
- Overwrite the test and make specific changes (low amount of models with similar problems)
molbap
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks niiiice
For the unhappy CI, let's throw the Check Copies away!
| ): | ||
| if use_attention_mask: | ||
| self.skipTest( | ||
| "This test uses attention masks which are not compatible with DETR. Skipping when use_attention_mask is True." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is with the test's base design indeed. It will lead to more skipped tests down the line because the division encoder/encoder-decoder/decoder isn't that clearly made. The amount of models with similar problems isn't "low" imo.
| "qwen2_5_vl", | ||
| "videollava", | ||
| "vipllava", | ||
| "detr", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure, do we need to add this here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's what made me go crazy haha otherwise _checkpoint_conversion_mapping doesn't work.
Note that this is temporary and will be replaced by the new way to convert weights on the fly that @ArthurZucker and @Cyrilvallez are working on.
| def __init__(self, config: DetrConfig): | ||
| super().__init__() | ||
| self.embed_dim = config.d_model | ||
| self.hidden_size = config.d_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't that break BC? (at least on the attribute names)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what way? If users access it directly? In any case I think we really need to standardize these types of variable names, it might be worth slightly breaking BC imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah in case of non-config access. I agree I prefer to standardize
| _no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"] | ||
| _supports_sdpa = True | ||
| _supports_flash_attn = True | ||
| _supports_attention_backend = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's _supports_flex_attn, and yep
| if self.training: | ||
| dropout_probability = torch.rand([]) | ||
| if dropout_probability < self.layerdrop: | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not exactly the typical dropout interface, we can maybe take the occasion to update it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes 😫, I was scared of breaking BC in that case, but maybe it's not so important. It would be great to get rid of non standards dropout elsewhere as well really
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's ok to break it in here, it does not affect inference and clearly it would be an improvement to get rid of it haha
| def freeze_backbone(self): | ||
| for name, param in self.backbone.conv_encoder.model.named_parameters(): | ||
| for _, param in self.backbone.model.named_parameters(): | ||
| param.requires_grad_(False) | ||
|
|
||
| def unfreeze_backbone(self): | ||
| for name, param in self.backbone.conv_encoder.model.named_parameters(): | ||
| for _, param in self.backbone.model.named_parameters(): | ||
| param.requires_grad_(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these methods should really be user-side responsibilities 😨 I would be pro-removal! We can always communicate on it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes agreed, we could start a deprecation cycle, or just remove it for v5. It's present in several other vision models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just asked @merveenoyan who's an avid finetuner and is not using these methods anymore, I think they were good initially but they're ok to go now. Agreed it's out of scope for current PR will create another to remove all of it (cc @ariG23498 as we chatted on finetuning too)
| def forward(self, q, k, mask: Optional[torch.Tensor] = None): | ||
| q = self.q_linear(q) | ||
| k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) | ||
| queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on here my nit would be, if we can update a bit the single-letter variable names, that'd be great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think we could even try to refactor this to use the standard attention module and only take the attention weights! It could be interesting to compare the performance of eager attention vs this implementation (conv2d instead of linear for key proj, and no multiplication by value) vs other attention impl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahah that's a tough one to benchmark but indeed sounds good, LMK if you want to do that in this PR or move to another
|
[For maintainers] Suggested jobs to run (before merge) run-slow: detr |
What does this PR do?
This PR aims at refactoring DETR as part of an effort to standardize vision models in the library, in the same vein as #41546.
Expect to see much more PRs like this for vision models as we approach v5!