From 023fc6cc8a5ae38239b5729affbeb5723a37cdce Mon Sep 17 00:00:00 2001 From: denghuilu Date: Wed, 27 Oct 2021 09:15:23 +0800 Subject: [PATCH] add init-frz-model support for se-t type descriptor --- source/lib/src/cuda/tabulate.cu | 67 +++++++++++++--------------- source/lib/src/rocm/tabulate.hip.cu | 64 ++++++++++++-------------- source/lib/src/tabulate.cc | 69 +++++++++++------------------ source/op/tabulate_multi_device.cc | 4 +- 4 files changed, 89 insertions(+), 115 deletions(-) diff --git a/source/lib/src/cuda/tabulate.cu b/source/lib/src/cuda/tabulate.cu index 47ae73577f..4cc6b112ee 100644 --- a/source/lib/src/cuda/tabulate.cu +++ b/source/lib/src/cuda/tabulate.cu @@ -420,49 +420,43 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial( const FPTYPE max, const FPTYPE stride0, const FPTYPE stride1, - const int nnei, + const int nnei_i, const int nnei_j, const int last_layer_size) { - extern __shared__ int _data[]; - const int block_idx = blockIdx.x; // nloc + const int block_idx = blockIdx.x; // nloc const int thread_idx = threadIdx.x; // last_layer_size - FPTYPE ago = __shfl_sync(0xffffffff, em_x[block_idx * nnei + nnei - 1], 0); - bool unloop = false; - int breakpoint = nnei - 1; - FPTYPE * iteratorC = (FPTYPE*) &_data[0]; - for (int kk = 0; kk < MTILE; kk++) - iteratorC[kk * last_layer_size + thread_idx] = 0.f; - __syncthreads(); - for (int ii = 0; ii < nnei; ii++) { - FPTYPE var[6]; - FPTYPE xx = em_x[block_idx * nnei + ii]; - FPTYPE dz_xx = dz_dy_dem_x[block_idx * nnei + ii]; - if (xx == ago) { - unloop = true; - breakpoint = ii; - } - int table_idx = 0; - locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); - var[0] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 0]; - var[1] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 1]; - var[2] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 2]; - var[3] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 3]; - var[4] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 4]; - var[5] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 5]; - FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx; - FPTYPE res_grad = var[1] + (2 * var[2] + (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * xx; + FPTYPE sum = 0.f; + for (int ii = 0; ii < nnei_i; ii++) { + FPTYPE ago = __shfl_sync(0xffffffff, em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0); + bool unloop = false; + for (int jj = 0; ii < nnei_j; jj++) { + FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; + FPTYPE tmp = xx; + FPTYPE dz_xx = dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; + FPTYPE dz_em = dz_dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; + FPTYPE var[6]; + if (ago == xx) { + unloop = true; + } - for (int kk = 0; kk < MTILE; kk++) { - int em_index = block_idx * nnei * MTILE + ii * MTILE + kk; - iteratorC[kk * last_layer_size + thread_idx] += (nnei - breakpoint) * (em[em_index] * res_grad * dz_xx + dz_dy_dem[em_index] * res); + int table_idx = 0; + locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); + var[0] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 0]; + var[1] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 1]; + var[2] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 2]; + var[3] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 3]; + var[4] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 4]; + var[5] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 5]; + FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx; + FPTYPE res_grad = var[1] + (2 * var[2] + (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * xx; + + sum += (tmp * res_grad * dz_xx + dz_em * res); + if (unloop) break; } - if (unloop) break; - } - for (int ii = 0; ii < MTILE; ii++) { - dz_dy[block_idx * MTILE * last_layer_size + ii * last_layer_size + thread_idx] = iteratorC[ii * last_layer_size + thread_idx]; } + dz_dy[block_idx * last_layer_size + thread_idx] = sum; } namespace deepmd { @@ -604,7 +598,8 @@ void tabulate_fusion_se_t_grad_grad_gpu_cuda( DPErrcheck(cudaMemset( dz_dy, 0.0, sizeof(FPTYPE) * nloc * last_layer_size)); - tabulate_fusion_se_t_grad_grad_fifth_order_polynomial <<>>( + + tabulate_fusion_se_t_grad_grad_fifth_order_polynomial <<>>( dz_dy, table, em_x, em, dz_dy_dem_x, dz_dy_dem, table_info[0], table_info[1], table_info[2], table_info[3], table_info[4], nnei_i, nnei_j, last_layer_size); DPErrcheck(cudaGetLastError()); diff --git a/source/lib/src/rocm/tabulate.hip.cu b/source/lib/src/rocm/tabulate.hip.cu index 055d52d7b8..6b6270c18f 100644 --- a/source/lib/src/rocm/tabulate.hip.cu +++ b/source/lib/src/rocm/tabulate.hip.cu @@ -430,45 +430,39 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial( const int nnei_j, const int last_layer_size) { - extern __shared__ int _data[]; - const int block_idx = blockIdx.x; // nloc + const int block_idx = blockIdx.x; // nloc const int thread_idx = threadIdx.x; // last_layer_size - FPTYPE ago = __shfl( em_x[block_idx * nnei + nnei - 1], 0); - bool unloop = false; - int breakpoint = nnei - 1; - FPTYPE * iteratorC = (FPTYPE*) &_data[0]; - for (int kk = 0; kk < MTILE; kk++) - iteratorC[kk * last_layer_size + thread_idx] = 0.f; - __syncthreads(); - for (int ii = 0; ii < nnei; ii++) { - FPTYPE var[6]; - FPTYPE xx = em_x[block_idx * nnei + ii]; - FPTYPE dz_xx = dz_dy_dem_x[block_idx * nnei + ii]; - if (xx == ago) { - unloop = true; - breakpoint = ii; - } - int table_idx = 0; - locate_xx(xx, table_idx, lower, upper, max, stride0, stride1); - var[0] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 0]; - var[1] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 1]; - var[2] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 2]; - var[3] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 3]; - var[4] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 4]; - var[5] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 5]; - FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx; - FPTYPE res_grad = var[1] + (2 * var[2] + (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * xx; + FPTYPE sum = 0.f; + for (int ii = 0; ii < nnei_i; ii++) { + FPTYPE ago = __shfl(em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0); + bool unloop = false; + for (int jj = 0; ii < nnei_j; jj++) { + FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; + FPTYPE tmp = xx; + FPTYPE dz_xx = dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; + FPTYPE dz_em = dz_dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; + FPTYPE var[6]; + if (ago == xx) { + unloop = true; + } - for (int kk = 0; kk < MTILE; kk++) { - int em_index = block_idx * nnei * MTILE + ii * MTILE + kk; - iteratorC[kk * last_layer_size + thread_idx] += (nnei - breakpoint) * (em[em_index] * res_grad * dz_xx + dz_dy_dem[em_index] * res); + int table_idx = 0; + locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); + var[0] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 0]; + var[1] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 1]; + var[2] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 2]; + var[3] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 3]; + var[4] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 4]; + var[5] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 5]; + FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx; + FPTYPE res_grad = var[1] + (2 * var[2] + (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * xx; + + sum += (tmp * res_grad * dz_xx + dz_em * res); + if (unloop) break; } - if (unloop) break; - } - for (int ii = 0; ii < MTILE; ii++) { - dz_dy[block_idx * MTILE * last_layer_size + ii * last_layer_size + thread_idx] = iteratorC[ii * last_layer_size + thread_idx]; } + dz_dy[block_idx * last_layer_size + thread_idx] = sum; } namespace deepmd { @@ -610,7 +604,7 @@ void tabulate_fusion_se_t_grad_grad_gpu_rocm( DPErrcheck(hipMemset( dz_dy, 0.0, sizeof(FPTYPE) * nloc * last_layer_size)); - hipLaunchKernelGGL(HIP_KERNEL_NAME(tabulate_fusion_se_t_grad_grad_fifth_order_polynomial), nloc, last_layer_size, sizeof(FPTYPE) * last_layer_size, 0, + hipLaunchKernelGGL(HIP_KERNEL_NAME(tabulate_fusion_se_t_grad_grad_fifth_order_polynomial), nloc, last_layer_size, 0, 0, dz_dy, table, em_x, em, dz_dy_dem_x, dz_dy_dem, table_info[0], table_info[1], table_info[2], table_info[3], table_info[4], nnei_i, nnei_j, last_layer_size); DPErrcheck(hipGetLastError()); diff --git a/source/lib/src/tabulate.cc b/source/lib/src/tabulate.cc index ffacb96fdb..840284b93b 100644 --- a/source/lib/src/tabulate.cc +++ b/source/lib/src/tabulate.cc @@ -422,11 +422,11 @@ void deepmd::tabulate_fusion_se_t_grad_grad_cpu( const FPTYPE * dz_dy_dem_x, const FPTYPE * dz_dy_dem, const int nloc, - const int nnei, + const int nnei_i, const int nnei_j, const int last_layer_size) { - memset(dz_dy, 0.0, sizeof(FPTYPE) * nloc * 4 * last_layer_size); + memset(dz_dy, 0.0, sizeof(FPTYPE) * nloc * last_layer_size); const FPTYPE lower = table_info[0]; const FPTYPE upper = table_info[1]; const FPTYPE _max = table_info[2]; @@ -436,49 +436,34 @@ void deepmd::tabulate_fusion_se_t_grad_grad_cpu( // FPTYPE * res = new FPTYPE[4 * last_layer_size]; #pragma omp parallel for for (int ii = 0; ii < nloc; ii++) { - FPTYPE ll[4]; - FPTYPE hh[4]; - FPTYPE ago = em_x[ii * nnei + nnei - 1]; - bool unloop = false; - for (int jj = 0; jj < nnei; jj++) { - ll[0] = em[ii * nnei * 4 + jj * 4 + 0]; - ll[1] = em[ii * nnei * 4 + jj * 4 + 1]; - ll[2] = em[ii * nnei * 4 + jj * 4 + 2]; - ll[3] = em[ii * nnei * 4 + jj * 4 + 3]; - hh[0] = dz_dy_dem[ii * nnei * 4 + jj * 4 + 0]; - hh[1] = dz_dy_dem[ii * nnei * 4 + jj * 4 + 1]; - hh[2] = dz_dy_dem[ii * nnei * 4 + jj * 4 + 2]; - hh[3] = dz_dy_dem[ii * nnei * 4 + jj * 4 + 3]; - FPTYPE xx = em_x[ii * nnei + jj]; - FPTYPE dz_xx = dz_dy_dem_x[ii * nnei + jj]; - if (ago == xx) { - unloop = true; - } - int table_idx = 0; - locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx); - for (int kk = 0; kk < last_layer_size; kk++) { - FPTYPE a0 = table[table_idx * last_layer_size * 6 + 6 * kk + 0]; - FPTYPE a1 = table[table_idx * last_layer_size * 6 + 6 * kk + 1]; - FPTYPE a2 = table[table_idx * last_layer_size * 6 + 6 * kk + 2]; - FPTYPE a3 = table[table_idx * last_layer_size * 6 + 6 * kk + 3]; - FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * kk + 4]; - FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * kk + 5]; - FPTYPE var = a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; - FPTYPE var_grad = a1 + (2 * a2 + (3 * a3 + (4 * a4 + 5 * a5 * xx) * xx) * xx) * xx; - if (unloop) { - dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] += (nnei - jj) * (var * hh[0] + dz_xx * var_grad * ll[0]); - dz_dy[ii * last_layer_size * 4 + 1 * last_layer_size + kk] += (nnei - jj) * (var * hh[1] + dz_xx * var_grad * ll[1]); - dz_dy[ii * last_layer_size * 4 + 2 * last_layer_size + kk] += (nnei - jj) * (var * hh[2] + dz_xx * var_grad * ll[2]); - dz_dy[ii * last_layer_size * 4 + 3 * last_layer_size + kk] += (nnei - jj) * (var * hh[3] + dz_xx * var_grad * ll[3]); + for (int jj = 0; jj < nnei_i; jj++) { + FPTYPE ago = em_x[ii * nnei_i * nnei_j + jj * nnei_j + nnei_j - 1]; + bool unloop = false; + for (int kk = 0; kk < nnei_j; kk++) { + FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; + FPTYPE tmp = xx; + FPTYPE dz_em = dz_dy_dem [ii * nnei_i * nnei_j + jj * nnei_j + kk]; + FPTYPE dz_xx = dz_dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; + + if (ago == xx) { + unloop = true; } - else { - dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] += var * hh[0] + dz_xx * var_grad * ll[0]; - dz_dy[ii * last_layer_size * 4 + 1 * last_layer_size + kk] += var * hh[1] + dz_xx * var_grad * ll[1]; - dz_dy[ii * last_layer_size * 4 + 2 * last_layer_size + kk] += var * hh[2] + dz_xx * var_grad * ll[2]; - dz_dy[ii * last_layer_size * 4 + 3 * last_layer_size + kk] += var * hh[3] + dz_xx * var_grad * ll[3]; + int table_idx = 0; + locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx, table_idx); + for (int mm = 0; mm < last_layer_size; mm++) { + FPTYPE a0 = table[table_idx * last_layer_size * 6 + 6 * mm + 0]; + FPTYPE a1 = table[table_idx * last_layer_size * 6 + 6 * mm + 1]; + FPTYPE a2 = table[table_idx * last_layer_size * 6 + 6 * mm + 2]; + FPTYPE a3 = table[table_idx * last_layer_size * 6 + 6 * mm + 3]; + FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * mm + 4]; + FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5]; + FPTYPE var = a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; + FPTYPE var_grad = a1 + (2 * a2 + (3 * a3 + (4 * a4 + 5 * a5 * xx) * xx) * xx) * xx; + + dz_dy[ii * last_layer_size + mm] += var * dz_em + dz_xx * var_grad * tmp; } + if (unloop) break; } - if (unloop) break; } } } diff --git a/source/op/tabulate_multi_device.cc b/source/op/tabulate_multi_device.cc index c19c88e48b..aba6efc5b5 100644 --- a/source/op/tabulate_multi_device.cc +++ b/source/op/tabulate_multi_device.cc @@ -472,10 +472,10 @@ class TabulateFusionSeTGradGradOp : public OpKernel { const FPTYPE * em = em_tensor.flat().data(); const FPTYPE * dz_dy_dem_x = dz_dy_dem_x_tensor.flat().data(); const FPTYPE * dz_dy_dem = dz_dy_dem_tensor.flat().data(); - const int nloc = em_tensor.shape().dim_size(0); + const int nloc = em_tensor.shape().dim_size(0); const int nnei_i = em_tensor.shape().dim_size(1); const int nnei_j = em_tensor.shape().dim_size(2); - const int last_layer_size = descriptor_tensor.shape().dim_size(2); + const int last_layer_size = descriptor_tensor.shape().dim_size(1); if (device == "GPU") { #if GOOGLE_CUDA