# Arcface vs. Cosface vs. Arcface+Cosface

In [1]:
import torch
from torch.nn.functional import linear, normalize

labels = torch.randint(low=0, high=10, size=(6,))
print(labels)
index = torch.where(labels!=-1)[0]
print(index)

tensor([4, 6, 0, 5, 5, 3])
tensor([0, 1, 2, 3, 4, 5])


In [2]:
batch_size = 6
num_class = 5

weights = torch.nn.Parameter(torch.FloatTensor(num_class, 512))
torch.nn.init.xavier_uniform_(weights)
embed_vec = torch.randn((batch_size,512))
logits = linear(normalize(embed_vec), normalize(weights)).clamp(-1,1)
print(logits)
print(logits.shape)
print()

labels = torch.randint(low=0, high=num_class, size=(batch_size,))
print(labels)
index = torch.where(labels!=-1)[0]
print(index)


tensor([[-0.0016, -0.0394, -0.0146,  0.0258, -0.0268],
        [ 0.0209, -0.0399,  0.0914, -0.0066,  0.0424],
        [-0.0661, -0.0771,  0.0389, -0.0413, -0.0049],
        [-0.0681,  0.0132, -0.0275,  0.0343,  0.0055],
        [-0.0756, -0.0029,  0.0225, -0.0044, -0.0222],
        [-0.0753, -0.0096,  0.1256,  0.0286, -0.0812]],
       grad_fn=<ClampBackward1>)
torch.Size([6, 5])

tensor([1, 0, 0, 2, 4, 1])
tensor([0, 1, 2, 3, 4, 5])


In [3]:
target = logits[index, labels[index].view(-1)]
print(target)

tensor([-0.0394,  0.0209, -0.0661, -0.0275, -0.0222, -0.0096],
       grad_fn=<IndexBackward0>)


In [4]:
## Arcface(0.5)
from marginloss import CombinedMarginLoss
softmax = CombinedMarginLoss(s=10, m1=1.0, m2=0.5, m3=0.0)
softmax(logits, labels)

tensor([[-0.0159, -5.1361, -0.1460,  0.2579, -0.2678],
        [-4.6100, -0.3994,  0.9139, -0.0662,  0.4240],
        [-5.3635, -0.7706,  0.3888, -0.4132, -0.0485],
        [-0.6811,  0.1322, -5.0339,  0.3428,  0.0551],
        [-0.7555, -0.0288,  0.2246, -0.0440, -4.9878],
        [-0.7528, -4.8778,  1.2565,  0.2857, -0.8119]], grad_fn=<MulBackward0>)

In [5]:
## Cosface(0.35)
softmax = CombinedMarginLoss(s=10, m1=0.0, m2=0.0, m3=0.35)
softmax(logits, labels)

tensor([[-0.0159, -8.6361, -0.1460,  0.2579, -0.2678],
        [-8.1100, -0.3994,  0.9139, -0.0662,  0.4240],
        [-8.8635, -0.7706,  0.3888, -0.4132, -0.0485],
        [-0.6811,  0.1322, -8.5339,  0.3428,  0.0551],
        [-0.7555, -0.0288,  0.2246, -0.0440, -8.4878],
        [-0.7528, -8.3778,  1.2565,  0.2857, -0.8119]], grad_fn=<MulBackward0>)

In [6]:
## Arcface(0.5) + Cosface(0.35)
softmax = CombinedMarginLoss(s=10, m1=0.5, m2=0.5, m3=0.35)
softmax(logits, labels)

tensor([[ -0.0159, -13.4960,  -0.1460,   0.2579,  -0.2678],
        [-13.4221,  -0.3994,   0.9139,  -0.0662,   0.4240],
        [-14.7606,  -0.7706,   0.3888,  -0.4132,  -0.0485],
        [ -0.6811,   0.1322, -13.4883,   0.3428,   0.0551],
        [ -0.7555,  -0.0288,   0.2246,  -0.0440, -13.4837],
        [ -0.7528, -13.4699,   1.2565,   0.2857,  -0.8119]],
       grad_fn=<MulBackward0>)

# ==================

In [None]:
from marginloss import CombinedMarginLoss
softmax = CombinedMarginLoss(s=2, m1=1.0, m2=0.5, m3=0.0)
softmax(logits, labels)

In [3]:
from dataset import get_dataloader

train_loader, x, valid_loader, y = get_dataloader(local_rank=0, batch_size=12)
print(len(train_loader))
print(len(valid_loader))

train dataset length:  5,822,653
valid dataset length:  13,233
485222
1103


In [4]:
from model import get_model
ROOT_DIR = "/home/ljj0512/private/workspace/CV-project/Computer-Vision-Project/train"
model = get_model(ROOT_DIR)

=> the number of model parameters: 24,025,600


In [5]:
for input, labels in train_loader:
    print(input.shape)
    print(labels.shape)
    break

torch.Size([12, 3, 112, 112])
torch.Size([12])


In [10]:
inputs = torch.randn((6,3,112,112))
labels = torch.randint(low=0,high=100,size=(6,))
print(inputs.shape)
print(labels.shape)
print()

margin_loss = CombinedMarginLoss(64, 1.0, 0.5, 0.0)
fc_softmax = FCSoftmax(margin_loss, 512, 85742)
criterion = nn.CrossEntropyLoss()

# with torch.no_grad():
model.train()
fc_softmax.train()
embed_vec = model(inputs)
print(embed_vec.shape)
logits = fc_softmax(embed_vec, labels)
print(logits.shape)
_, predicted = torch.max(logits.data, dim=1)
print(predicted.shape)
loss = criterion(logits, labels)
print(loss.item())
print("finish")
# model = PartialFC(margin_loss, 512, 93431, 1.0, True)

torch.Size([6, 3, 112, 112])
torch.Size([6])

torch.Size([6, 512])
torch.Size([6, 85742])
torch.Size([6])
46.125762939453125
finish


In [8]:
print('=> the number of model parameters: {:,}'.format(sum([p.data.nelement() for p in fc_softmax.parameters()])))


=> the number of model parameters: 43,899,904


# =================================

In [7]:
import torch
import collections
from typing import Callable
from torch import distributed
from torch.nn.functional import linear, normalize

class PartialFC(torch.nn.Module):
    """
    https://arxiv.org/abs/2203.15565
    A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
    When sample rate less than 1, in each iteration, positive class centers and a random subset of
    negative class centers are selected to compute the margin-based softmax loss, all class
    centers are still maintained throughout the whole training process, but only a subset is
    selected and updated in each iteration.
    .. note::
        When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
    Example:
    --------
    >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
    >>> for img, labels in data_loader:
    >>>     embeddings = net(img)
    >>>     loss = module_pfc(embeddings, labels, optimizer)
    >>>     loss.backward()
    >>>     optimizer.step()
    """
    _version = 1 
    def __init__(
        self,
        margin_loss: Callable,
        embedding_size: int,
        num_classes: int,
        sample_rate: float = 1.0,
        fp16: bool = False,
    ):
        """
        Paramenters:
        -----------
        embedding_size: int
            The dimension of embedding, required
        num_classes: int
            Total number of classes, required
        sample_rate: float
            The rate of negative centers participating in the calculation, default is 1.0.
        """
        super(PartialFC, self).__init__()
        # assert (
        #     distributed.is_initialized()
        # ), "must initialize distributed before create this"
        self.rank = distributed.get_rank()
        self.world_size = distributed.get_world_size()

        self.dist_cross_entropy = DistCrossEntropy()
        self.embedding_size = embedding_size
        self.sample_rate: float = sample_rate
        self.fp16 = fp16
        self.num_local: int = num_classes // self.world_size + int(
            self.rank < num_classes % self.world_size
        )
        self.class_start: int = num_classes // self.world_size * self.rank + min(
            self.rank, num_classes % self.world_size
        )
        self.num_sample: int = int(self.sample_rate * self.num_local)
        self.last_batch_size: int = 0
        self.weight: torch.Tensor
        self.weight_mom: torch.Tensor
        self.weight_activated: torch.nn.Parameter
        self.weight_activated_mom: torch.Tensor
        self.is_updated: bool = True
        self.init_weight_update: bool = True

        if self.sample_rate < 1:
            self.register_buffer("weight",
                tensor=torch.normal(0, 0.01, (self.num_local, embedding_size)))
            self.register_buffer("weight_mom",
                tensor=torch.zeros_like(self.weight))
            self.register_parameter("weight_activated",
                param=torch.nn.Parameter(torch.empty(0, 0)))
            self.register_buffer("weight_activated_mom",
                tensor=torch.empty(0, 0))
            self.register_buffer("weight_index",
                tensor=torch.empty(0, 0))
        else:
            self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))

        # margin_loss
        if isinstance(margin_loss, Callable):
            self.margin_softmax = margin_loss
        else:
            raise

    @torch.no_grad()
    def sample(self, 
        labels: torch.Tensor, 
        index_positive: torch.Tensor, 
        optimizer: torch.optim.Optimizer):
        """
        This functions will change the value of labels
        Parameters:
        -----------
        labels: torch.Tensor
            pass
        index_positive: torch.Tensor
            pass
        optimizer: torch.optim.Optimizer
            pass
        """
        positive = torch.unique(labels[index_positive], sorted=True).cuda()
        if self.num_sample - positive.size(0) >= 0:
            perm = torch.rand(size=[self.num_local]).cuda()
            perm[positive] = 2.0
            index = torch.topk(perm, k=self.num_sample)[1].cuda()
            index = index.sort()[0].cuda()
        else:
            index = positive
        self.weight_index = index

        labels[index_positive] = torch.searchsorted(index, labels[index_positive])
        
        self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index])
        self.weight_activated_mom = self.weight_mom[self.weight_index]
        
        if isinstance(optimizer, torch.optim.SGD):
            # TODO the params of partial fc must be last in the params list
            optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None)
            optimizer.param_groups[-1]["params"][0] = self.weight_activated
            optimizer.state[self.weight_activated][
                "momentum_buffer"
            ] = self.weight_activated_mom
        else:
            raise

    @torch.no_grad()
    def update(self):
        """ partial weight to global
        """
        if self.init_weight_update:
            self.init_weight_update = False
            return

        if self.sample_rate < 1:
            self.weight[self.weight_index] = self.weight_activated
            self.weight_mom[self.weight_index] = self.weight_activated_mom


    def forward(
        self,
        local_embeddings: torch.Tensor,
        local_labels: torch.Tensor,
        optimizer: torch.optim.Optimizer,
    ):
        """
        Parameters:
        ----------
        local_embeddings: torch.Tensor
            feature embeddings on each GPU(Rank).
        local_labels: torch.Tensor
            labels on each GPU(Rank).
        Returns:
        -------
        loss: torch.Tensor
            pass
        """
        local_labels.squeeze_()
        local_labels = local_labels.long()
        self.update()

        batch_size = local_embeddings.size(0)
        if self.last_batch_size == 0:
            self.last_batch_size = batch_size
        assert self.last_batch_size == batch_size, (
            "last batch size do not equal current batch size: {} vs {}".format(
            self.last_batch_size, batch_size))

        _gather_embeddings = [
            torch.zeros((batch_size, self.embedding_size)).cuda()
            for _ in range(self.world_size)
        ]
        _gather_labels = [
            torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
        ]
        _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
        distributed.all_gather(_gather_labels, local_labels)

        embeddings = torch.cat(_list_embeddings)
        labels = torch.cat(_gather_labels)

        labels = labels.view(-1, 1)
        index_positive = (self.class_start <= labels) & (
            labels < self.class_start + self.num_local
        )
        labels[~index_positive] = -1
        labels[index_positive] -= self.class_start

        if self.sample_rate < 1:
            self.sample(labels, index_positive, optimizer)

        with torch.cuda.amp.autocast(self.fp16):
            norm_embeddings = normalize(embeddings)
            norm_weight_activated = normalize(self.weight_activated)
            logits = linear(norm_embeddings, norm_weight_activated)
        if self.fp16:
            logits = logits.float()
        logits = logits.clamp(-1, 1)

        logits = self.margin_softmax(logits, labels)
        loss = self.dist_cross_entropy(logits, labels)
        return loss

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        if destination is None: 
            destination = collections.OrderedDict()
            destination._metadata = collections.OrderedDict()

        for name, module in self._modules.items():
            if module is not None:
                module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars)
        if self.sample_rate < 1:
            destination["weight"] = self.weight.detach()
        else:
            destination["weight"] = self.weight_activated.data.detach()
        return destination

    def load_state_dict(self, state_dict, strict: bool = True):
        if self.sample_rate < 1:
            self.weight = state_dict["weight"].to(self.weight.device)
            self.weight_mom.zero_()
            self.weight_activated.data.zero_()
            self.weight_activated_mom.zero_()
            self.weight_index.zero_()
        else:
            self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device)

class DistCrossEntropyFunc(torch.autograd.Function):
    """
    CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
    Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
    """

    @staticmethod
    def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
        """ """
        batch_size = logits.size(0)
        # for numerical stability
        max_logits, _ = torch.max(logits, dim=1, keepdim=True)
        # local to global
        distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
        logits.sub_(max_logits)
        logits.exp_()
        sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
        # local to global
        distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
        logits.div_(sum_logits_exp)
        index = torch.where(label != -1)[0]
        # loss
        loss = torch.zeros(batch_size, 1, device=logits.device)
        loss[index] = logits[index].gather(1, label[index])
        distributed.all_reduce(loss, distributed.ReduceOp.SUM)
        ctx.save_for_backward(index, logits, label)
        return loss.clamp_min_(1e-30).log_().mean() * (-1)

    @staticmethod
    def backward(ctx, loss_gradient):
        """
        Args:
            loss_grad (torch.Tensor): gradient backward by last layer
        Returns:
            gradients for each input in forward function
            `None` gradients for one-hot label
        """
        (
            index,
            logits,
            label,
        ) = ctx.saved_tensors
        batch_size = logits.size(0)
        one_hot = torch.zeros(
            size=[index.size(0), logits.size(1)], device=logits.device
        )
        one_hot.scatter_(1, label[index], 1)
        logits[index] -= one_hot
        logits.div_(batch_size)
        return logits * loss_gradient.item(), None


class DistCrossEntropy(torch.nn.Module):
    def __init__(self):
        super(DistCrossEntropy, self).__init__()

    def forward(self, logit_part, label_part):
        return DistCrossEntropyFunc.apply(logit_part, label_part)


class AllGatherFunc(torch.autograd.Function):
    """AllGather op with gradient backward"""

    @staticmethod
    def forward(ctx, tensor, *gather_list):
        gather_list = list(gather_list)
        distributed.all_gather(gather_list, tensor)
        return tuple(gather_list)

    @staticmethod
    def backward(ctx, *grads):
        grad_list = list(grads)
        rank = distributed.get_rank()
        grad_out = grad_list[rank]

        dist_ops = [
            distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
            if i == rank
            else distributed.reduce(
                grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
            )
            for i in range(distributed.get_world_size())
        ]
        for _op in dist_ops:
            _op.wait()

        grad_out *= len(grad_list)  # cooperate with distributed loss function
        return (grad_out, *[None for _ in range(len(grad_list))])


AllGather = AllGatherFunc.apply
# Footer
# © 2022 GitHub, Inc.
# Footer navigation
# Terms
# Privacy
# Security
# Status
# Docs
# Contact GitHub
# Pricing
# API
# Training
# Blog
# About
# insightface/partial_fc.py at master · deepinsight/insightface

In [None]:
import torch
import torch.nn as nn
from typing import Callable
from torch.nn.functional import linear, normalize

class FCSoftmax(nn.Module):
    def __init__(self, margin_softmax: Callable, embed_size: int, num_classes: int):
        super(FCSoftmax, self).__init__()
        self.margin_softmax = margin_softmax
        self.weights = nn.Parameter(torch.FloatTensor(num_classes, embed_size))
        nn.init.xavier_uniform_(self.weights)
    
    def forward(self, embed_vec: torch.Tensor, labels: torch.Tensor):
        logits = linear(normalize(embed_vec), normalize(self.weights)).clamp(-1,1)
        logits = self.margin_softmax(logits, labels)
        return logits



In [8]:
import torch
import math


class CombinedMarginLoss(torch.nn.Module):
    def __init__(self, 
                 s, 
                 m1,
                 m2,
                 m3,
                 interclass_filtering_threshold=0):
        super().__init__()
        self.s = s
        self.m1 = m1
        self.m2 = m2
        self.m3 = m3
        self.interclass_filtering_threshold = interclass_filtering_threshold
        
        # For ArcFace
        self.cos_m = math.cos(self.m2)
        self.sin_m = math.sin(self.m2)
        self.theta = math.cos(math.pi - self.m2)
        self.sinmm = math.sin(math.pi - self.m2) * self.m2
        self.easy_margin = False


    def forward(self, logits, labels):
        index_positive = torch.where(labels != -1)[0]

        if self.interclass_filtering_threshold > 0:
            with torch.no_grad():
                dirty = logits > self.interclass_filtering_threshold
                dirty = dirty.float()
                mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device)
                mask.scatter_(1, labels[index_positive], 0)
                dirty[index_positive] *= mask
                tensor_mul = 1 - dirty    
            logits = tensor_mul * logits

        target_logit = logits[index_positive, labels[index_positive].view(-1)]

        if self.m1 == 1.0 and self.m3 == 0.0:
            sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
            cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m  # cos(target+margin)
            if self.easy_margin:
                final_target_logit = torch.where(
                    target_logit > 0, cos_theta_m, target_logit)
            else:
                final_target_logit = torch.where(
                    target_logit > self.theta, cos_theta_m, target_logit - self.sinmm)
            logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
            logits = logits * self.s
        
        elif self.m3 > 0:
            final_target_logit = target_logit - self.m3
            logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
            logits = logits * self.s
        else:
            raise        

        return logits

class ArcFace(torch.nn.Module):
    """ ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
    """
    def __init__(self, s=64.0, margin=0.5):
        super(ArcFace, self).__init__()
        self.scale = s
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.theta = math.cos(math.pi - margin)
        self.sinmm = math.sin(math.pi - margin) * margin
        self.easy_margin = False


    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        index = torch.where(labels != -1)[0]
        target_logit = logits[index, labels[index].view(-1)]

        sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
        cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m  # cos(target+margin)
        if self.easy_margin:
            final_target_logit = torch.where(
                target_logit > 0, cos_theta_m, target_logit)
        else:
            final_target_logit = torch.where(
                target_logit > self.theta, cos_theta_m, target_logit - self.sinmm)

        logits[index, labels[index].view(-1)] = final_target_logit
        logits = logits * self.scale
        return logits


class CosFace(torch.nn.Module):
    def __init__(self, s=64.0, m=0.40):
        super(CosFace, self).__init__()
        self.s = s
        self.m = m

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        index = torch.where(labels != -1)[0]
        target_logit = logits[index, labels[index].view(-1)]
        final_target_logit = target_logit - self.m
        logits[index, labels[index].view(-1)] = final_target_logit
        logits = logits * self.s
        return logits

In [None]:
import torch
from typing import Callable
from torch.nn.functional import linear, normalize

class FCSoftmax(nn.Module):
    def __init__(self, margin_softmax: Callable, embed_size: int, num_classes: int):
        super(FCSoftmax, self).__init__()
        self.margin_softmax = margin_softmax
        self.weights = torch.nn.Parameter(torch.FloatTensor(num_classes, embed_size))
        nn.init.xavier_uniform_(self.weights)
    
    def forward(self, embed_vec: torch.Tensor, labels: torch.Tensor):
        logits = linear(normalize(embed_vec), normalize(self.weights)).clamp(-1,1)
        logits = self.margin_softmax(logits, labels)
        return logits



# ====================================

In [9]:
distributed.is_initialized()

False

In [10]:
margin_loss = CombinedMarginLoss(64, 1.0, 0.5, 0.0)

model = PartialFC(margin_loss, 512, 93431, 1.0, True)

# model = torch.nn.Parameter(x)
# print(model.shape)

# print('=> the number of model parameters: {:,}'.format(sum([p.data.nelement() for p in model.parameters()])))

RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.