In [6]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.models as models


In [7]:
BATCH_SIZE = 64
H,W = 224, 224
Z_DIM = 128

In [8]:
base_encoder = models.__dict__["resnet50"]
# base_encoder()
# base_encoder = base_encoder(num_classes = Z_DIM)
# base_encoder.fc = nn.Sequential(nn.Linear(2048, 2048), nn.ReLU(), nn.Linear(in_features= 2048, out_features= 128))
# for p in base_encoder.parameters():
#     print(p)

In [35]:
class MoCoV2(nn.Module):
    def __init__(self, base_encoder, similarity_dim = 128, q_size=65536, momentum=0.999, temperature=0.07, ):
        super(MoCoV2, self).__init__()
        self.K = q_size
        self.m = momentum
        self.T = temperature
        
        self.q_enc = base_encoder() # torchvision.models.__dict__['resnet50']
        self.k_enc = base_encoder() # torchvision.models.__dict__['resnet50']
        
        # for mlp
        in_features = self.q_enc.fc.weight.size(1)
        self.q_enc.fc = nn.Sequential(nn.Linear(in_features, in_features), nn.ReLU(), nn.Linear(in_features= in_features, out_features= similarity_dim))
        self.k_enc.fc = nn.Sequential(nn.Linear(in_features, in_features), nn.ReLU(), nn.Linear(in_features= in_features, out_features= similarity_dim))
        
        # initialize k_enc params
        for q, k in zip(self.q_enc.parameters(), self.k_enc.parameters()):
            k.data.copy_(q.data)
            k.requires_grad = False
        
        # initialize dynamic queue : (simil_dim,K)
        self.register_buffer("queue", torch.randn(similarity_dim, self.K))
        self.register_buffer("queue_pointer", torch.tensor(0,dtype= torch.long))
        
        self.queue = F.normalize(self.queue, dim = 0)
        
    def forward(self, img_q, img_k):
        """
        no need to shuffle
        @param img_q : from moco_dataloader (bs, 3, 224, 224)
        @param img_k : from moco_dataloader (bs, 3, 224, 224)
        """
        self.momentum_update()
        
        simil_vec_q = self.q_enc(img_q) # (bs, similarity_dim)
        simil_vec_q = F.normalize(simil_vec_q, dim = 1)
        
        simil_vec_k = self.k_enc(img_k) # (bs, similarity_dim)
        simil_vec_k = F.normalize(simil_vec_k, dim = 1)
        # no grad to key
        simil_vec_k = simil_vec_k.detach() 
        
        # positive logits : (bs,1) = (bs, 1, simil_dim) * (bs, simil_dim, 1)
        l_pos = torch.bmm(simil_vec_q.unsqueeze(1), simil_vec_k.unsqueeze(2))
#         l_pos.requires_grad_()
        print("l_pos.requires_grad",l_pos.requires_grad)
        
        # negative logits : (bs,K) = (bs, simil_dim) * (simil_dim, K)
        l_neg = torch.mm(simil_vec_q, self.queue.clone().detach())
        print("l_neg.requires_grad",l_neg.requires_grad)
        
        # output, y
        logits = torch.cat([l_pos.squeeze(2), l_neg], dim = 1, ) 
        logits /= self.T
        labels = torch.zeros(logits.size(0)) #(bs,)
        print("logits.requires_grad",logits.requires_grad)
        
        self.replace_queue_with(simil_vec_k)
        
        return logits, labels
    
    def momentum_update(self,):
        for q_param, k_param in zip(self.q_enc.parameters(), self.k_enc.parameters()):
            k_param.data = self.m * k_param.data + (1-self.m)*q_param
    
    
    def replace_queue_with(self, simil_vec_k):
        """
        self.queue : (simil_dim,K)
        simil_vec_k : (bs, simil_dim)
        """
        bs = simil_vec_k.size(0)
        self.queue[:,self.queue_pointer:self.queue_pointer+bs] = simil_vec_k.T
        self.queue_pointer = (self.queue_pointer + bs) % self.K

