/
modified_linear.py
59 lines (53 loc) · 2.18 KB
/
modified_linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import math
import torch
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from torch.nn import Module
class CosineLinear(Module):
def __init__(self, in_features, out_features, sigma=True):
super(CosineLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if sigma:
self.sigma = Parameter(torch.Tensor(1))
else:
self.register_parameter('sigma', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.sigma is not None:
self.sigma.data.fill_(1) #for initializaiton of sigma
def forward(self, input):
#w_norm = self.weight.data.norm(dim=1, keepdim=True)
#w_norm = w_norm.expand_as(self.weight).add_(self.epsilon)
#x_norm = input.data.norm(dim=1, keepdim=True)
#x_norm = x_norm.expand_as(input).add_(self.epsilon)
#w = self.weight.div(w_norm)
#x = input.div(x_norm)
out = F.linear(F.normalize(input, p=2,dim=1), \
F.normalize(self.weight, p=2, dim=1))
if self.sigma is not None:
out = self.sigma * out
return out
class SplitCosineLinear(Module):
#consists of two fc layers and concatenate their outputs
def __init__(self, in_features, out_features1, out_features2, sigma=True):
super(SplitCosineLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features1 + out_features2
self.fc1 = CosineLinear(in_features, out_features1, False)
self.fc2 = CosineLinear(in_features, out_features2, False)
if sigma:
self.sigma = Parameter(torch.Tensor(1))
self.sigma.data.fill_(1)
else:
self.register_parameter('sigma', None)
def forward(self, x):
out1 = self.fc1(x)
out2 = self.fc2(x)
out = torch.cat((out1, out2), dim=1) #concatenate along the channel
if self.sigma is not None:
out = self.sigma * out
return out