Skip to content

Commit

Permalink
add init-frz-model support for se-t type descriptor
Browse files Browse the repository at this point in the history
  • Loading branch information
denghuilu committed Oct 27, 2021
1 parent 1a8fd73 commit 023fc6c
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 115 deletions.
67 changes: 31 additions & 36 deletions source/lib/src/cuda/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -420,49 +420,43 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
const FPTYPE max,
const FPTYPE stride0,
const FPTYPE stride1,
const int nnei,
const int nnei_i,
const int nnei_j,
const int last_layer_size)
{
extern __shared__ int _data[];
const int block_idx = blockIdx.x; // nloc
const int block_idx = blockIdx.x; // nloc
const int thread_idx = threadIdx.x; // last_layer_size
FPTYPE ago = __shfl_sync(0xffffffff, em_x[block_idx * nnei + nnei - 1], 0);
bool unloop = false;
int breakpoint = nnei - 1;
FPTYPE * iteratorC = (FPTYPE*) &_data[0];
for (int kk = 0; kk < MTILE; kk++)
iteratorC[kk * last_layer_size + thread_idx] = 0.f;
__syncthreads();

for (int ii = 0; ii < nnei; ii++) {
FPTYPE var[6];
FPTYPE xx = em_x[block_idx * nnei + ii];
FPTYPE dz_xx = dz_dy_dem_x[block_idx * nnei + ii];
if (xx == ago) {
unloop = true;
breakpoint = ii;
}
int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
var[0] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 0];
var[1] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 1];
var[2] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 2];
var[3] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 3];
var[4] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 4];
var[5] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 5];
FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx;
FPTYPE res_grad = var[1] + (2 * var[2] + (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * xx;
FPTYPE sum = 0.f;
for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago = __shfl_sync(0xffffffff, em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
bool unloop = false;
for (int jj = 0; ii < nnei_j; jj++) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE tmp = xx;
FPTYPE dz_xx = dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE dz_em = dz_dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE var[6];
if (ago == xx) {
unloop = true;
}

for (int kk = 0; kk < MTILE; kk++) {
int em_index = block_idx * nnei * MTILE + ii * MTILE + kk;
iteratorC[kk * last_layer_size + thread_idx] += (nnei - breakpoint) * (em[em_index] * res_grad * dz_xx + dz_dy_dem[em_index] * res);
int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
var[0] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 0];
var[1] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 1];
var[2] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 2];
var[3] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 3];
var[4] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 4];
var[5] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 5];
FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx;
FPTYPE res_grad = var[1] + (2 * var[2] + (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * xx;

sum += (tmp * res_grad * dz_xx + dz_em * res);
if (unloop) break;
}
if (unloop) break;
}
for (int ii = 0; ii < MTILE; ii++) {
dz_dy[block_idx * MTILE * last_layer_size + ii * last_layer_size + thread_idx] = iteratorC[ii * last_layer_size + thread_idx];
}
dz_dy[block_idx * last_layer_size + thread_idx] = sum;
}

