/
_batchnorm.py
183 lines (142 loc) · 6.35 KB
/
_batchnorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
from torch import nn
from e3nn import o3
from e3nn.util.jit import compile_mode
@compile_mode("unsupported")
class BatchNorm(nn.Module):
"""Batch normalization for orthonormal representations
It normalizes by the norm of the representations.
Note that the norm is invariant only for orthonormal representations.
Irreducible representations `wigner_D` are orthonormal.
Parameters
----------
irreps : `o3.Irreps`
representation
eps : float
avoid division by zero when we normalize by the variance
momentum : float
momentum of the running average
affine : bool
do we have weight and bias parameters
reduce : {'mean', 'max'}
method used to reduce
instance : bool
apply instance norm instead of batch norm
"""
def __init__(self, irreps, eps=1e-5, momentum=0.1, affine=True, reduce="mean", instance=False, normalization="component"):
super().__init__()
self.irreps = o3.Irreps(irreps)
self.eps = eps
self.momentum = momentum
self.affine = affine
self.instance = instance
num_scalar = sum(mul for mul, ir in self.irreps if ir.is_scalar())
num_features = self.irreps.num_irreps
if self.instance:
self.register_buffer("running_mean", None)
self.register_buffer("running_var", None)
else:
self.register_buffer("running_mean", torch.zeros(num_scalar))
self.register_buffer("running_var", torch.ones(num_features))
if affine:
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_scalar))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
assert isinstance(reduce, str), "reduce should be passed as a string value"
assert reduce in ["mean", "max"], "reduce needs to be 'mean' or 'max'"
self.reduce = reduce
assert normalization in ["norm", "component"], "normalization needs to be 'norm' or 'component'"
self.normalization = normalization
def __repr__(self):
return f"{self.__class__.__name__} ({self.irreps}, eps={self.eps}, momentum={self.momentum})"
def _roll_avg(self, curr, update):
return (1 - self.momentum) * curr + self.momentum * update.detach()
def forward(self, input):
"""evaluate
Parameters
----------
input : `torch.Tensor`
tensor of shape ``(batch, ..., irreps.dim)``
Returns
-------
`torch.Tensor`
tensor of shape ``(batch, ..., irreps.dim)``
"""
batch, *size, dim = input.shape
input = input.reshape(batch, -1, dim) # [batch, sample, stacked features]
if self.training and not self.instance:
new_means = []
new_vars = []
fields = []
ix = 0
irm = 0
irv = 0
iw = 0
ib = 0
for mul, ir in self.irreps:
d = ir.dim
field = input[:, :, ix : ix + mul * d] # [batch, sample, mul * repr]
ix += mul * d
# [batch, sample, mul, repr]
field = field.reshape(batch, -1, mul, d)
if ir.is_scalar(): # scalars
if self.training or self.instance:
if self.instance:
field_mean = field.mean(1).reshape(batch, mul) # [batch, mul]
else:
field_mean = field.mean([0, 1]).reshape(mul) # [mul]
new_means.append(self._roll_avg(self.running_mean[irm : irm + mul], field_mean))
else:
field_mean = self.running_mean[irm : irm + mul]
irm += mul
# [batch, sample, mul, repr]
field = field - field_mean.reshape(-1, 1, mul, 1)
if self.training or self.instance:
if self.normalization == "norm":
field_norm = field.pow(2).sum(3) # [batch, sample, mul]
elif self.normalization == "component":
field_norm = field.pow(2).mean(3) # [batch, sample, mul]
else:
raise ValueError("Invalid normalization option {}".format(self.normalization))
if self.reduce == "mean":
field_norm = field_norm.mean(1) # [batch, mul]
elif self.reduce == "max":
field_norm = field_norm.max(1).values # [batch, mul]
else:
raise ValueError("Invalid reduce option {}".format(self.reduce))
if not self.instance:
field_norm = field_norm.mean(0) # [mul]
new_vars.append(self._roll_avg(self.running_var[irv : irv + mul], field_norm))
else:
field_norm = self.running_var[irv : irv + mul]
irv += mul
field_norm = (field_norm + self.eps).pow(-0.5) # [(batch,) mul]
if self.affine:
weight = self.weight[iw : iw + mul] # [mul]
iw += mul
field_norm = field_norm * weight # [(batch,) mul]
field = field * field_norm.reshape(-1, 1, mul, 1) # [batch, sample, mul, repr]
if self.affine and ir.is_scalar(): # scalars
bias = self.bias[ib : ib + mul] # [mul]
ib += mul
field += bias.reshape(mul, 1) # [batch, sample, mul, repr]
fields.append(field.reshape(batch, -1, mul * d)) # [batch, sample, mul * repr]
if ix != dim:
fmt = "`ix` should have reached input.size(-1) ({}), but it ended at {}"
msg = fmt.format(dim, ix)
raise AssertionError(msg)
if self.training and not self.instance:
assert irm == self.running_mean.numel()
assert irv == self.running_var.size(0)
if self.affine:
assert iw == self.weight.size(0)
assert ib == self.bias.numel()
if self.training and not self.instance:
if len(new_means) > 0:
torch.cat(new_means, out=self.running_mean)
if len(new_vars) > 0:
torch.cat(new_vars, out=self.running_var)
output = torch.cat(fields, dim=2) # [batch, sample, stacked features]
return output.reshape(batch, *size, dim)