Skip to content

Commit

Permalink
Missing batch normalization functionality.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Apr 19, 2024
1 parent 1616922 commit 0468abb
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions Compiler/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,7 @@ def __init__(self, shape, approx=True, args=None):
else:
print('Precise square root inverse in batch normalization')
self.InvertSqrt = lambda x: 1 / mpc_math.sqrt(x)
self.is_trained = False

def __repr__(self):
return '%s(%s, approx=%s)' % \
Expand All @@ -1372,11 +1373,12 @@ def _output(self, batch, mu, var):
@for_range_opt_multithread(self.n_threads,
[len(batch), self.X.sizes[1]])
def _(i, j):
tmp = self.weights[:] * (self.X[i][j][:] - self.mu[:]) * factor[:]
tmp = self.weights[:] * (self.X[i][j][:] - mu[:]) * factor[:]
self.Y[i][j][:] = self.bias[:] + tmp

def forward(self, batch, training=False):
if training:
if training or not self.is_trained:
self.is_trained = True
d = self.X.sizes[1]
d_in = self.X.sizes[2]
s = sfix.Array(d_in)
Expand Down Expand Up @@ -3264,6 +3266,12 @@ def process(item):
pass
elif name == 'BatchNorm2d':
layers.append(BatchNorm(layers[-1].Y.sizes))
if input_via is not None:
layers[-1].epsilon = item.eps
layers[-1].weights = sfix.input_tensor_via(input_via,
item.weight.detach())
layers[-1].bias = sfix.input_tensor_via(input_via,
item.bias.detach())
elif name == 'Dropout':
layers.append(Dropout(input_shape[0], mul(layers[-1].Y.sizes[1:]),
alpha=item.p))
Expand Down

0 comments on commit 0468abb

Please sign in to comment.