namespace deepmd {
Expand Down Expand Up @@ -604,7 +598,8 @@ void tabulate_fusion_se_t_grad_grad_gpu_cuda(
DPErrcheck(cudaMemset(
dz_dy,
0.0, sizeof(FPTYPE) * nloc * last_layer_size));
tabulate_fusion_se_t_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK> <<<nloc, last_layer_size, sizeof(FPTYPE) * MM * last_layer_size>>>(

tabulate_fusion_se_t_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK> <<<nloc, last_layer_size>>>(
dz_dy,
table, em_x, em, dz_dy_dem_x, dz_dy_dem, table_info[0], table_info[1], table_info[2], table_info[3], table_info[4], nnei_i, nnei_j, last_layer_size);
DPErrcheck(cudaGetLastError());
Expand Down
64 changes: 29 additions & 35 deletions source/lib/src/rocm/tabulate.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -430,45 +430,39 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
const int nnei_j,
const int last_layer_size)
{
extern __shared__ int _data[];
const int block_idx = blockIdx.x; // nloc
const int block_idx = blockIdx.x; // nloc
const int thread_idx = threadIdx.x; // last_layer_size
FPTYPE ago = __shfl( em_x[block_idx * nnei + nnei - 1], 0);
bool unloop = false;
int breakpoint = nnei - 1;
FPTYPE * iteratorC = (FPTYPE*) &_data[0];
for (int kk = 0; kk < MTILE; kk++)
iteratorC[kk * last_layer_size + thread_idx] = 0.f;
__syncthreads();

for (int ii = 0; ii < nnei; ii++) {
FPTYPE var[6];
FPTYPE xx = em_x[block_idx * nnei + ii];
FPTYPE dz_xx = dz_dy_dem_x[block_idx * nnei + ii];
if (xx == ago) {
unloop = true;
breakpoint = ii;
}
int table_idx = 0;
locate_xx(xx, table_idx, lower, upper, max, stride0, stride1);
var[0] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 0];
var[1] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 1];
var[2] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 2];
var[3] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 3];
var[4] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 4];
var[5] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 5];
FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx;
FPTYPE res_grad = var[1] + (2 * var[2] + (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * xx;
FPTYPE sum = 0.f;
for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago = __shfl(em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
bool unloop = false;
for (int jj = 0; ii < nnei_j; jj++) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE tmp = xx;
FPTYPE dz_xx = dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE dz_em = dz_dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE var[6];
if (ago == xx) {
unloop = true;
}

for (int kk = 0; kk < MTILE; kk++) {
int em_index = block_idx * nnei * MTILE + ii * MTILE + kk;
iteratorC[kk * last_layer_size + thread_idx] += (nnei - breakpoint) * (em[em_index] * res_grad * dz_xx + dz_dy_dem[em_index] * res);
int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
var[0] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 0];
var[1] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 1];
var[2] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 2];
var[3] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 3];
var[4] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 4];
var[5] = table[table_idx * last_layer_size * 6 + thread_idx * 6 + 5];
FPTYPE res = var[0] + (var[1] + (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx;
FPTYPE res_grad = var[1] + (2 * var[2] + (3 * var[3] + (4 * var[4] + 5 * var[5] * xx) * xx) * xx) * xx;

sum += (tmp * res_grad * dz_xx + dz_em * res);
if (unloop) break;
}
if (unloop) break;
}
for (int ii = 0; ii < MTILE; ii++) {
dz_dy[block_idx * MTILE * last_layer_size + ii * last_layer_size + thread_idx] = iteratorC[ii * last_layer_size + thread_idx];
}
dz_dy[block_idx * last_layer_size + thread_idx] = sum;
}

