Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SeCo weights issue warning when loaded #1234

Closed
adamjstewart opened this issue Apr 12, 2023 · 3 comments · Fixed by #1593
Closed

SeCo weights issue warning when loaded #1234

adamjstewart opened this issue Apr 12, 2023 · 3 comments · Fixed by #1593
Assignees
Labels
models Models and pretrained weights testing Continuous integration testing
Milestone

Comments

@adamjstewart
Copy link
Collaborator

Description

It seems like the SeCo weights still have the final layer from when they were trained, meaning they always issue a warning when loaded into a trainer. We should remove the last layer and make sure the only keys in the weight are the ones needed to load the model.

Steps to reproduce

Any of the following commands will reproduce the warning:

$ pytest -m slow tests/trainers/test_regression.py::TestRegressionTask::test_weight_enum_download[ResNet50_Weights.SENTINEL2_RGB_SECO]
$ pytest -m slow tests/trainers/test_regression.py::TestRegressionTask::test_weight_str_download[ResNet50_Weights.SENTINEL2_RGB_SECO]
$ pytest -m slow tests/trainers/test_classification.py::TestClassificationTask::test_weight_enum_download[ResNet50_Weights.SENTINEL2_RGB_SECO]
$ pytest -m slow tests/trainers/test_classification.py::TestClassificationTask::test_weight_str_download[ResNet50_Weights.SENTINEL2_RGB_SECO]

No other weights issue this warning. We could ignore it with with pytest.warns():, but this will cause all other weights to fail the tests since they don't issue a warning.

Version

0.5.0.dev0 (35525b2)

@adamjstewart adamjstewart added models Models and pretrained weights testing Continuous integration testing labels Apr 12, 2023
@adamjstewart adamjstewart added this to the 0.4.2 milestone Apr 12, 2023
@calebrob6
Copy link
Member

The current SeCo weights are almost entirely incorrect as the only key that matches between our uploaded weights and the author's weights is "conv1.weight".

@calebrob6
Copy link
Member

calebrob6 commented May 8, 2023

Here's how you can extract just the encoder weights and rename them into a format that timm/torchvision understands:

from lightning.pytorch.utilities.migration import pl_legacy_patch
import timm
from src.moco import MocoV2  # from the SeCo repo
import torch
from copy import deepcopy

with pl_legacy_patch():
    backbone = MocoV2.load_from_checkpoint("../../seasonal-contrast/seco_resnet50_1m.ckpt")
    model = deepcopy(backbone.encoder_q).eval()
    state_dict_original = model.state_dict()

    resnet18 = timm.create_model("resnet50")
    key_correct_name_list = list(resnet18.state_dict().keys())

    state_dict_new = {}
    for i, (k, v) in enumerate(state_dict_original.items()):
         state_dict_new[key_correct_name_list[i]] = v

    torch.save(state_dict_new, "seco_resnet50_1m.ckpt")

@adamjstewart
Copy link
Collaborator Author

@nilsleh if you have time after you finish #1482 can you take a look at this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
models Models and pretrained weights testing Continuous integration testing
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants