In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from lib.loss.sphereface2 import *

import numpy as np
import torch
import random

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_seed(3022)

# device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device = 'cpu'
device

'cpu'

In [10]:
batch_size = 32
num_classes = 600
r = 40
m = 0.4
l = 0.7
t = 3.
margin_type = 'A'

In [11]:
x = torch.rand((batch_size, num_classes))
y = torch.zeros(batch_size).random_(num_classes).long()
y

tensor([348, 264, 411,  39, 369, 410, 172, 121, 278, 259, 305, 555,  52, 238,
        305,  96,  12, 559, 582, 234, 450, 192, 103, 112, 352, 585, 305,  82,
        211, 568, 132, 529])

In [12]:
self = RefSphereFace2(alpha=l, m=m, r=r, t=t, magn_type=margin_type)
loss_gt, logits_gt, weight_gt = self(x, y)
loss_gt, logits_gt, logits_gt.sum(), weight_gt, weight_gt.sum()

(tensor(0.0570),
 tensor([[ -6.5366,  -9.8331, -26.3281,  ...,  24.8107, -25.9402,   5.2875],
         [ -6.4215,  -2.8220,  14.3189,  ..., -22.3555,  21.4496, -26.0347],
         [ -2.2130, -21.4492,   9.2286,  ...,   5.6058, -13.0432, -14.9139],
         ...,
         [ -7.1829,   6.1392, -24.0033,  ..., -18.6522, -29.0807, -29.6197],
         [-19.6523,  -8.8938,  15.3108,  ...,  17.3855, -13.8753, -21.4755],
         [-28.6364,  -4.9285,  25.8206,  ..., -14.2998,  31.0264, -12.2747]]),
 tensor(-51663.3281),
 tensor([[0.0075, 0.0075, 0.0075,  ..., 0.0075, 0.0075, 0.0075],
         [0.0075, 0.0075, 0.0075,  ..., 0.0075, 0.0075, 0.0075],
         [0.0075, 0.0075, 0.0075,  ..., 0.0075, 0.0075, 0.0075],
         ...,
         [0.0075, 0.0075, 0.0075,  ..., 0.0075, 0.0075, 0.0075],
         [0.0075, 0.0075, 0.0075,  ..., 0.0075, 0.0075, 0.0075],
         [0.0075, 0.0075, 0.0075,  ..., 0.0075, 0.0075, 0.0075]]),
 tensor(144.3200))

In [13]:
criterion = SphereLoss(lamb=l, margin=m, scale=r, t=t, margin_type=margin_type)
loss, logits, weight = criterion(x, y)
loss, logits, logits.sum(), weight, weight.sum()

(tensor(0.0570),
 tensor([[ -6.5366,  -9.8331, -26.3281,  ...,  24.8107, -25.9402,   5.2875],
         [ -6.4215,  -2.8220,  14.3189,  ..., -22.3555,  21.4496, -26.0347],
         [ -2.2130, -21.4492,   9.2286,  ...,   5.6058, -13.0432, -14.9139],
         ...,
         [ -7.1829,   6.1392, -24.0033,  ..., -18.6522, -29.0807, -29.6197],
         [-19.6523,  -8.8938,  15.3108,  ...,  17.3855, -13.8753, -21.4755],
         [-28.6364,  -4.9285,  25.8206,  ..., -14.2998,  31.0264, -12.2747]]),
 tensor(-51663.3281),
 tensor([[0.0075, 0.0075, 0.0075,  ..., 0.0075, 0.0075, 0.0075],
         [0.0075, 0.0075, 0.0075,  ..., 0.0075, 0.0075, 0.0075],
         [0.0075, 0.0075, 0.0075,  ..., 0.0075, 0.0075, 0.0075],
         ...,
         [0.0075, 0.0075, 0.0075,  ..., 0.0075, 0.0075, 0.0075],
         [0.0075, 0.0075, 0.0075,  ..., 0.0075, 0.0075, 0.0075],
         [0.0075, 0.0075, 0.0075,  ..., 0.0075, 0.0075, 0.0075]]),
 tensor(144.3200))

In [14]:
index_y = torch.arange(len(y))
cos_theta = x

theta_m = torch.acos(cos_theta.clamp(-1+1e-5, 1.-1e-5))
theta_m.scatter_(1, y.view(-1, 1), self.m, reduce='add')
theta_m.clamp_(1e-5, 3.14159)

g_cos_theta = torch.cos(theta_m)
g_cos_theta = 2. * ((g_cos_theta + 1.) / 2.).pow(self.t) - 1.
g_cos_theta, g_cos_theta.sum()

(tensor([[-0.1634, -0.2458, -0.6582,  ...,  0.6203, -0.6485,  0.1322],
         [-0.1605, -0.0706,  0.3580,  ..., -0.5589,  0.5362, -0.6509],
         [-0.0553, -0.5362,  0.2307,  ...,  0.1401, -0.3261, -0.3728],
         ...,
         [-0.1796,  0.1535, -0.6001,  ..., -0.4663, -0.7270, -0.7405],
         [-0.4913, -0.2223,  0.3828,  ...,  0.4346, -0.3469, -0.5369],
         [-0.7159, -0.1232,  0.6455,  ..., -0.3575,  0.7757, -0.3069]]),
 tensor(-1291.5833))

In [7]:
index_y = torch.arange(len(y))
cos_theta = x

one_hot = torch.zeros_like(cos_theta)
one_hot[index_y, y] = 1

cos_theta_m = torch.acos(cos_theta.clamp(-1+1e-5, 1.-1e-5))
# cos_theta_m.scatter_(1, y.view(-1, 1), self.m, reduce='add')
# cos_theta_m[index_y, y] += m
cos_theta_m += m * one_hot
cos_theta_m.clamp_max_(math.pi)
cos_theta_m = torch.cos(cos_theta_m)

g_cos_theta_m = g_func(cos_theta_m, t)
g_cos_theta_m, g_cos_theta_m.sum()

(tensor([[ 0.9808, -0.1933,  0.2304,  ...,  0.0757,  0.8295, -0.4010],
         [-0.3679, -0.5634, -0.1492,  ...,  0.3682,  0.7450, -0.1906],
         [ 0.0389, -0.1680,  0.4587,  ..., -0.7314, -0.0891, -0.0954],
         ...,
         [-0.6661,  0.4484, -0.4101,  ...,  0.8546, -0.6737, -0.7250],
         [ 0.0241, -0.3533,  0.4371,  ...,  0.6928, -0.6292,  0.0742],
         [-0.1420, -0.2916, -0.0419,  ..., -0.6848, -0.7042, -0.6977]]),
 tensor(-1208.6223))

In [8]:
# index_y = torch.arange(len(y))
# cos_theta = x

# cos_theta_m = torch.acos(cos_theta)
# cos_theta_m[index_y, y] = cos_theta_m[index_y, y] * m
# cos_theta_m[index_y, y].clamp_max(math.pi)
# cos_theta_m.size(), one_hot.size()