From a3ad7a65126b379f76e428b9e04efc5f0e2da7bd Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Mon, 22 May 2023 14:29:28 +0800 Subject: [PATCH] [Prim] Simplify bn vjp (#54012) * recompute bn grad * fix test case --------- Co-authored-by: sunli <466530738@qq.com> --- .../composite_backward_api.h | 16 ++++++++-------- test/prim/model/test_resnet_prim_cinn.py | 19 +++++++++---------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 45f62db4c70e4..ca99b818dbaf0 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1366,9 +1366,8 @@ void batch_norm_grad(const Tensor& x, auto nhwc_out_grad = transpose(out_grad_data, nchw_to_nhwc_dim); auto nhwc_out_grad_sum = sum(nhwc_out_grad, reduce_axis, dtype, false); - auto x_sub_mean = nhwc_x - mean_data; - auto sum_dout_mul_diff = - sum(nhwc_out_grad * x_sub_mean, reduce_axis, dtype, false); + auto sum_dout_mul_diff = sum( + nhwc_out_grad * (nhwc_x - mean_data), reduce_axis, dtype, false); if (x_grad) { if (use_global_stats) { @@ -1382,7 +1381,8 @@ void batch_norm_grad(const Tensor& x, auto part1 = scale * rsqrt_var; auto mean_temp1 = nhwc_out_grad_sum / nhw; auto mean_temp2 = sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var; - auto part2 = nhwc_out_grad - mean_temp1 - x_sub_mean * mean_temp2; + auto part2 = + nhwc_out_grad - mean_temp1 - (nhwc_x - mean_data) * mean_temp2; auto x_grad_data = part1 * part2; auto nchw_x_grad = transpose(x_grad_data, nhwc_to_nchw_dim); @@ -1403,11 +1403,10 @@ void batch_norm_grad(const Tensor& x, } case DataLayout::kNHWC: { if (x_grad) { - auto x_sub_mean = x_data - mean_data; auto out_grad_data_sum = sum(out_grad_data, reduce_axis, dtype, false); - auto nhwc_sum_dout_mul_diff = - sum(out_grad_data * x_sub_mean, reduce_axis, dtype, false); + auto nhwc_sum_dout_mul_diff = sum( + out_grad_data * (x_data - mean_data), reduce_axis, dtype, false); if (use_global_stats) { auto x_grad_data = scale * rsqrt_var * out_grad_data; if (x.dtype() == phi::DataType::FLOAT16) { @@ -1420,7 +1419,8 @@ void batch_norm_grad(const Tensor& x, auto mean_temp1 = out_grad_data_sum / nhw; auto mean_temp2 = nhwc_sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var; - auto part2 = out_grad_data - mean_temp1 - x_sub_mean * mean_temp2; + auto part2 = + out_grad_data - mean_temp1 - (x_data - mean_data) * mean_temp2; auto x_grad_data = part1 * part2; if (x.dtype() == phi::DataType::FLOAT16) { diff --git a/test/prim/model/test_resnet_prim_cinn.py b/test/prim/model/test_resnet_prim_cinn.py index 76295fbf9b4c8..0acf625393403 100644 --- a/test/prim/model/test_resnet_prim_cinn.py +++ b/test/prim/model/test_resnet_prim_cinn.py @@ -46,17 +46,16 @@ # The results in ci as as follows: DY2ST_PRIM_CINN_GT = [ 5.828786849975586, - 8.332858085632324, - 5.026939868927002, - 8.475804328918457, - 8.017110824584961, - 7.8353095054626465, - 9.731267929077148, - 8.193124771118164, - 8.155317306518555, - 10.185102462768555, + 8.332863807678223, + 5.0373005867004395, + 8.464998245239258, + 8.20099925994873, + 7.576723098754883, + 9.679173469543457, + 8.381753921508789, + 8.10612678527832, + 10.124727249145508, ] - if core.is_compiled_with_cuda(): paddle.set_flags({'FLAGS_cudnn_deterministic': True})