Skip to content

Commit

Permalink
Increase the tolerance (apache#7937)
Browse files Browse the repository at this point in the history
  • Loading branch information
goswamig authored and crazy-cat committed Oct 26, 2017
1 parent 8cb6021 commit d666afa
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/python/unittest/test_operator.py
Expand Up @@ -903,28 +903,28 @@ def check_batchnorm_training(stype):
mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)]

test = mx.symbol.BatchNorm_v1(data, fix_gamma=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-4)

test = mx.symbol.BatchNorm(data, fix_gamma=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-4)

test = mx.symbol.BatchNorm_v1(data, fix_gamma=True, use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-4)

test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-4)

test = mx.symbol.BatchNorm_v1(data, fix_gamma=False)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-4)

test = mx.symbol.BatchNorm(data, fix_gamma=False)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-4)

test = mx.symbol.BatchNorm_v1(data, fix_gamma=False, use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-4)

test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-4)

# Test varying channel axis
dim = len(shape)
Expand Down

0 comments on commit d666afa

Please sign in to comment.