Skip to content

Commit dae2cb2

Browse files
authored
Merge pull request #2286 from weinbe2/fea-snap-ui-optim
Kokkos SNAP optimizations – Pre-computing Cayley-Klein parameters, symmetrized data layouts for host and device backends, reducing number of atomics
2 parents fa0aa7f + 6dfe2f3 commit dae2cb2

File tree

5 files changed

+524
-298
lines changed

5 files changed

+524
-298
lines changed

src/KOKKOS/kokkos_type.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,7 @@ struct alignas(2*sizeof(real)) SNAComplex
10581058
{
10591059
real re,im;
10601060

1061-
KOKKOS_FORCEINLINE_FUNCTION SNAComplex() = default;
1061+
SNAComplex() = default;
10621062

10631063
KOKKOS_FORCEINLINE_FUNCTION SNAComplex(real re)
10641064
: re(re), im(static_cast<real>(0.)) { ; }
@@ -1100,6 +1100,17 @@ KOKKOS_FORCEINLINE_FUNCTION SNAComplex<real> operator*(const real& r, const SNAC
11001100

11011101
typedef SNAComplex<SNAreal> SNAcomplex;
11021102

1103+
// Cayley-Klein pack
1104+
// Can guarantee it's aligned to 2 complex
1105+
struct alignas(32) CayleyKleinPack {
1106+
1107+
SNAcomplex a, b;
1108+
SNAcomplex da[3], db[3];
1109+
SNAreal sfac;
1110+
SNAreal dsfacu[3];
1111+
1112+
};
1113+
11031114

11041115
#if defined(KOKKOS_ENABLE_CXX11)
11051116
#undef ISFINITE

src/KOKKOS/pair_snap_kokkos.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ struct TagPairSNAPComputeFusedDeidrj{};
5050
// CPU backend only
5151
struct TagPairSNAPPreUiCPU{};
5252
struct TagPairSNAPComputeUiCPU{};
53+
struct TagPairSNAPTransformUiCPU{};
5354
struct TagPairSNAPComputeZiCPU{};
5455
struct TagPairSNAPBetaCPU{};
5556
struct TagPairSNAPComputeBiCPU{};
@@ -104,7 +105,7 @@ class PairSNAPKokkos : public PairSNAP {
104105
void operator() (TagPairSNAPComputeUi,const typename Kokkos::TeamPolicy<DeviceType, TagPairSNAPComputeUi>::member_type& team) const;
105106

106107
KOKKOS_INLINE_FUNCTION
107-
void operator() (TagPairSNAPTransformUi,const int iatom_mod, const int idxu, const int iatom_div) const;
108+
void operator() (TagPairSNAPTransformUi,const int iatom_mod, const int j, const int iatom_div) const;
108109

109110
KOKKOS_INLINE_FUNCTION
110111
void operator() (TagPairSNAPComputeZi,const int iatom_mod, const int idxz, const int iatom_div) const;
@@ -135,13 +136,13 @@ class PairSNAPKokkos : public PairSNAP {
135136
void operator() (TagPairSNAPComputeUiCPU,const typename Kokkos::TeamPolicy<DeviceType, TagPairSNAPComputeUiCPU>::member_type& team) const;
136137

137138
KOKKOS_INLINE_FUNCTION
138-
void operator() (TagPairSNAPComputeZiCPU,const int& ii) const;
139+
void operator() (TagPairSNAPTransformUiCPU, const int j, const int iatom) const;
139140

140141
KOKKOS_INLINE_FUNCTION
141-
void operator() (TagPairSNAPComputeBiCPU,const typename Kokkos::TeamPolicy<DeviceType, TagPairSNAPComputeBiCPU>::member_type& team) const;
142+
void operator() (TagPairSNAPComputeZiCPU,const int& ii) const;
142143

143144
KOKKOS_INLINE_FUNCTION
144-
void operator() (TagPairSNAPZeroYiCPU,const typename Kokkos::TeamPolicy<DeviceType, TagPairSNAPZeroYiCPU>::member_type& team) const;
145+
void operator() (TagPairSNAPComputeBiCPU,const typename Kokkos::TeamPolicy<DeviceType, TagPairSNAPComputeBiCPU>::member_type& team) const;
145146

146147
KOKKOS_INLINE_FUNCTION
147148
void operator() (TagPairSNAPComputeYiCPU,const int& ii) const;

src/KOKKOS/pair_snap_kokkos_impl.h

Lines changed: 112 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,6 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
206206

207207
EV_FLOAT ev;
208208

209-
int idxu_max = snaKK.idxu_max;
210-
211209
while (chunk_offset < inum) { // chunk up loop to prevent running out of memory
212210

213211
EV_FLOAT ev_tmp;
@@ -246,6 +244,13 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
246244
Kokkos::parallel_for("ComputeUiCPU",policy_ui_cpu,*this);
247245
}
248246

247+
{
248+
// Expand ulisttot -> ulisttot_full
249+
// Zero out ylist
250+
typename Kokkos::MDRangePolicy<DeviceType, Kokkos::IndexType<int>, Kokkos::Rank<2, Kokkos::Iterate::Left, Kokkos::Iterate::Left>, TagPairSNAPTransformUiCPU> policy_transform_ui_cpu({0,0},{twojmax+1,chunk_size});
251+
Kokkos::parallel_for("TransformUiCPU",policy_transform_ui_cpu,*this);
252+
}
253+
249254
//Compute bispectrum
250255
if (quadraticflag || eflag) {
251256
//ComputeZi
@@ -261,20 +266,12 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
261266
Kokkos::parallel_for("ComputeBiCPU",policy_bi_cpu,*this);
262267
}
263268

264-
//ZeroYi,ComputeYi
269+
//ComputeYi
265270
{
266-
int vector_length = vector_length_default;
267-
int team_size = team_size_default;
268-
269271
//Compute beta = dE_i/dB_i for all i in list
270272
typename Kokkos::RangePolicy<DeviceType,TagPairSNAPBetaCPU> policy_beta(0,chunk_size);
271273
Kokkos::parallel_for("ComputeBetaCPU",policy_beta,*this);
272274

273-
//ZeroYi
274-
check_team_size_for<TagPairSNAPZeroYiCPU>(chunk_size,team_size,vector_length);
275-
typename Kokkos::TeamPolicy<DeviceType,TagPairSNAPZeroYiCPU> policy_zero_yi(((idxu_max+team_size-1)/team_size)*chunk_size,team_size,vector_length);
276-
Kokkos::parallel_for("ZeroYiCPU",policy_zero_yi,*this);
277-
278275
//ComputeYi
279276
int idxz_max = snaKK.idxz_max;
280277
typename Kokkos::RangePolicy<DeviceType,TagPairSNAPComputeYiCPU> policy_yi_cpu(0,chunk_size*idxz_max);
@@ -294,6 +291,7 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
294291

295292
Kokkos::parallel_for("ComputeDeidrjCPU",policy_deidrj_cpu,*this);
296293
}
294+
297295
} else { // GPU
298296

299297
#ifdef LMP_KOKKOS_GPU
@@ -313,10 +311,10 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
313311
int team_size = 4; // need to cap b/c of shared memory reqs
314312
check_team_size_for<TagPairSNAPComputeUi>(chunk_size,team_size,vector_length);
315313

316-
// scratch size: 2 * team_size * (twojmax+1)^2, to cover all `m1`,`m2` values
314+
// scratch size: 2 * team_size * (twojmax+1)^2, to cover all `m1`,`m2` values, div 2 for symmetry
317315
// 2 is for double buffer
318316

319-
const int tile_size = (twojmax+1)*(twojmax+1);
317+
const int tile_size = (twojmax+1)*(twojmax/2+1);
320318
typedef Kokkos::View< SNAcomplex*,
321319
Kokkos::DefaultExecutionSpace::scratch_memory_space,
322320
Kokkos::MemoryTraits<Kokkos::Unmanaged> >
@@ -329,7 +327,7 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
329327
Kokkos::parallel_for("ComputeUi",policy_ui,*this);
330328

331329
//Transform data layout of ulisttot to AoSoA, zero ylist
332-
typename Kokkos::MDRangePolicy<DeviceType, Kokkos::IndexType<int>, Kokkos::Rank<3, Kokkos::Iterate::Left, Kokkos::Iterate::Left>, TagPairSNAPTransformUi> policy_transform_ui({0,0,0},{32,idxu_max,(chunk_size + 32 - 1) / 32},{32,4,1});
330+
typename Kokkos::MDRangePolicy<DeviceType, Kokkos::IndexType<int>, Kokkos::Rank<3, Kokkos::Iterate::Left, Kokkos::Iterate::Left>, TagPairSNAPTransformUi> policy_transform_ui({0,0,0},{32,twojmax+1,(chunk_size + 32 - 1) / 32},{32,4,1});
333331
Kokkos::parallel_for("TransformUi",policy_transform_ui,*this);
334332

335333
}
@@ -367,7 +365,8 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
367365
Kokkos::parallel_for("ComputeYi",policy_compute_yi,*this);
368366

369367
//Transform data layout of ylist out of AoSoA
370-
typename Kokkos::MDRangePolicy<DeviceType, Kokkos::IndexType<int>, Kokkos::Rank<3, Kokkos::Iterate::Left, Kokkos::Iterate::Left>, TagPairSNAPTransformYi> policy_transform_yi({0,0,0},{32,idxu_max,(chunk_size + 32 - 1) / 32},{32,4,1});
368+
const int idxu_half_max = snaKK.idxu_half_max;
369+
typename Kokkos::MDRangePolicy<DeviceType, Kokkos::IndexType<int>, Kokkos::Rank<3, Kokkos::Iterate::Left, Kokkos::Iterate::Left>, TagPairSNAPTransformYi> policy_transform_yi({0,0,0},{32,idxu_half_max,(chunk_size + 32 - 1) / 32},{32,4,1});
371370
Kokkos::parallel_for("TransformYi",policy_transform_yi,*this);
372371

373372
}
@@ -397,7 +396,7 @@ void PairSNAPKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
397396
}
398397
}
399398

400-
#endif // KOKKOS_ENABLE_CUDA
399+
#endif // LMP_KOKKOS_GPU
401400

402401
}
403402

@@ -608,12 +607,21 @@ void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPComputeNeigh,const typen
608607

609608
if ( rsq < rnd_cutsq(itype,jtype) ) {
610609
if (final) {
611-
my_sna.rij(ii,offset,0) = dx;
612-
my_sna.rij(ii,offset,1) = dy;
613-
my_sna.rij(ii,offset,2) = dz;
610+
#ifdef LMP_KOKKOS_GPU
611+
if (std::is_same<DeviceType,Kokkos::Cuda>::value) {
612+
my_sna.compute_cayley_klein(ii, offset, dx, dy, dz, (radi + d_radelem[elem_j])*rcutfac,
613+
d_wjelem[elem_j]);
614+
} else {
615+
#endif
616+
my_sna.rij(ii,offset,0) = dx;
617+
my_sna.rij(ii,offset,1) = dy;
618+
my_sna.rij(ii,offset,2) = dz;
619+
my_sna.wj(ii,offset) = d_wjelem[elem_j];
620+
my_sna.rcutij(ii,offset) = (radi + d_radelem[elem_j])*rcutfac;
621+
#ifdef LMP_KOKKOS_GPU
622+
}
623+
#endif
614624
my_sna.inside(ii,offset) = j;
615-
my_sna.wj(ii,offset) = d_wjelem[elem_j];
616-
my_sna.rcutij(ii,offset) = (radi + d_radelem[elem_j])*rcutfac;
617625
if (chemflag)
618626
my_sna.element(ii,offset) = elem_j;
619627
else
@@ -704,27 +712,54 @@ void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPComputeUi,const typename
704712

705713
template<class DeviceType>
706714
KOKKOS_INLINE_FUNCTION
707-
void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPTransformUi,const int iatom_mod, const int idxu, const int iatom_div) const {
715+
void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPTransformUi,const int iatom_mod, const int j, const int iatom_div) const {
708716
SNAKokkos<DeviceType> my_sna = snaKK;
709717

710718
const int iatom = iatom_mod + iatom_div * 32;
711719
if (iatom >= chunk_size) return;
712720

713-
if (idxu >= my_sna.idxu_max) return;
721+
if (j > twojmax) return;
714722

715723
int elem_count = chemflag ? nelements : 1;
716724

717725
for (int ielem = 0; ielem < elem_count; ielem++) {
726+
const int jju_half = my_sna.idxu_half_block(j);
727+
const int jju = my_sna.idxu_block(j);
728+
729+
for (int mb = 0; 2*mb <= j; mb++) {
730+
for (int ma = 0; ma <= j; ma++) {
731+
// Extract top half
732+
733+
const int idxu_shift = mb * (j + 1) + ma;
734+
const int idxu_half = jju_half + idxu_shift;
735+
const int idxu = jju + idxu_shift;
736+
737+
auto utot_re = my_sna.ulisttot_re(idxu_half, ielem, iatom);
738+
auto utot_im = my_sna.ulisttot_im(idxu_half, ielem, iatom);
739+
740+
// Store
741+
my_sna.ulisttot_pack(iatom_mod, idxu, ielem, iatom_div) = { utot_re, utot_im };
742+
743+
// Also zero yi
744+
my_sna.ylist_pack_re(iatom_mod, idxu_half, ielem, iatom_div) = 0.;
745+
my_sna.ylist_pack_im(iatom_mod, idxu_half, ielem, iatom_div) = 0.;
746+
747+
// Symmetric term
748+
const int sign_factor = (((ma+mb)%2==0)?1:-1);
749+
const int idxu_flip = jju + (j + 1 - mb) * (j + 1) - (ma + 1);
750+
751+
if (sign_factor == 1) {
752+
utot_im = -utot_im;
753+
} else {
754+
utot_re = -utot_re;
755+
}
718756

719-
const auto utot_re = my_sna.ulisttot_re(idxu, ielem, iatom);
720-
const auto utot_im = my_sna.ulisttot_im(idxu, ielem, iatom);
721-
722-
my_sna.ulisttot_pack(iatom_mod, idxu, ielem, iatom_div) = { utot_re, utot_im };
757+
my_sna.ulisttot_pack(iatom_mod, idxu_flip, ielem, iatom_div) = { utot_re, utot_im };
723758

724-
my_sna.ylist_pack_re(iatom_mod, idxu, ielem, iatom_div) = 0.;
725-
my_sna.ylist_pack_im(iatom_mod, idxu, ielem, iatom_div) = 0.;
759+
// No need to zero symmetrized ylist
760+
}
761+
}
726762
}
727-
728763
}
729764

730765
template<class DeviceType>
@@ -742,20 +777,20 @@ void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPComputeYi,const int iato
742777

743778
template<class DeviceType>
744779
KOKKOS_INLINE_FUNCTION
745-
void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPTransformYi,const int iatom_mod, const int idxu, const int iatom_div) const {
780+
void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPTransformYi,const int iatom_mod, const int idxu_half, const int iatom_div) const {
746781
SNAKokkos<DeviceType> my_sna = snaKK;
747782

748783
const int iatom = iatom_mod + iatom_div * 32;
749784
if (iatom >= chunk_size) return;
750785

751-
if (idxu >= my_sna.idxu_max) return;
786+
if (idxu_half >= my_sna.idxu_half_max) return;
752787

753788
int elem_count = chemflag ? nelements : 1;
754789
for (int ielem = 0; ielem < elem_count; ielem++) {
755-
const auto y_re = my_sna.ylist_pack_re(iatom_mod, idxu, ielem, iatom_div);
756-
const auto y_im = my_sna.ylist_pack_im(iatom_mod, idxu, ielem, iatom_div);
790+
const auto y_re = my_sna.ylist_pack_re(iatom_mod, idxu_half, ielem, iatom_div);
791+
const auto y_im = my_sna.ylist_pack_im(iatom_mod, idxu_half, ielem, iatom_div);
757792

758-
my_sna.ylist(idxu, ielem, iatom) = { y_re, y_im };
793+
my_sna.ylist(idxu_half, ielem, iatom) = { y_re, y_im };
759794
}
760795

761796
}
@@ -904,22 +939,52 @@ void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPComputeUiCPU,const typen
904939

