You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Firstly, thank you so much for this clean implementation!!
The self-supervised training process looks good, but the saved (i.e. improved) model is exactly the same as the initial one on my side. Have you observed the same problem?
The code I tested:
import torch
from net.byol import BYOL
from torchvision import models
resnet = models.resnet50(pretrained=True)
param_1 = resnet.parameters()
learner = BYOL(
resnet,
image_size = 256,
hidden_layer = 'avgpool'
)
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
for _ in range(2):
images = sample_unlabelled_images()
loss = learner(images)
opt.zero_grad()
loss.backward()
opt.step()
learner.update_moving_average() # update moving average of target encoder
# save your improved network
torch.save(resnet.state_dict(), './checkpoints/improved-net.pt')
# restore the model
resnet2 = models.resnet50()
resnet2.load_state_dict(torch.load('./checkpoints/improved-net.pt'))
param_2 = resnet2.parameters()
# test whether two models are the same
for p1, p2 in zip(param_1, param_2):
if p1.data.ne(p2.data).sum() > 0:
print('They are different.')
print('They are same.')
The text was updated successfully, but these errors were encountered:
@KimMeen Hi Ming! It is because your param_1 still references the same parameters, so will follow the changes until the very end of training. If you wish to make a copy of the parameters, you need to do something like this
@lucidrains Thank you for your timely reply! And yes you're correct! I have observed a different model when using copy.deepcopy(model) to keep a snapshot of my initial model.
Firstly, thank you so much for this clean implementation!!
The self-supervised training process looks good, but the saved (i.e. improved) model is exactly the same as the initial one on my side. Have you observed the same problem?
The code I tested:
The text was updated successfully, but these errors were encountered: