Skip to content

Commit

Permalink
[Prim] Simplify bn vjp (PaddlePaddle#54012)
Browse files Browse the repository at this point in the history
* recompute bn grad

* fix test case

---------

Co-authored-by: sunli <466530738@qq.com>
  • Loading branch information
2 people authored and bukejiyu committed May 22, 2023
1 parent 322d9f4 commit a3ad7a6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
Expand Up @@ -1366,9 +1366,8 @@ void batch_norm_grad(const Tensor& x,
auto nhwc_out_grad = transpose<T>(out_grad_data, nchw_to_nhwc_dim);
auto nhwc_out_grad_sum = sum<T>(nhwc_out_grad, reduce_axis, dtype, false);

auto x_sub_mean = nhwc_x - mean_data;
auto sum_dout_mul_diff =
sum<T>(nhwc_out_grad * x_sub_mean, reduce_axis, dtype, false);
auto sum_dout_mul_diff = sum<T>(
nhwc_out_grad * (nhwc_x - mean_data), reduce_axis, dtype, false);

if (x_grad) {
if (use_global_stats) {
Expand All @@ -1382,7 +1381,8 @@ void batch_norm_grad(const Tensor& x,
auto part1 = scale * rsqrt_var;
auto mean_temp1 = nhwc_out_grad_sum / nhw;
auto mean_temp2 = sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var;
auto part2 = nhwc_out_grad - mean_temp1 - x_sub_mean * mean_temp2;
auto part2 =
nhwc_out_grad - mean_temp1 - (nhwc_x - mean_data) * mean_temp2;

auto x_grad_data = part1 * part2;
auto nchw_x_grad = transpose<T>(x_grad_data, nhwc_to_nchw_dim);
Expand All @@ -1403,11 +1403,10 @@ void batch_norm_grad(const Tensor& x,
}
case DataLayout::kNHWC: {
if (x_grad) {
auto x_sub_mean = x_data - mean_data;
auto out_grad_data_sum =
sum<T>(out_grad_data, reduce_axis, dtype, false);
auto nhwc_sum_dout_mul_diff =
sum<T>(out_grad_data * x_sub_mean, reduce_axis, dtype, false);
auto nhwc_sum_dout_mul_diff = sum<T>(
out_grad_data * (x_data - mean_data), reduce_axis, dtype, false);
if (use_global_stats) {
auto x_grad_data = scale * rsqrt_var * out_grad_data;
if (x.dtype() == phi::DataType::FLOAT16) {
Expand All @@ -1420,7 +1419,8 @@ void batch_norm_grad(const Tensor& x,
auto mean_temp1 = out_grad_data_sum / nhw;
auto mean_temp2 =
nhwc_sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var;
auto part2 = out_grad_data - mean_temp1 - x_sub_mean * mean_temp2;
auto part2 =
out_grad_data - mean_temp1 - (x_data - mean_data) * mean_temp2;

auto x_grad_data = part1 * part2;
if (x.dtype() == phi::DataType::FLOAT16) {
Expand Down
19 changes: 9 additions & 10 deletions test/prim/model/test_resnet_prim_cinn.py
Expand Up @@ -46,17 +46,16 @@
# The results in ci as as follows:
DY2ST_PRIM_CINN_GT = [
5.828786849975586,
8.332858085632324,
5.026939868927002,
8.475804328918457,
8.017110824584961,
7.8353095054626465,
9.731267929077148,
8.193124771118164,
8.155317306518555,
10.185102462768555,
8.332863807678223,
5.0373005867004395,
8.464998245239258,
8.20099925994873,
7.576723098754883,
9.679173469543457,
8.381753921508789,
8.10612678527832,
10.124727249145508,
]

if core.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': True})

Expand Down

0 comments on commit a3ad7a6

Please sign in to comment.