Skip to content

Commit

Permalink
Avoid syntax error in Python 2
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Jan 24, 2018
1 parent 2155c0a commit 1899ff3
Showing 1 changed file with 5 additions and 4 deletions.
Expand Up @@ -18,7 +18,8 @@ def _to_noncontiguous(arrays):
return [xp.asfortranarray(a) for a in arrays]


def _batch_normalization(x, gamma, beta, mean, var, expander):
def _batch_normalization(args):
x, gamma, beta, mean, var, expander = args
mean = mean[expander]
std = numpy.sqrt(var)[expander]
y_expect = (gamma[expander] * (x - mean) / std + beta[expander])
Expand Down Expand Up @@ -90,7 +91,7 @@ def setUp(self):

def forward_cpu(self, inputs):
y_expect = _batch_normalization(
*inputs, self.mean, self.var, self.expander)
inputs + [self.mean, self.var, self.expander])
return y_expect,

def check_forward(self, inputs, backend_config):
Expand Down Expand Up @@ -221,7 +222,7 @@ def setUp(self):
'dtype': numpy.float64, 'atol': 1e-2, 'rtol': 1e-2}

def forward_cpu(self, inputs):
y_expect = _batch_normalization(*inputs, self.expander)
y_expect = _batch_normalization(inputs + [self.expander])
return y_expect,

def check_forward(self, inputs, backend_config):
Expand All @@ -237,7 +238,7 @@ def check_forward(self, inputs, backend_config):
assert y.data.dtype == self.dtype

testing.assert_allclose(
y_expected.data, y.data, **self.check_forward_options)
y_expected, y.data, **self.check_forward_options)

def test_forward(self, backend_config):
self.check_forward(self.inputs, backend_config)
Expand Down

0 comments on commit 1899ff3

Please sign in to comment.