-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
utils.load_state_dict() does not support ViT architecture #1049
Comments
Good catch! We could either add this to our model tests or our trainer tests or both. Honestly, I would love to get rid of everything in |
I think having the flexibility of |
Separate PR. New features have to wait until 0.5, but bug fixes can be backported to 0.4.1. |
I think we can fix this by checking the parent class of the model For example: import timm
resnet = timm.create_model("resnet18")
vit = timm.create_model("vit_small_patch16_224")
type(resnet)
# <class 'timm.models.resnet.ResNet'>
type(vit)
# <class 'timm.models.vision_transformer.VisionTransformer'> Thoughts? |
We will eventually need to support far more models than just ResNet and ViT. Prob best to use |
I don't have a solution yet but I don't think this will be the answer either. It may be easier to just force the user to input which backbone the state dict was originally from. |
They do input the backbone, but they don't input which layer to look at to find in_channels. |
I was thinking of finding the in_channels by looking at the first named parameter in the model with |
That's actually a really interesting idea! I would do that as the fallback and then only special case the few architectures like ViT where that doesn't work. |
ViT If you recursively search the first named children until it's a base module with no children you can actually get the key e.g. and this works for resnet, vit, efficientnet, etc. Haven't done a thorough search though. import timm
def get_input_layer(backbone):
model = timm.create_model(backbone)
keys = []
children = list(model.named_children())
while children != []:
name, module = children[0]
keys.append(name)
children = list(module.named_children())
key = ".".join(keys)
return key, module
get_input_layer("resnet18")
# ('conv1',
# Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False))
get_input_layer("vit_small_patch16_224")
# ('patch_embed.proj', Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16)))
get_input_layer("efficientnet_b0")
# ('conv_stem',
# Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)) |
If I loop through all timm models with your code, all models have
|
Dang you beat me. I'm just finishing looping through them in the next minute or so |
If all models default to 3 channels then we can just hardcode it to 3 |
I believe all models are trained for imagenet or variants of imagenet so it's unlikely that it's not 3 channels. |
But I think you still need to obtain the first parameter name to index the state_dict to get |
Some analysis on the type of the first layer. Majority is I think it's fair to assume that we can use the above solution to get the state_dict key and input channels.
|
Description
When using the new pretrained weights, the
load_state_dict()
function is looking for a "conv1" in the model to determine the expected input channels to the models. However, only ResNet Architecture weights have this. ViTs begin with "patch_embed". I think the trainertest_classification.py
as well as the others are only testing with a ResNet18, so that is why the tests didn't catch it.Steps to reproduce
Version
0.4.0
The text was updated successfully, but these errors were encountered: