In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet18

In [None]:
# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    
    return output

![image](https://user-images.githubusercontent.com/44194558/154207714-c25c7c84-12b4-4702-89a2-c3eaad687369.png)


In [None]:
class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
        """
        dim: feature dimension (default: 128)
        K: queue size; # of negative keys (default: 65536)
        m: MoCo momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T
        
        # Encoder networks for Query & Key
        #self.encoder_q = base_encoder(num_classes=dim)
        #self.encoder_k = base_encoder(num_classes=dim)

        self.encoder_q = resnet18(pretrained=False, num_classes=128)  # 128 차원의 representation으로 encoding
        self.encoder_k = resnet18(pretrained=False, num_classes=128)

        if mlp:  # for MoCoV2
            # self.encoder_q.fc = Linear(in_features, out_features=dim=128)
            dim_mlp = self.encoder_q.fc.weight.shape[1]  # in_features의 차원
            self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)  # F.C layer에 앞쪽에 새로운 Linear layer, ReLU 추가
            self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)
        
        # theta_q, theta_k
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient (역전파를 통해 갱신되는 것은 encoder_q의 파라미터 뿐)

        # Queue 생성
        self.register_buffer("queue", torch.randn(dim, K))  # [128, 65536] - 128차원의 feature로 표현되는 65536개의 negative sample들
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))  # pointer (enqueue, dequeue용)
    
    # Momentum update for Key encoder
    @torch.no_grad()  # 역전파에 의한 갱신 x
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
        
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)  # pointer
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T  # 가장 과거의 mini-batch를 교체
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr       
    
    # Shuffling BN (Shuffling없이 BN을 적용시키면 성능 감소)
    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle
    
    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def forward(self, im_q, im_k):
        """
        Input:
            im_q: a batch of query images (randomly augmented version of X)
            im_k: a batch of key images  (another randomly augmented version of X)
        Output:
            logits, targets
        """

        # 1. Compute Query features
        q = self.encoder_q(im_q)  # queries: NxC (C=dim=128)
        q = nn.functional.normalize(q, dim=1)

        # 2. Compute Key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # (momentum)update the key encoder

            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # 3. Compute logits
        l_pos = torch.einsum('nc, nc -> n', [q, k]).unsqueeze(-1)  # positive logits: Nx1 ([N, 1, C] x [N, C, 1])
        l_neg = torch.einsum('nc, ck -> nk', [q, self.queue.clone().detach()])  # negative logits: NxK ([N, C] x [C, K], no gradient)

        logits = torch.cat([l_pos, l_neg], dim=1)  # concat
        logits /= self.T

        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels