Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

utils.load_state_dict() does not support ViT architecture #1049

Closed
nilsleh opened this issue Jan 26, 2023 · 16 comments · Fixed by #1084
Closed

utils.load_state_dict() does not support ViT architecture #1049

nilsleh opened this issue Jan 26, 2023 · 16 comments · Fixed by #1084
Labels
models Models and pretrained weights trainers PyTorch Lightning trainers
Milestone

Comments

@nilsleh
Copy link
Collaborator

nilsleh commented Jan 26, 2023

Description

When using the new pretrained weights, the load_state_dict() function is looking for a "conv1" in the model to determine the expected input channels to the models. However, only ResNet Architecture weights have this. ViTs begin with "patch_embed". I think the trainer test_classification.py as well as the others are only testing with a ResNet18, so that is why the tests didn't catch it.

Steps to reproduce

from torchgeo.models import ViTSmall16_Weights
from torchgeo.trainers import ClassificationTask

task = ClassificationTask(
    model="vit_small_patch16_224",
    weights=ViTSmall16_Weights.SENTINEL2_ALL_DINO,
    num_classes=10,
    in_channels=13,
)

Version

0.4.0

@adamjstewart adamjstewart added this to the 0.4.1 milestone Jan 26, 2023
@adamjstewart
Copy link
Collaborator

Good catch! We could either add this to our model tests or our trainer tests or both.

Honestly, I would love to get rid of everything in torchgeo/trainers/utils.py and instead use the builtin functions in PyTorch. The only reason we override load_state_dict is for when num_classes or in_channels doesn't match the model weights, but we could choose to only support model weights with the correct in_channels and always remove the last layer so num_classes doesn't matter. Not sure if that would be realistic or not as a fix for this.

@adamjstewart adamjstewart added models Models and pretrained weights trainers PyTorch Lightning trainers labels Jan 26, 2023
@nilsleh
Copy link
Collaborator Author

nilsleh commented Jan 27, 2023

I think having the flexibility of num_classes and in_channels is a nice feature for running experiment as a user so I'd be in favor of keeping it. I am working on a PR for adding more ViT weights, so I could fix the issue there or do you want a separate PR?

@adamjstewart
Copy link
Collaborator

Separate PR. New features have to wait until 0.5, but bug fixes can be backported to 0.4.1.

@isaaccorley
Copy link
Collaborator

isaaccorley commented Jan 27, 2023

I think we can fix this by checking the parent class of the model

For example:

import timm

resnet = timm.create_model("resnet18")
vit = timm.create_model("vit_small_patch16_224")

type(resnet)
# <class 'timm.models.resnet.ResNet'>

type(vit)
# <class 'timm.models.vision_transformer.VisionTransformer'>

Thoughts?

@adamjstewart
Copy link
Collaborator

We will eventually need to support far more models than just ResNet and ViT. Prob best to use hasattr and look for common names for the first layer. Hopefully there are only a few names for the first layer.

@isaaccorley
Copy link
Collaborator

isaaccorley commented Jan 27, 2023

  • resnet starts with conv1
  • vit starts with patch_embed
  • efficientnet starts with conv_stem

I don't have a solution yet but I don't think this will be the answer either. It may be easier to just force the user to input which backbone the state dict was originally from.

@adamjstewart
Copy link
Collaborator

They do input the backbone, but they don't input which layer to look at to find in_channels.

@nilsleh
Copy link
Collaborator Author

nilsleh commented Jan 27, 2023

I was thinking of finding the in_channels by looking at the first named parameter in the model with model.named_parameters(), however, in the ViT case index 0 is a cls_token and not the patch_embed so cannot just taken the first index. Same goes for the state_dict. I suppose one could generate a dict that maps all backbone names in timm to the named parameter where to look for the in_channels... not very elegant.

@adamjstewart
Copy link
Collaborator

That's actually a really interesting idea! I would do that as the fallback and then only special case the few architectures like ViT where that doesn't work.

@isaaccorley
Copy link
Collaborator

ViT patch_embed actually starts with a conv layer. So if you do list(list(model.children())[0].children())[0] you can access it.

If you recursively search the first named children until it's a base module with no children you can actually get the key e.g. and this works for resnet, vit, efficientnet, etc. Haven't done a thorough search though.

import timm

def get_input_layer(backbone):
    model = timm.create_model(backbone)

    keys = []
    children = list(model.named_children())
    while children != []:
        name, module = children[0]
        keys.append(name)
        children = list(module.named_children())
    
    key = ".".join(keys)
    return key, module

get_input_layer("resnet18")
# ('conv1',
# Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False))

get_input_layer("vit_small_patch16_224")
# ('patch_embed.proj', Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16)))

get_input_layer("efficientnet_b0")
# ('conv_stem',
# Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False))

@nilsleh
Copy link
Collaborator Author

nilsleh commented Jan 27, 2023

import timm
import pandas as pd


def get_input_layer(backbone):
    model = timm.create_model(backbone)

    keys = []
    children = list(model.named_children())
    while children != []:
        name, module = children[0]
        keys.append(name)
        children = list(module.named_children())

    key = ".".join(keys)
    return key, module


all_timm_model_names = timm.list_models()
result = {}
for model_name in tqdm(all_timm_model_names):
    key, module = get_input_layer(model_name)
    try:
        num_in_channels = module.in_channels
        result[model_name] = num_in_channels
    except:
        result[model_name] = -1

df = (
    pd.DataFrame.from_dict(result, orient="index")
    .reset_index()
    .rename(columns={0: "num_in_chans", "index": "model_name"})
)

print(df[df["num_in_chans"]==-1])

If I loop through all timm models with your code, all models have num_in_channels=3, except for:

               model_name  num_in_chans
532             tresnet_l            -1
533         tresnet_l_448            -1
534             tresnet_m            -1
535         tresnet_m_448            -1
536  tresnet_m_miil_in21k            -1
537            tresnet_xl            -1
538        tresnet_xl_448            -1

@isaaccorley
Copy link
Collaborator

Dang you beat me. I'm just finishing looping through them in the next minute or so

@adamjstewart
Copy link
Collaborator

If all models default to 3 channels then we can just hardcode it to 3

@isaaccorley
Copy link
Collaborator

I believe all models are trained for imagenet or variants of imagenet so it's unlikely that it's not 3 channels.

@nilsleh
Copy link
Collaborator Author

nilsleh commented Jan 27, 2023

But I think you still need to obtain the first parameter name to index the state_dict to get expected_in_channels from the weights you try to load and for that you need Isaac's solution.

@isaaccorley
Copy link
Collaborator

isaaccorley commented Jan 27, 2023

Some analysis on the type of the first layer. Majority is Conv2d. Looks like the only one that is not a subclass of Conv2d is SpaceToDepthModule which doesn't depend on input channels.

I think it's fair to assume that we can use the above solution to get the state_dict key and input channels.

Counter({
         'Conv2d': 825,
         'ScaledStdConv2dSame': 7,
         'ScaledStdConv2d': 28,
         'Conv2dSame': 66,
         'StdConv2d': 15,
         'SpaceToDepthModule': 8,
         'StdConv2dSame': 15
})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
models Models and pretrained weights trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants