-
Notifications
You must be signed in to change notification settings - Fork 118
/
head.py
110 lines (91 loc) · 3.49 KB
/
head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from torch.nn import Module, Parameter
import math
import torch
def build_head(head_type,
embedding_size,
class_num,
m,
t_alpha,
h,
s,
):
if head_type == 'adaface':
head = AdaFace(embedding_size=embedding_size,
classnum=class_num,
m=m,
h=h,
s=s,
t_alpha=t_alpha,
)
else:
raise ValueError('not a correct head type', head_type)
return head
def l2_norm(input,axis=1):
norm = torch.norm(input,2,axis,True)
output = torch.div(input, norm)
return output
class AdaFace(Module):
def __init__(self,
embedding_size=512,
classnum=70722,
m=0.4,
h=0.333,
s=64.,
t_alpha=1.0,
):
super(AdaFace, self).__init__()
self.classnum = classnum
self.kernel = Parameter(torch.Tensor(embedding_size,classnum))
# initial kernel
self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
self.m = m
self.eps = 1e-3
self.h = h
self.s = s
# ema prep
self.t_alpha = t_alpha
self.register_buffer('t', torch.zeros(1))
self.register_buffer('batch_mean', torch.ones(1)*(20))
self.register_buffer('batch_std', torch.ones(1)*100)
print('\n\AdaFace with the following property')
print('self.m', self.m)
print('self.h', self.h)
print('self.s', self.s)
print('self.t_alpha', self.t_alpha)
def forward(self, embbedings, norms, label):
kernel_norm = l2_norm(self.kernel,axis=0)
cosine = torch.mm(embbedings,kernel_norm)
cosine = cosine.clamp(-1+self.eps, 1-self.eps) # for stability
safe_norms = torch.clip(norms, min=0.001, max=100) # for stability
safe_norms = safe_norms.clone().detach()
# update batchmean batchstd
with torch.no_grad():
mean = safe_norms.mean().detach()
std = safe_norms.std().detach()
self.batch_mean = mean * self.t_alpha + (1 - self.t_alpha) * self.batch_mean
self.batch_std = std * self.t_alpha + (1 - self.t_alpha) * self.batch_std
margin_scaler = (safe_norms - self.batch_mean) / (self.batch_std+self.eps) # 66% between -1, 1
margin_scaler = margin_scaler * self.h # 68% between -0.333 ,0.333 when h:0.333
margin_scaler = torch.clip(margin_scaler, -1, 1)
# ex: m=0.5, h:0.333
# range
# (66% range)
# -1 -0.333 0.333 1 (margin_scaler)
# -0.5 -0.166 0.166 0.5 (m * margin_scaler)
# g_angular
m_arc = torch.zeros(label.size()[0], cosine.size()[1], device=cosine.device)
m_arc.scatter_(1, label.reshape(-1, 1), 1.0)
g_angular = self.m * margin_scaler * -1
m_arc = m_arc * g_angular
theta = cosine.acos()
theta_m = torch.clip(theta + m_arc, min=self.eps, max=math.pi-self.eps)
cosine = theta_m.cos()
# g_additive
m_cos = torch.zeros(label.size()[0], cosine.size()[1], device=cosine.device)
m_cos.scatter_(1, label.reshape(-1, 1), 1.0)
g_add = self.m + (self.m * margin_scaler)
m_cos = m_cos * g_add
cosine = cosine - m_cos
# scale
scaled_cosine_m = cosine * self.s
return scaled_cosine_m