Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DenseCL init weights copy query encoder weights to key encoder. #411

Merged
merged 2 commits into from Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 8 additions & 5 deletions mmselfsup/models/algorithms/densecl.py
Expand Up @@ -49,11 +49,6 @@ def __init__(self,
self.encoder_k = nn.Sequential(
build_backbone(backbone), build_neck(neck))

for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False

self.backbone = self.encoder_q[0]
assert head is not None
self.head = build_head(head)
Expand All @@ -71,6 +66,14 @@ def __init__(self,
self.queue2 = nn.functional.normalize(self.queue2, dim=0)
self.register_buffer('queue2_ptr', torch.zeros(1, dtype=torch.long))

def init_weights(self):
"""Init weights and copy query encoder init weights to key encoder."""
super().init_weights()
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False

@torch.no_grad()
def _momentum_update_key_encoder(self):
"""Momentum update of the key encoder."""
Expand Down
6 changes: 6 additions & 0 deletions tests/test_models/test_algorithms/test_densecl.py
Expand Up @@ -57,6 +57,12 @@ def test_densecl():
assert alg.queue.size() == torch.Size([feat_dim, queue_len])
assert alg.queue2.size() == torch.Size([feat_dim, queue_len])

alg.init_weights()
for param_q, param_k in zip(alg.encoder_q.parameters(),
alg.encoder_k.parameters()):
assert torch.equal(param_q, param_k)
assert param_k.requires_grad is False

fake_input = torch.randn((2, 3, 224, 224))
with pytest.raises(AssertionError):
fake_out = alg.forward_train(fake_input)
Expand Down