Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct the seco weight names #1593

Merged
merged 7 commits into from
Sep 30, 2023

Conversation

nilsleh
Copy link
Collaborator

@nilsleh nilsleh commented Sep 28, 2023

Closes #1234

@github-actions github-actions bot added the models Models and pretrained weights label Sep 28, 2023
@adamjstewart adamjstewart added this to the 0.5.0 milestone Sep 28, 2023
@adamjstewart
Copy link
Collaborator

Having trouble reproducing the original issue so it's hard for me to verify that it's now fixed.

@calebrob6
Copy link
Member

calebrob6 commented Sep 28, 2023

old_keys = set(torch.load("resnet18_sentinel2_rgb_seco-9976a9cb.pth").keys())
new_keys = set(torch.load("resnet18_sentinel2_rgb_seco-c2e6e88b.pth").keys())

print(old_keys - new_keys)

Gives:

{'layer3.0.bn1.bias',
 'layer3.0.bn1.num_batches_tracked',
 'layer3.0.bn1.running_mean',
 'layer3.0.bn1.running_var',
 'layer3.0.bn1.weight',
 'layer3.0.bn2.bias',
 'layer3.0.bn2.num_batches_tracked',
 'layer3.0.bn2.running_mean',
 'layer3.0.bn2.running_var',
 'layer3.0.bn2.weight',
 'layer3.0.conv1.weight',
 'layer3.0.conv2.weight',
 'layer3.0.downsample.0.weight',
 'layer3.0.downsample.1.bias',
 'layer3.0.downsample.1.num_batches_tracked',
 'layer3.0.downsample.1.running_mean',
 'layer3.0.downsample.1.running_var',
 'layer3.0.downsample.1.weight',
 'layer3.1.bn1.bias',
 'layer3.1.bn1.num_batches_tracked',
 'layer3.1.bn1.running_mean',
 'layer3.1.bn1.running_var',
 'layer3.1.bn1.weight',
 'layer3.1.bn2.bias',
 'layer3.1.bn2.num_batches_tracked',
 'layer3.1.bn2.running_mean',
 'layer3.1.bn2.running_var',
 'layer3.1.bn2.weight',
 'layer3.1.conv1.weight',
 'layer3.1.conv2.weight',
 'layer4.0.bn1.bias',
 'layer4.0.bn1.num_batches_tracked',
 'layer4.0.bn1.running_mean',
 'layer4.0.bn1.running_var',
 'layer4.0.bn1.weight',
 'layer4.0.bn2.bias',
 'layer4.0.bn2.num_batches_tracked',
 'layer4.0.bn2.running_mean',
 'layer4.0.bn2.running_var',
 'layer4.0.bn2.weight',
 'layer4.0.conv1.weight',
 'layer4.0.conv2.weight',
 'layer4.0.downsample.0.weight',
 'layer4.0.downsample.1.bias',
 'layer4.0.downsample.1.num_batches_tracked',
 'layer4.0.downsample.1.running_mean',
 'layer4.0.downsample.1.running_var',
 'layer4.0.downsample.1.weight',
 'layer4.1.bn1.bias',
 'layer4.1.bn1.num_batches_tracked',
 'layer4.1.bn1.running_mean',
 'layer4.1.bn1.running_var',
 'layer4.1.bn1.weight',
 'layer4.1.bn2.bias',
 'layer4.1.bn2.num_batches_tracked',
 'layer4.1.bn2.running_mean',
 'layer4.1.bn2.running_var',
 'layer4.1.bn2.weight',
 'layer4.1.conv1.weight',
 'layer4.1.conv2.weight'}

@nilsleh, is this correct? The new weights don't have any entries for layer 3 or layer 4.

@nilsleh
Copy link
Collaborator Author

nilsleh commented Sep 29, 2023

@nilsleh, is this correct? The new weights don't have any entries for layer 3 or layer 4.

I used the script you posted in the issue to update the weights and didn't check further. I will look into again.

@calebrob6
Copy link
Member

Oh oops... sorry for leading you in the wrong direction here! Good to not trust my code unless it has been reviewed by Adam.

@nilsleh
Copy link
Collaborator Author

nilsleh commented Sep 29, 2023

Oh oops... sorry for leading you in the wrong direction here! Good to not trust my code unless it has been reviewed by Adam.

Sorry, turns out the mistake was on my end of course, just always better to first blame others...

@adamjstewart
Copy link
Collaborator

Still curious how to reproduce the original issue so I can confirm that things have actually been fixed. If our (slow) tests don't catch this bug, then I'm worried about our tests.

@calebrob6
Copy link
Member

calebrob6 commented Sep 29, 2023

The following works for getting the SeCo pretrained weights in a timm resnet format:

from lightning import LightningModule
import timm
import torch
from copy import deepcopy
from torch import nn
import torchvision

# From https://github.com/ServiceNow/seasonal-contrast/blob/main/models/moco2_module.py
class MocoV2(LightningModule):

    def __init__(self, base_encoder, emb_dim, num_negatives, emb_spaces=1, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        # create the encoders
        template_model = getattr(torchvision.models, base_encoder)
        self.encoder_q = template_model(num_classes=self.hparams.emb_dim)
        self.encoder_k = template_model(num_classes=self.hparams.emb_dim)

        # remove fc layer
        self.encoder_q = nn.Sequential(*list(self.encoder_q.children())[:-1], nn.Flatten())
        self.encoder_k = nn.Sequential(*list(self.encoder_k.children())[:-1], nn.Flatten())

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the projection heads
        self.mlp_dim = 512 * (1 if base_encoder in ['resnet18', 'resnet34'] else 4)
        self.heads_q = nn.ModuleList([
            nn.Sequential(nn.Linear(self.mlp_dim, self.mlp_dim), nn.ReLU(), nn.Linear(self.mlp_dim, emb_dim))
            for _ in range(emb_spaces)
        ])
        self.heads_k = nn.ModuleList([
            nn.Sequential(nn.Linear(self.mlp_dim, self.mlp_dim), nn.ReLU(), nn.Linear(self.mlp_dim, emb_dim))
            for _ in range(emb_spaces)
        ])

        for param_q, param_k in zip(self.heads_q.parameters(), self.heads_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(emb_spaces, emb_dim, num_negatives))
        self.queue = nn.functional.normalize(self.queue, dim=1)

        self.register_buffer("queue_ptr", torch.zeros(emb_spaces, 1, dtype=torch.long))

if __name__ == "__main__":
    backbone = MocoV2.load_from_checkpoint("../seasonal-contrast/seco_resnet50_1m.ckpt")
    model = deepcopy(backbone.encoder_q).eval()
    state_dict_original = model.state_dict()

    patient_model = timm.create_model("resnet50")
    key_correct_name_list = list(patient_model.state_dict().keys())

    state_dict_new = {}
    for i, (k, v) in enumerate(state_dict_original.items()):
            state_dict_new[key_correct_name_list[i]] = v
    print(patient_model.load_state_dict(state_dict_new, strict=False))

    torch.save(patient_model.state_dict(), "seco_resnet50_1m.ckpt")

adamjstewart
adamjstewart previously approved these changes Sep 30, 2023
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figured out how to test, just need to run the slow tests pytest -m slow tests/models and either make it strict or assert that there are no missing/unexpected keys.

@adamjstewart adamjstewart merged commit 0547a65 into microsoft:main Sep 30, 2023
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
models Models and pretrained weights
Projects
None yet
Development

Successfully merging this pull request may close these issues.

SeCo weights issue warning when loaded
3 participants