Skip to content

Commit

Permalink
grid: Force compiler to generate branches for low values of lp
Browse files Browse the repository at this point in the history
  • Loading branch information
oschuett committed Mar 2, 2021
1 parent ada15fe commit b7bb701
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 39 deletions.
8 changes: 8 additions & 0 deletions src/grid/common/grid_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
#ifndef GRID_COMMON_H
#define GRID_COMMON_H

#define GRID_STRINGIFY(SYMBOL) #SYMBOL

#if defined(__GNUC__)
#define GRID_PRAGMA_UNROLL(N) _Pragma(GRID_STRINGIFY(GCC unroll N))
#else
#define GRID_PRAGMA_UNROLL(N) _Pragma(GRID_STRINGIFY(unroll(N)))
#endif

#if defined(__CUDACC__)
#define GRID_HOST_DEVICE __host__ __device__
#else
Expand Down
67 changes: 28 additions & 39 deletions src/grid/ref/grid_ref_collint.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
* \brief Collocates registers reg onto the grid for orthorhombic case.
* \author Ole Schuett
******************************************************************************/
static inline void ortho_reg_to_grid(const int kg1, const int kg2,
const int jg1, const int jg2,
const int ig1, const int ig2,
const int npts_local[3],
GRID_CONST_WHEN_COLLOCATE double *reg,
GRID_CONST_WHEN_INTEGRATE double *grid) {
static inline void __attribute__((always_inline))
ortho_reg_to_grid(const int kg1, const int kg2, const int jg1, const int jg2,
const int ig1, const int ig2, const int npts_local[3],
GRID_CONST_WHEN_COLLOCATE double *reg,
GRID_CONST_WHEN_INTEGRATE double *grid) {

const int stride = npts_local[1] * npts_local[0];
const int grid_index_0 = kg1 * stride + jg1 * npts_local[0] + ig1;
Expand Down Expand Up @@ -72,10 +71,11 @@ static inline void ortho_reg_to_grid(const int kg1, const int kg2,
* \brief Transforms coefficients C_x into registers reg by fixing grid index i.
* \author Ole Schuett
******************************************************************************/
static inline void ortho_cx_to_reg(const int lp, const double pol_i1[lp + 1],
const double pol_i2[lp + 1],
GRID_CONST_WHEN_COLLOCATE double *cx,
GRID_CONST_WHEN_INTEGRATE double *reg) {
static inline void __attribute__((always_inline))
ortho_cx_to_reg(const int lp, const double pol_i1[lp + 1],
const double pol_i2[lp + 1],
GRID_CONST_WHEN_COLLOCATE double *cx,
GRID_CONST_WHEN_INTEGRATE double *reg) {

for (int lxp = 0; lxp <= lp; lxp++) {
const double p1 = pol_i1[lxp];
Expand Down Expand Up @@ -109,7 +109,7 @@ static inline void ortho_cx_to_reg(const int lp, const double pol_i1[lp + 1],
* \brief Collocates coefficients C_x onto the grid for orthorhombic case.
* \author Ole Schuett
******************************************************************************/
static inline void
static inline void __attribute__((always_inline))
ortho_cx_to_grid(const int lp, const int kg1, const int kg2, const int jg1,
const int jg2, const int cmax,
const double pol[3][2 * cmax + 1][lp + 1],
Expand Down Expand Up @@ -142,10 +142,11 @@ ortho_cx_to_grid(const int lp, const int kg1, const int kg2, const int jg1,
* \brief Transforms coefficients C_xy into C_x by fixing grid index j.
* \author Ole Schuett
******************************************************************************/
static inline void ortho_cxy_to_cx(const int lp, const double pol_j1[lp + 1],
const double pol_j2[lp + 1],
GRID_CONST_WHEN_COLLOCATE double *cxy,
GRID_CONST_WHEN_INTEGRATE double *cx) {
static inline void __attribute__((always_inline))
ortho_cxy_to_cx(const int lp, const double pol_j1[lp + 1],
const double pol_j2[lp + 1],
GRID_CONST_WHEN_COLLOCATE double *cxy,
GRID_CONST_WHEN_INTEGRATE double *cx) {

for (int lyp = 0; lyp <= lp; lyp++) {
for (int lxp = 0; lxp <= lp - lyp; lxp++) {
Expand All @@ -172,7 +173,7 @@ static inline void ortho_cxy_to_cx(const int lp, const double pol_j1[lp + 1],
* \brief Loop body of ortho_cxy_to_grid to be inlined for low values of lp.
* \author Ole Schuett
******************************************************************************/
static inline void
static inline void __attribute__((always_inline))
ortho_cxy_to_grid_low(const int lp, const int j1, const int j2, const int kg1,
const int kg2, const int jg1, const int jg2,
const int cmax, const double pol[3][2 * cmax + 1][lp + 1],
Expand Down Expand Up @@ -222,29 +223,17 @@ static inline void ortho_cxy_to_grid(const int lp, const int kg1, const int kg2,

memset(cx, 0, cx_size * sizeof(double));

// Hopefully the compiler will inline optimized branches for low lp values.
switch (lp) {
case (0):
ortho_cxy_to_grid_low(0, j1, j2, kg1, kg2, jg1, jg2, cmax, pol, map,
npts_local, sphere_bounds_iter, cx, cxy, grid);
break;
case (1):
ortho_cxy_to_grid_low(1, j1, j2, kg1, kg2, jg1, jg2, cmax, pol, map,
npts_local, sphere_bounds_iter, cx, cxy, grid);
break;
case (2):
ortho_cxy_to_grid_low(2, j1, j2, kg1, kg2, jg1, jg2, cmax, pol, map,
npts_local, sphere_bounds_iter, cx, cxy, grid);
break;
case (3):
ortho_cxy_to_grid_low(3, j1, j2, kg1, kg2, jg1, jg2, cmax, pol, map,
npts_local, sphere_bounds_iter, cx, cxy, grid);
break;
case (4):
ortho_cxy_to_grid_low(4, j1, j2, kg1, kg2, jg1, jg2, cmax, pol, map,
npts_local, sphere_bounds_iter, cx, cxy, grid);
break;
default:
// Generate separate branches for low values of lp gives up to 30% speedup.
const int MAX_LP_OPTIMIZED = 9;
if (lp <= MAX_LP_OPTIMIZED) {
GRID_PRAGMA_UNROLL(MAX_LP_OPTIMIZED + 1)
for (int ilp = 0; ilp <= MAX_LP_OPTIMIZED; ilp++) {
if (lp == ilp) {
ortho_cxy_to_grid_low(ilp, j1, j2, kg1, kg2, jg1, jg2, cmax, pol, map,
npts_local, sphere_bounds_iter, cx, cxy, grid);
}
}
} else {
ortho_cxy_to_grid_low(lp, j1, j2, kg1, kg2, jg1, jg2, cmax, pol, map,
npts_local, sphere_bounds_iter, cx, cxy, grid);
}
Expand Down

0 comments on commit b7bb701

Please sign in to comment.