Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix load_state_dict for all timm models #1084

Merged
merged 10 commits into from
Feb 11, 2023

Conversation

nilsleh
Copy link
Collaborator

@nilsleh nilsleh commented Feb 3, 2023

This PR closes #1049 , by implementing Isaac's solution.

@nilsleh nilsleh marked this pull request as draft February 3, 2023 10:43
@github-actions github-actions bot added testing Continuous integration testing trainers PyTorch Lightning trainers labels Feb 3, 2023
tests/trainers/test_classification.py Show resolved Hide resolved
torchgeo/trainers/utils.py Outdated Show resolved Hide resolved
@adamjstewart adamjstewart added this to the 0.4.1 milestone Feb 3, 2023
@nilsleh
Copy link
Collaborator Author

nilsleh commented Feb 8, 2023

I looked into the seco weights again. Since they are originally saved as part of a pytorch-lightning module, the keys have different names then the default timm keys. Looking at the seco code there are a "q" and a "k" network and the "q" network is used as a pretrained backbone for downstream tasks. The encoder_q.0.weight is the only weight that has the same shape as the default timm model and there is also no bias at this stage, so I think encoder_q.0.weight is the conv1.weight we look for. I can rename it and reupload to huggingface and then I think we should be good. Or do you have a suggestion?

import torch
import timm

timm_model = timm.create_model("resnet18")
checkpoint = torch.load("path/to/seco_resnet18_1m.ckpt", map_location="cpu")

state_dict = checkpoint["state_dict"]
assert timm_model.conv1.weight.shape == state_dict["encoder_q.0.weight"].shape

for key, val in checkpoint["state_dict"].items():
    if key not in ["encoder_q.0.weight", "encoder_k.0.weight"]:
        if state_dict["encoder_q.0.weight"].shape == val.shape:
            print(key)

state_dict["conv1.weight"] = state_dict.pop("encoder_q.0.weight")
# save this state dict with just the encoder_q weights and renamed 0.weight and upload to huggingface

@adamjstewart
Copy link
Collaborator

Your guess is as good as mine. Do the authors have any code for loading the pretrained model like your group does? If not, then I think your analysis makes sense. If you really want to be sure, you can try to train a model with those pretrained weights and make sure if converges quickly.

@calebrob6
Copy link
Member

You could use something like this -- https://gist.github.com/calebrob6/44f2e42017e2d192e837f0a1cd526c50 -- to make sure that linear-probing on a downstream dataset with the model achieves good performance. This is the notebook that I used to verify one of the SSL4EO weights I think.

@nilsleh
Copy link
Collaborator Author

nilsleh commented Feb 8, 2023

Yes, so the encoder_q I got from this code where they load a backbone from their pretrained model for a downstream task.

@nilsleh
Copy link
Collaborator Author

nilsleh commented Feb 9, 2023

I did the above described extraction of the weights and tried @calebrob6 script with the extracted seco resnet18 weights. As they are pretrained on RGB I only use the Eurosat RGB bands. Here are the scores I get:

Edit:
Correction, I forgot to call model.eval()...

This is with preprocessing step of just dividing image values by 1000:

  • seco resnet18 weights 0.7625
  • timm resnet imagenet 0.79
  • timm resnet random init 0.46

This is with preprocessing step of using the provided normalization stats for bands

  • seco resnet18 weights 0.75
  • timm resnet imagenet 0.89
  • timm resnet random init 0.65

@nilsleh
Copy link
Collaborator Author

nilsleh commented Feb 9, 2023

I will try to investigate again tomorrow. This is what I am using as a script (takes about 3 minutes to run on cpu locally). And for some reason one needs pytorch-lightning==1.1.8 in order to do load the original checkpoint file.

@nilsleh
Copy link
Collaborator Author

nilsleh commented Feb 10, 2023

I downloaded the bigearthnet dataset to try the linear probing script on that since that is the dataset they also report in their paper. For bigearthnet with 10,000 samples (not the full dataset), I get the following scores:

  • seco resnet18 weights: 0.4983
  • timm imagenet: 0.4512
  • timm random: 0.442

@adamjstewart
Copy link
Collaborator

Interesting evidence that the weights may not be very transferable...

@adamjstewart
Copy link
Collaborator

But this should be sufficient proof that your approach to extracting the first layer of weights is correct and we can move forward with this PR.

@nilsleh
Copy link
Collaborator Author

nilsleh commented Feb 10, 2023

Screenshot from 2023-02-10 19-57-31

In the paper, they report quiet a significant improvement when using Seco in linear probing, so I must be doing something wrong. I can also contact the authors to sort it out.

@adamjstewart
Copy link
Collaborator

That's prob for MSI, not RGB

@nilsleh
Copy link
Collaborator Author

nilsleh commented Feb 10, 2023

I think everything is only RGB, at least that is how I interpret it when they state "Although the collected dataset contains up to 12 spectral bands, in this work we focus on the RGB channels since it is a more general modality."

@nilsleh nilsleh marked this pull request as ready for review February 10, 2023 21:14
@nilsleh
Copy link
Collaborator Author

nilsleh commented Feb 10, 2023

For the moment I updated the seco weights on huggingface, and the loading works for all weights now.

@adamjstewart
Copy link
Collaborator

I'm still seeing the same issue:

$ pytest -m slow tests/trainers/
...
FAILED tests/trainers/test_byol.py::TestBYOLTask::test_weight_enum_download[ResNet18_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_byol.py::TestBYOLTask::test_weight_enum_download[ResNet50_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_byol.py::TestBYOLTask::test_weight_str_download[ResNet18_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_byol.py::TestBYOLTask::test_weight_str_download[ResNet50_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_classification.py::TestClassificationTask::test_weight_enum_download[ResNet18_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_classification.py::TestClassificationTask::test_weight_enum_download[ResNet50_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_classification.py::TestClassificationTask::test_weight_str_download[ResNet18_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_classification.py::TestClassificationTask::test_weight_str_download[ResNet50_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_regression.py::TestRegressionTask::test_weight_enum_download[ResNet18_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_regression.py::TestRegressionTask::test_weight_enum_download[ResNet50_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_regression.py::TestRegressionTask::test_weight_str_download[ResNet18_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'
FAILED tests/trainers/test_regression.py::TestRegressionTask::test_weight_str_download[ResNet50_Weights.SENTINEL2_RGB_SECO] - KeyError: 'conv1.weight'

@nilsleh
Copy link
Collaborator Author

nilsleh commented Feb 11, 2023

Mhm I am not getting those errors. Maybe, the old weights are still cached in your torch/hub? I had to delete those so it would reload the new ones after I uploaded them to huggingface.

@adamjstewart
Copy link
Collaborator

Oh mine are prob cached, let me delete

@adamjstewart
Copy link
Collaborator

Yep, works now. Thanks!

@adamjstewart adamjstewart merged commit a461d58 into microsoft:main Feb 11, 2023
calebrob6 pushed a commit that referenced this pull request Apr 10, 2023
* implement isaacs solution

* simple test for function

* private function but failing tests

* Fix in_channels

* Fix model

* Test real weights

* Real weights have no final layer

* Style fixes

* expand test coverage of other trainers

* revert byol image_size

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
* implement isaacs solution

* simple test for function

* private function but failing tests

* Fix in_channels

* Fix model

* Test real weights

* Real weights have no final layer

* Style fixes

* expand test coverage of other trainers

* revert byol image_size

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

utils.load_state_dict() does not support ViT architecture
3 participants