[Q] What's the best way to override a PyTorch module used in a timm model? #2101
Replies: 3 comments 1 reply
-
F.sdpa usually works fine but there was a sequence of bugs related to masking over recent 2.x releases. F.sdpa usage can be disabled in timm by setting TIMM_FUSED_ATTN=0 in your environment. |
Beta Was this translation helpful? Give feedback.
-
To replace modules, the best way is as freeze_bn or syncbn examples, iterate over the modules recursively and rebuild the model by swapping ones that match your criteria https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/norm_act.py#L148-L187 ...but, if the bits of code you want to replace are called functionally you cannot do that, you could FX the model and manipulate the graph but that is complicated and has limitations, you typically need to alter the code at that point or have flags in place to allow different paths |
Beta Was this translation helpful? Give feedback.
-
Simply write an unbounded function f with self as first input:
Now import transformers and assign f to a module forward.
You then override the hugging face functions. |
Beta Was this translation helpful? Give feedback.
-
For my research, I need to override the forward method of some modules such as
nn.SiLU
. I have been doing this manually before, going through the code of the ViT model and just changing the modules used to my overridden versions. But I like to be able to create a general way to use any timm (or ideally PyTorch) model with my modifications.I have seen forward hooks, but ideally I want to be able to use
class MySiLU(nn.SiLU)
which gives me more flexibility. The forward hook would be inefficient for my purposes, as it would compute the forward pass twice (I need to recompute the forward pass for my changes).I also need to override
F.scaled_dot_product_attention
. Should I just doF.scaled_dot_product_attention = my_scaled_dot_product_attention
? Ideally, I also need to access the parent module that calledF.scaled_dot_product_attention
. Is that possible through some hack?Is
F.scaled_dot_product_attention
even reliable? @rwightman mentions thatF.scaled_dot_product_attention
is buggy in DINOv2 worse performance compared to the original version · Issue #2094 · huggingface/pytorch-image-models. I am using PyTorch1.13.1+cu116
, as it's the latest version my vGPU driver supports.Any suggestions?
Beta Was this translation helpful? Give feedback.
All reactions