namespace deepmd {
Expand Down Expand Up @@ -610,7 +604,7 @@ void tabulate_fusion_se_t_grad_grad_gpu_rocm(
DPErrcheck(hipMemset(
dz_dy,
0.0, sizeof(FPTYPE) * nloc * last_layer_size));
hipLaunchKernelGGL(HIP_KERNEL_NAME(tabulate_fusion_se_t_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>), nloc, last_layer_size, sizeof(FPTYPE) * last_layer_size, 0,
hipLaunchKernelGGL(HIP_KERNEL_NAME(tabulate_fusion_se_t_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>), nloc, last_layer_size, 0, 0,
dz_dy,
table, em_x, em, dz_dy_dem_x, dz_dy_dem, table_info[0], table_info[1], table_info[2], table_info[3], table_info[4], nnei_i, nnei_j, last_layer_size);
DPErrcheck(hipGetLastError());
Expand Down
69 changes: 27 additions & 42 deletions source/lib/src/tabulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,11 @@ void deepmd::tabulate_fusion_se_t_grad_grad_cpu(
const FPTYPE * dz_dy_dem_x,
const FPTYPE * dz_dy_dem,
const int nloc,
const int nnei,
const int nnei_i,
const int nnei_j,
const int last_layer_size)
{
memset(dz_dy, 0.0, sizeof(FPTYPE) * nloc * 4 * last_layer_size);
memset(dz_dy, 0.0, sizeof(FPTYPE) * nloc * last_layer_size);
const FPTYPE lower = table_info[0];
const FPTYPE upper = table_info[1];
const FPTYPE _max = table_info[2];
Expand All @@ -436,49 +436,34 @@ void deepmd::tabulate_fusion_se_t_grad_grad_cpu(
// FPTYPE * res = new FPTYPE[4 * last_layer_size];
#pragma omp parallel for
for (int ii = 0; ii < nloc; ii++) {
FPTYPE ll[4];
FPTYPE hh[4];
FPTYPE ago = em_x[ii * nnei + nnei - 1];
bool unloop = false;
for (int jj = 0; jj < nnei; jj++) {
ll[0] = em[ii * nnei * 4 + jj * 4 + 0];
ll[1] = em[ii * nnei * 4 + jj * 4 + 1];
ll[2] = em[ii * nnei * 4 + jj * 4 + 2];
ll[3] = em[ii * nnei * 4 + jj * 4 + 3];
hh[0] = dz_dy_dem[ii * nnei * 4 + jj * 4 + 0];
hh[1] = dz_dy_dem[ii * nnei * 4 + jj * 4 + 1];
hh[2] = dz_dy_dem[ii * nnei * 4 + jj * 4 + 2];
hh[3] = dz_dy_dem[ii * nnei * 4 + jj * 4 + 3];
FPTYPE xx = em_x[ii * nnei + jj];
FPTYPE dz_xx = dz_dy_dem_x[ii * nnei + jj];
if (ago == xx) {
unloop = true;
}
int table_idx = 0;
locate_xx(lower, upper, _max, stride0, stride1, xx, table_idx);
for (int kk = 0; kk < last_layer_size; kk++) {
FPTYPE a0 = table[table_idx * last_layer_size * 6 + 6 * kk + 0];
FPTYPE a1 = table[table_idx * last_layer_size * 6 + 6 * kk + 1];
FPTYPE a2 = table[table_idx * last_layer_size * 6 + 6 * kk + 2];
FPTYPE a3 = table[table_idx * last_layer_size * 6 + 6 * kk + 3];
FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * kk + 4];
FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * kk + 5];
FPTYPE var = a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx;
FPTYPE var_grad = a1 + (2 * a2 + (3 * a3 + (4 * a4 + 5 * a5 * xx) * xx) * xx) * xx;
if (unloop) {
dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] += (nnei - jj) * (var * hh[0] + dz_xx * var_grad * ll[0]);
dz_dy[ii * last_layer_size * 4 + 1 * last_layer_size + kk] += (nnei - jj) * (var * hh[1] + dz_xx * var_grad * ll[1]);
dz_dy[ii * last_layer_size * 4 + 2 * last_layer_size + kk] += (nnei - jj) * (var * hh[2] + dz_xx * var_grad * ll[2]);
dz_dy[ii * last_layer_size * 4 + 3 * last_layer_size + kk] += (nnei - jj) * (var * hh[3] + dz_xx * var_grad * ll[3]);
for (int jj = 0; jj < nnei_i; jj++) {
FPTYPE ago = em_x[ii * nnei_i * nnei_j + jj * nnei_j + nnei_j - 1];
bool unloop = false;
for (int kk = 0; kk < nnei_j; kk++) {
FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk];
FPTYPE tmp = xx;
FPTYPE dz_em = dz_dy_dem [ii * nnei_i * nnei_j + jj * nnei_j + kk];
FPTYPE dz_xx = dz_dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk];

if (ago == xx) {
unloop = true;
}
else {
dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] += var * hh[0] + dz_xx * var_grad * ll[0];
dz_dy[ii * last_layer_size * 4 + 1 * last_layer_size + kk] += var * hh[1] + dz_xx * var_grad * ll[1];
dz_dy[ii * last_layer_size * 4 + 2 * last_layer_size + kk] += var * hh[2] + dz_xx * var_grad * ll[2];
dz_dy[ii * last_layer_size * 4 + 3 * last_layer_size + kk] += var * hh[3] + dz_xx * var_grad * ll[3];
int table_idx = 0;
locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx, table_idx);
for (int mm = 0; mm < last_layer_size; mm++) {
FPTYPE a0 = table[table_idx * last_layer_size * 6 + 6 * mm + 0];
FPTYPE a1 = table[table_idx * last_layer_size * 6 + 6 * mm + 1];
FPTYPE a2 = table[table_idx * last_layer_size * 6 + 6 * mm + 2];
FPTYPE a3 = table[table_idx * last_layer_size * 6 + 6 * mm + 3];
FPTYPE a4 = table[table_idx * last_layer_size * 6 + 6 * mm + 4];
FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5];
FPTYPE var = a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx;
FPTYPE var_grad = a1 + (2 * a2 + (3 * a3 + (4 * a4 + 5 * a5 * xx) * xx) * xx) * xx;

dz_dy[ii * last_layer_size + mm] += var * dz_em + dz_xx * var_grad * tmp;
}
if (unloop) break;
}
if (unloop) break;
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions source/op/tabulate_multi_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,10 @@ class TabulateFusionSeTGradGradOp : public OpKernel {
const FPTYPE * em = em_tensor.flat<FPTYPE>().data();
const FPTYPE * dz_dy_dem_x = dz_dy_dem_x_tensor.flat<FPTYPE>().data();
const FPTYPE * dz_dy_dem = dz_dy_dem_tensor.flat<FPTYPE>().data();
const int nloc = em_tensor.shape().dim_size(0);
const int nloc = em_tensor.shape().dim_size(0);
const int nnei_i = em_tensor.shape().dim_size(1);
const int nnei_j = em_tensor.shape().dim_size(2);
const int last_layer_size = descriptor_tensor.shape().dim_size(2);
const int last_layer_size = descriptor_tensor.shape().dim_size(1);

if (device == "GPU") {
#if GOOGLE_CUDA
Expand Down

0 comments on commit 023fc6c

Please sign in to comment.