Skip to content

🚨 🚨Bring some dinos to modern standards#46266

Open
molbap wants to merge 11 commits into
mainfrom
improve_dinos
Open

🚨 🚨Bring some dinos to modern standards#46266
molbap wants to merge 11 commits into
mainfrom
improve_dinos

Conversation

@molbap
Copy link
Copy Markdown
Contributor

@molbap molbap commented May 28, 2026

What does this PR do?

Part of the larger vision model refactor #41693 focused on dinov2, which has still some usage and downloads, but mostly serves as a basis for many other models. Attempt at putting this in line with the rest of the lib.

image

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Member

@guarin guarin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this will make things much easier! Left more questions than comments :)

Comment on lines +48 to +76
self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might this break things downstream if now every model expects mask token to exist?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

totally, missing ternary with None default

Comment on lines -194 to +162
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.attention_dropout = config.attention_probs_dropout_prob
self.scaling = self.head_dim**-0.5
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we consider attributes as part of the API that shouldn't change? E.g. here the rename from self.dropout_prob to self.attention_dropout could be backwards incompatible

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in theory it's OK, in the sense that we can still read old hub configs and make them work with the new code -it is backwards-compatible in that sense. So it's a minor breakage I'd say, also would allow to be more aligned with ViT naming-wise.

Comment on lines +336 to +339
if isinstance(module, Dinov2Embeddings):
init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
init.zeros_(module.mask_token)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the custom init is already correctly inherited from ViTPreTrainedModel and we don't have to overwrite it. The if xyz is not None checks will always pass.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

Comment on lines +367 to +368
use_mask_token (`bool`, *optional*, defaults to `False`):
Whether to use a mask token for masked image modeling.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove use_mask_token from the docstring or mention that it is ignored? add_pooling_layer also doesn't seem to be used

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well it'll be used in the end

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not pooling_layer though)

Comment on lines +375 to +376
self.pooler = None
self.encoder = Dinov2Encoder(config)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super minor but it feels off if modules are not declared in the order they are accessed. For example now self.encoder is declared after self.layernorm. This impacts module printing and some torch utils which rely on order of modules. I also don't see self.layers being accessed, where is it needed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modules should absolutely be declared in inheritance order! PR is in draft so haven't checked yet, but yes. For self.layers it's an inheritance from VitModel

Comment on lines +52 to +53
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inline num_patches?

Comment on lines +193 to 194
if isinstance(module, Dinov2WithRegistersEmbeddings):
init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also not needed

Comment on lines -274 to -275

self.num_register_tokens = config.num_register_tokens
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removal of self.num_register_tokens might also break backwards compat

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that can minor-ly, yes, I'm pretty sure this PR needs 🚨 🚨 because we might not be able to get around some breakage (even if we keep all old attributes)

def get_input_embeddings(self):
return self.embeddings.patch_embeddings

@can_return_tuple
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is can_return_tuple not needed anymore?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@capture_outputs(tie_last_hidden_states=False) supersedes it!

Comment on lines -259 to +265
torch.testing.assert_close(predicted_depth[0, :3, :3], expected_slice, rtol=1e-6, atol=1e-6)
torch.testing.assert_close(predicted_depth[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is tolerance so much higher?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, because it was broken, haha. Noticed that when running on the DGX

E       AssertionError: Tensor-likes are not close!
E       
E       Mismatched elements: 8 / 9 (88.9%)
E       Greatest absolute difference: 5.53131103515625e-05 at index (0, 2) (up to 1e-06 allowed)
E       Greatest relative difference: 6.415643383661518e-06 at index (0, 2) (up to 1e-06 allowed)

(this is on main). Related because dinov2 is the main backbone

@molbap molbap mentioned this pull request May 29, 2026
39 tasks
@molbap molbap marked this pull request as ready for review June 4, 2026 13:04
@molbap molbap changed the title Improve dinos 🚨 🚨Bring some dinos to modern standards Jun 4, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 4, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: depth_anything, dinov2, dinov2_with_registers, dinov3_convnext, dinov3_vit, eomt, eomt_dinov3, pixio, rf_detr, sapiens2, videomt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants