Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The saved network is same as the initial one? #32

Closed
KimMeen opened this issue Oct 15, 2020 · 3 comments
Closed

The saved network is same as the initial one? #32

KimMeen opened this issue Oct 15, 2020 · 3 comments

Comments

@KimMeen
Copy link

KimMeen commented Oct 15, 2020

Firstly, thank you so much for this clean implementation!!

The self-supervised training process looks good, but the saved (i.e. improved) model is exactly the same as the initial one on my side. Have you observed the same problem?

The code I tested:

import torch
from net.byol import BYOL
from torchvision import models
 
       
resnet = models.resnet50(pretrained=True)
param_1 = resnet.parameters()

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(2):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of target encoder

# save your improved network
torch.save(resnet.state_dict(), './checkpoints/improved-net.pt')

# restore the model      
resnet2 = models.resnet50()
resnet2.load_state_dict(torch.load('./checkpoints/improved-net.pt'))
param_2 = resnet2.parameters()

# test whether two models are the same 
for p1, p2 in zip(param_1, param_2):
    if p1.data.ne(p2.data).sum() > 0:
        print('They are different.')
print('They are same.')
@neversap
Copy link

I'm also experiencing the same issue, any solutions?

@lucidrains
Copy link
Owner

@KimMeen Hi Ming! It is because your param_1 still references the same parameters, so will follow the changes until the very end of training. If you wish to make a copy of the parameters, you need to do something like this

from copy import deepcopy
param_copy = deepcopy(model).parameters()

@KimMeen
Copy link
Author

KimMeen commented Oct 16, 2020

@lucidrains Thank you for your timely reply! And yes you're correct! I have observed a different model when using copy.deepcopy(model) to keep a snapshot of my initial model.

@KimMeen KimMeen closed this as completed Oct 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants