# Contrastive Losse in MoCo

## Constrastive Loss

> We cite the texts from [MoCo v3](https://arxiv.org/pdf/2104.02057.pdf)

> As common practice (e.g., [20, 10]), we take two crops for each image under random data augmentation. They are encoded by two encoders, $f_q$ and $f_k$, with output vectors $q$ and $k$. Intuitively, $q$ behaves like a “query” [20], and the goal of learning is to retrieve the corresponding “key”. This is formulated as minimizing a contrastive loss function [19]. We adopt the form of InfoNCE [34]:

$$
\mathcal{L}_{q}=-\log \frac{\exp \left(q \cdot k^{+} / \tau\right)}{\exp \left(q \cdot k^{+} / \tau\right)+\sum_{k^{-}} \exp \left(q \cdot k^{-} / \tau\right)}
$$

> Here $k_{+}$ is $f_k$’s output on the same image as $q$, known as $q$’s positive sample. The set $\{k^{−}\}$ consists of $f_k$’s outputs from other images, known as $q$’s negative samples. $\tau$ is a temperature hyper-parameter [45] for $l_2$-normalized $q$, $k$.

> Following [46, 22, 2, 10], in MoCo v3 we use the keys `that naturally co-exist in the same batch`. We abandon the memory queue [20], which we find has diminishing gain if the batch is sufficiently large (e.g., 4096). With this simplification, the contrastive loss in (1) can be implemented by a few lines of code: see `ctr(q, k)` in Alg. 1. We adopt a symmetrized loss [18, 7, 13]: `ctr(q1, k2)+ctr(q2, k1)`.

## Code Samples in MoCo v3

### Contrastive loss

```python
def contrastive_loss(self, q, k):
        # normalize
        q = nn.functional.normalize(q, dim=1)
        k = nn.functional.normalize(k, dim=1)
        # gather all targets
        k = concat_all_gather(k)
        # Einstein sum is more intuitive
        logits = torch.einsum('nc,mc->nm', [q, k]) / self.T
        N = logits.shape[0]  # batch size per GPU
        labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()
        return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T)

```

It requires the following function `concat_all_gather(tensor)` or `torch.distributed.all_gather()` more specifically.

In [1]:
# CCJ's Note: 
# > see: https://amsword.medium.com/gradient-backpropagation-with-torch-distributed-all-gather-9f3941a381f8
# Typically, each GPU can calculate the loss of g（x） and 
# then the auto grad will do the job to calculate the gradient for 
# all parameters. Normally, this is paired with DistributedDataParallel, 
# which will do the averaging automatically. In this case, we don’t need 
# to gather other GPU’s output. But, what if the loss is not separable?
# This is why we use `torch.distributed.all_gather()`.

# utils:
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    
    # CCJ's Note: Gathers tensors from the whole group in a list.
    # > see: https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

NameError: name 'torch' is not defined

## My Notes

### Questions
- 1) Why requires `torch.distributed.all_gather()`?
- 2) How to understand and implement the statement that `in MoCo v3 we use the keys that naturally co-exist in the same batch`?

### My Understanding

- To collect all the keys (which are used as targets fed into the cross entroy loss function) from the whole group into a list. Since the contrastive loss is defined as query $q_i$ vs. all the keys (i.e., key $k_i$ among all the processes or GPUs $i$, where $i=0,1,2, \dots , N-1$). 

- For example, we have two nodes (i.e., process groups), and each with 2 GPUs (with local rank $0, 1$), then the world size is $2*2=4$, and the global rank $0, 1, 2, 3$.

- Syntax: torch.distributed.all_gather(tensor_list, tensor, group=None, async_op=False):

```python
# All tensors below are of torch.int64 dtype.
# We have 2 process groups, 2 ranks.
tensor_list = [torch.zeros(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1

tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1

dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1
```

- As for this contrastive loss:

```python
def contrastive_loss(self, q, k):
        # normalize
        q = nn.functional.normalize(q, dim=1)
        k = nn.functional.normalize(k, dim=1)
        # gather all targets
        k = concat_all_gather(k)
        # Einstein sum is more intuitive
        logits = torch.einsum('nc,mc->nm', [q, k]) / self.T
        N = logits.shape[0]  # batch size per GPU
        labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()
        return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T)

```

where, pay attention to:

``` python
# gather all targets
k = concat_all_gather(k) # calling torch.distributed.all_gather();
```

For convenience, just assuming batch size of $N=2$ for each GPU among those 4 GPUs. The current process, for example, is in GPU rank $2$, when we calculate the contrastive loss for query $q_2$ vs all the keys $\{k_0, k_1, k_2, k_3\}$, only $q_2$ and $k_2$ as a positive pair (SHOULD have samle label) due to coming from the same sample, and $q_2$ vs others keys as the negative pairs (SHOULD have different labels).

With this "nice" design by specifying the labels with batch size $N$ and global rank, the query $q$ and key $k$ which co-exist in the same batch will have the same label, and hence as positive pair.

```python
# batch size N=2, world_size = 4, 2 nodes, each has 2 GPUs
labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()
>>> tensor_list
[tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7])] # Rank 0
[tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7])] # Rank 1
[tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7])] # Rank 2
[tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7])] # Rank 3
```



### Complete Code in MoCo v3

In [3]:
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn


class MoCo(nn.Module):
    """
    Build a MoCo model with a base encoder, a momentum encoder, and two MLPs
    https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0):
        """
        dim: feature dimension (default: 256)
        mlp_dim: hidden dimension in MLPs (default: 4096)
        T: softmax temperature (default: 1.0)
        """
        super(MoCo, self).__init__()

        self.T = T

        # build encoders
        self.base_encoder = base_encoder(num_classes=mlp_dim)
        self.momentum_encoder = base_encoder(num_classes=mlp_dim)

        self._build_projector_and_predictor_mlps(dim, mlp_dim)

        for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
            param_m.data.copy_(param_b.data)  # initialize
            param_m.requires_grad = False  # not update by gradient

    def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
        mlp = []
        for l in range(num_layers):
            dim1 = input_dim if l == 0 else mlp_dim
            dim2 = output_dim if l == num_layers - 1 else mlp_dim

            mlp.append(nn.Linear(dim1, dim2, bias=False))

            if l < num_layers - 1:
                mlp.append(nn.BatchNorm1d(dim2))
                mlp.append(nn.ReLU(inplace=True))
            elif last_bn:
                # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
                # for simplicity, we further removed gamma in BN
                mlp.append(nn.BatchNorm1d(dim2, affine=False))

        return nn.Sequential(*mlp)

    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        pass

    @torch.no_grad()
    def _update_momentum_encoder(self, m):
        """Momentum update of the momentum encoder"""
        for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
            param_m.data = param_m.data * m + param_b.data * (1. - m)

    def contrastive_loss(self, q, k):
        # normalize
        q = nn.functional.normalize(q, dim=1)
        k = nn.functional.normalize(k, dim=1)
        # gather all targets
        k = concat_all_gather(k)
        # Einstein sum is more intuitive
        logits = torch.einsum('nc,mc->nm', [q, k]) / self.T
        N = logits.shape[0]  # batch size per GPU
        labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()
        return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T)

    def forward(self, x1, x2, m):
        """
        Input:
            x1: first views of images
            x2: second views of images
            m: moco momentum
        Output:
            loss
        """

        # compute features
        q1 = self.predictor(self.base_encoder(x1))
        q2 = self.predictor(self.base_encoder(x2))

        with torch.no_grad():  # no gradient
            self._update_momentum_encoder(m)  # update the momentum encoder

            # compute momentum features as targets
            # CCJ's Note: MoCo v3 used the keys that naturally co-exist 
            # in the same batch
            k1 = self.momentum_encoder(x1)
            k2 = self.momentum_encoder(x2)

        return self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1)


class MoCo_ResNet(MoCo):
    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        hidden_dim = self.base_encoder.fc.weight.shape[1]
        del self.base_encoder.fc, self.momentum_encoder.fc # remove original fc layer

        # projectors
        self.base_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)
        self.momentum_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)

        # predictor
        self.predictor = self._build_mlp(2, dim, mlp_dim, dim, False)


class MoCo_ViT(MoCo):
    def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
        hidden_dim = self.base_encoder.head.weight.shape[1]
        del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer

        # projectors
        self.base_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim)
        self.momentum_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim)

        # predictor
        self.predictor = self._build_mlp(2, dim, mlp_dim, dim)