-
Notifications
You must be signed in to change notification settings - Fork 33
/
proxynca.py
59 lines (49 loc) · 1.99 KB
/
proxynca.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from torch.nn import Parameter
import torch.nn.functional as F
def binarize_and_smooth_labels(T, nb_classes, smoothing_const = 0.1):
# Optional: BNInception uses label smoothing, apply it for retraining also
# "Rethinking the Inception Architecture for Computer Vision", p. 6
import sklearn.preprocessing
T = T.cpu().numpy()
T = sklearn.preprocessing.label_binarize(
T, classes = range(0, nb_classes)
)
T = T * (1 - smoothing_const)
T[T == 0] = smoothing_const / (nb_classes - 1)
T = torch.FloatTensor(T).cuda()
return T
class ProxyNCA(torch.nn.Module):
def __init__(self,
nb_classes,
sz_embedding,
smoothing_const = 0.1,
scaling_x = 1,
scaling_p = 3
):
torch.nn.Module.__init__(self)
# initialize proxies s.t. norm of each proxy ~1 through div by 8
# i.e. proxies.norm(2, dim=1)) should be close to [1,1,...,1]
# TODO: use norm instead of div 8, because of embedding size
self.proxies = Parameter(torch.randn(nb_classes, sz_embedding) / 8)
self.smoothing_const = smoothing_const
self.scaling_x = scaling_x
self.scaling_p = scaling_p
def forward(self, X, T):
P = F.normalize(self.proxies, p = 2, dim = -1) * self.scaling_p
X = F.normalize(X, p = 2, dim = -1) * self.scaling_x
D = torch.cdist(X, P) ** 2
T = binarize_and_smooth_labels(T, len(P), self.smoothing_const)
# note that compared to proxy nca, positive included in denominator
loss = torch.sum(-T * F.log_softmax(-D, -1), -1)
return loss.mean()
if __name__ == '__main__':
import random
nb_classes = 100
sz_batch = 32
sz_embedding = 64
X = torch.randn(sz_batch, sz_embedding).cuda()
P = torch.randn(nb_classes, sz_embedding).cuda()
T = torch.randint(low=0, high=nb_classes, size=[sz_batch]).cuda()
criterion = ProxyNCA(nb_classes, sz_embedding).cuda()
print(pnca(X, T.view(sz_batch)))