In [36]:
ex_input1 = torch.randn(BATCH_SIZE, 3, H,W)
ex_input2 = torch.randn(BATCH_SIZE, 3, H,W)

moco_v2 = MoCoV2(base_encoder=models.__dict__["resnet50"])

In [37]:
logits, labels = moco_v2(ex_input1, ex_input2)
logits.size(), labels.size()

l_pos.requires_grad True
l_neg.requires_grad False
logits.requires_grad True


(torch.Size([64, 65537]), torch.Size([64]))

In [27]:
for p,v in moco_v2.named_parameters():
    print(p,v.requires_grad)
    


q_enc.conv1.weight True
q_enc.bn1.weight True
q_enc.bn1.bias True
q_enc.layer1.0.conv1.weight True
q_enc.layer1.0.bn1.weight True
q_enc.layer1.0.bn1.bias True
q_enc.layer1.0.conv2.weight True
q_enc.layer1.0.bn2.weight True
q_enc.layer1.0.bn2.bias True
q_enc.layer1.0.conv3.weight True
q_enc.layer1.0.bn3.weight True
q_enc.layer1.0.bn3.bias True
q_enc.layer1.0.downsample.0.weight True
q_enc.layer1.0.downsample.1.weight True
q_enc.layer1.0.downsample.1.bias True
q_enc.layer1.1.conv1.weight True
q_enc.layer1.1.bn1.weight True
q_enc.layer1.1.bn1.bias True
q_enc.layer1.1.conv2.weight True
q_enc.layer1.1.bn2.weight True
q_enc.layer1.1.bn2.bias True
q_enc.layer1.1.conv3.weight True
q_enc.layer1.1.bn3.weight True
q_enc.layer1.1.bn3.bias True
q_enc.layer1.2.conv1.weight True
q_enc.layer1.2.bn1.weight True
q_enc.layer1.2.bn1.bias True
q_enc.layer1.2.conv2.weight True
q_enc.layer1.2.bn2.weight True
q_enc.layer1.2.bn2.bias True
q_enc.layer1.2.conv3.weight True
q_enc.layer1.2.bn3.weight True
q_enc.la

In [15]:
for k in moco_v2.q_enc.state_dict():
    print(k)

conv1.weight
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.bn1.running_mean
layer1.0.bn1.running_var
layer1.0.bn1.num_batches_tracked
layer1.0.conv2.weight
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.0.bn2.running_mean
layer1.0.bn2.running_var
layer1.0.bn2.num_batches_tracked
layer1.0.conv3.weight
layer1.0.bn3.weight
layer1.0.bn3.bias
layer1.0.bn3.running_mean
layer1.0.bn3.running_var
layer1.0.bn3.num_batches_tracked
layer1.0.downsample.0.weight
layer1.0.downsample.1.weight
layer1.0.downsample.1.bias
layer1.0.downsample.1.running_mean
layer1.0.downsample.1.running_var
layer1.0.downsample.1.num_batches_tracked
layer1.1.conv1.weight
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.bn1.running_mean
layer1.1.bn1.running_var
layer1.1.bn1.num_batches_tracked
layer1.1.conv2.weight
layer1.1.bn2.weight
layer1.1.bn2.bias
layer1.1.bn2.running_mean
layer1.1.bn2.running_var
layer1.1.bn2.num_batches_tr

In [69]:
for q in moco_v2.q_enc.parameters():
    q.data = q.data + 1
    

In [71]:
k_params = []
for k in moco_v2.k_enc.parameters():
    k_params.append(k.data)

In [79]:
logits, labels = moco_v2(ex_input1, ex_input2)
labels

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [74]:
# 1*0.001 should be added : and yes, it's correct
for k, old_k in zip(moco_v2.k_enc.parameters(), k_params):
    print(k.data - old_k)

tensor([[[[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          ...,
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],

         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          ...,
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]],

         [[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
          [0.0010, 0.0010, 0.0010,  ..., 0

In [78]:
for k in moco_v2.k_enc.parameters():
    print(k.grad)

None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None


In [5]:
torch.zeros(1, dtype=torch.long)
torch.tensor(0, dtype = torch.long) + 128

tensor(128)