905940
template<class DeviceType>
906941
KOKKOS_INLINE_FUNCTION
907-
void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPZeroYiCPU,const typename Kokkos::TeamPolicy<DeviceType,TagPairSNAPZeroYiCPU>::member_type& team) const {
942+
void PairSNAPKokkos<DeviceType>::operator() (TagPairSNAPTransformUiCPU, const int j, const int iatom) const {
908943
SNAKokkos<DeviceType> my_sna = snaKK;
909944

910-
// Extract the quantum number
911-
const int idx = team.team_rank() + team.team_size() * (team.league_rank() % ((my_sna.idxu_max+team.team_size()-1)/team.team_size()));
912-
if (idx >= my_sna.idxu_max) return;
945+
if (iatom >= chunk_size) return;
913946

914-
// Extract the atomic index
915-
const int ii = team.league_rank() / ((my_sna.idxu_max+team.team_size()-1)/team.team_size());
916-
if (ii >= chunk_size) return;
947+
if (j > twojmax) return;
948+
949+
int elem_count = chemflag ? nelements : 1;
950+
951+
// De-symmetrize ulisttot
952+
for (int ielem = 0; ielem < elem_count; ielem++) {
953+
954+
const int jju_half = my_sna.idxu_half_block(j);
955+
const int jju = my_sna.idxu_block(j);
956+
957+
for (int mb = 0; 2*mb <= j; mb++) {
958+
for (int ma = 0; ma <= j; ma++) {
959+
// Extract top half
960+
961+
const int idxu_shift = mb * (j + 1) + ma;
962+
const int idxu_half = jju_half + idxu_shift;
963+
const int idxu = jju + idxu_shift;
917964

918-
if (chemflag)
919-
for(int ielem = 0; ielem < nelements; ielem++)
920-
my_sna.zero_yi_cpu(idx,ii,ielem);
921-
else
922-
my_sna.zero_yi_cpu(idx,ii,0);
965+
// Load ulist
966+
auto utot = my_sna.ulisttot(idxu_half, ielem, iatom);
967+
968+
// Store
969+
my_sna.ulisttot_full(idxu, ielem, iatom) = utot;
970+
971+
// Zero Yi
972+
my_sna.ylist(idxu_half, ielem, iatom) = {0., 0.};
973+
974+
// Symmetric term
975+
const int sign_factor = (((ma+mb)%2==0)?1:-1);
976+
const int idxu_flip = jju + (j + 1 - mb) * (j + 1) - (ma + 1);
977+
978+
if (sign_factor == 1) {
979+
utot.im = -utot.im;
980+
} else {
981+
utot.re = -utot.re;
982+
}
983+
984+
my_sna.ulisttot_full(idxu_flip, ielem, iatom) = utot;
985+
}
986+
}
987+
}
923988
}
924989

925990
template<class DeviceType>

0 commit comments

Comments
 (0)