Skip to content

Commit

Permalink
ENH: einsum: Specialize contiguous reduction, add SSE prefetching
Browse files Browse the repository at this point in the history
Also fix some compiler warnings. The biggest performance improvement
was from adding SSE prefetching.
  • Loading branch information
mwiebe committed Feb 10, 2011
1 parent 8598315 commit 260824f
Showing 1 changed file with 239 additions and 6 deletions.
245 changes: 239 additions & 6 deletions numpy/core/src/multiarray/einsum.c.src
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ static void
#define _SUMPROD_NOP nop #define _SUMPROD_NOP nop
# endif # endif
npy_@temp@ re, im, tmp; npy_@temp@ re, im, tmp;
int i;
re = ((npy_@temp@ *)dataptr[0])[0]; re = ((npy_@temp@ *)dataptr[0])[0];
im = ((npy_@temp@ *)dataptr[0])[1]; im = ((npy_@temp@ *)dataptr[0])[1];
int i;
for (i = 1; i < _SUMPROD_NOP; ++i) { for (i = 1; i < _SUMPROD_NOP; ++i) {
tmp = re * ((npy_@temp@ *)dataptr[i])[0] - tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
im * ((npy_@temp@ *)dataptr[i])[1]; im * ((npy_@temp@ *)dataptr[i])[1];
Expand Down Expand Up @@ -202,7 +202,8 @@ static void
npy_@name@ *data0 = (npy_@name@ *)dataptr[0]; npy_@name@ *data0 = (npy_@name@ *)dataptr[0];
npy_@name@ *data_out = (npy_@name@ *)dataptr[1]; npy_@name@ *data_out = (npy_@name@ *)dataptr[1];


NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_one (%d)\n", (int)count); NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_one (%d)\n",
(int)count);


/* This is placed before the main loop to make small counts faster */ /* This is placed before the main loop to make small counts faster */
finish_after_unrolled_loop: finish_after_unrolled_loop:
Expand Down Expand Up @@ -268,7 +269,8 @@ static void
__m128 a, b; __m128 a, b;
#endif #endif


NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_two (%d)\n", (int)count); NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_two (%d)\n",
(int)count);


/* This is placed before the main loop to make small counts faster */ /* This is placed before the main loop to make small counts faster */
finish_after_unrolled_loop: finish_after_unrolled_loop:
Expand Down Expand Up @@ -592,6 +594,9 @@ finish_after_unrolled_loop:
while (count >= 8) { while (count >= 8) {
count -= 8; count -= 8;


_mm_prefetch(data0 + 512, _MM_HINT_T0);
_mm_prefetch(data1 + 512, _MM_HINT_T0);

/**begin repeat2 /**begin repeat2
* #i = 0, 4# * #i = 0, 4#
*/ */
Expand Down Expand Up @@ -623,6 +628,9 @@ finish_after_unrolled_loop:
while (count >= 8) { while (count >= 8) {
count -= 8; count -= 8;


_mm_prefetch(data0 + 512, _MM_HINT_T0);
_mm_prefetch(data1 + 512, _MM_HINT_T0);

/**begin repeat2 /**begin repeat2
* #i = 0, 2, 4, 6# * #i = 0, 2, 4, 6#
*/ */
Expand Down Expand Up @@ -652,6 +660,9 @@ finish_after_unrolled_loop:
count -= 8; count -= 8;


#if EINSUM_USE_SSE1 && @float32@ #if EINSUM_USE_SSE1 && @float32@
_mm_prefetch(data0 + 512, _MM_HINT_T0);
_mm_prefetch(data1 + 512, _MM_HINT_T0);

/**begin repeat2 /**begin repeat2
* #i = 0, 4# * #i = 0, 4#
*/ */
Expand All @@ -663,6 +674,9 @@ finish_after_unrolled_loop:
accum_sse = _mm_add_ps(accum_sse, a); accum_sse = _mm_add_ps(accum_sse, a);
/**end repeat2**/ /**end repeat2**/
#elif EINSUM_USE_SSE2 && @float64@ #elif EINSUM_USE_SSE2 && @float64@
_mm_prefetch(data0 + 512, _MM_HINT_T0);
_mm_prefetch(data1 + 512, _MM_HINT_T0);

/**begin repeat2 /**begin repeat2
* #i = 0, 2, 4, 6# * #i = 0, 2, 4, 6#
*/ */
Expand Down Expand Up @@ -943,7 +957,7 @@ static void
/**end repeat2**/ /**end repeat2**/
} }


#else #else /* @nop@ > 3 || @complex */


static void static void
@name@_sum_of_products_contig_@noplabel@(int nop, char **dataptr, @name@_sum_of_products_contig_@noplabel@(int nop, char **dataptr,
Expand Down Expand Up @@ -971,9 +985,9 @@ static void
# define _SUMPROD_NOP nop # define _SUMPROD_NOP nop
# endif # endif
npy_@temp@ re, im, tmp; npy_@temp@ re, im, tmp;
int i;
re = ((npy_@temp@ *)dataptr[0])[0]; re = ((npy_@temp@ *)dataptr[0])[0];
im = ((npy_@temp@ *)dataptr[0])[1]; im = ((npy_@temp@ *)dataptr[0])[1];
int i;
for (i = 1; i < _SUMPROD_NOP; ++i) { for (i = 1; i < _SUMPROD_NOP; ++i) {
tmp = re * ((npy_@temp@ *)dataptr[i])[0] - tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
im * ((npy_@temp@ *)dataptr[i])[1]; im * ((npy_@temp@ *)dataptr[i])[1];
Expand All @@ -994,7 +1008,186 @@ static void
} }
} }


#endif /* functions for various @nop@ */

#if @nop@ == 1

static void
@name@_sum_of_products_contig_outstride0_one(int nop, char **dataptr,
npy_intp *strides, npy_intp count)
{
#if @complex@
npy_@temp@ accum_re = 0, accum_im = 0;
npy_@temp@ *data0 = (npy_@temp@ *)dataptr[0];
#else
npy_@temp@ accum = 0;
npy_@name@ *data0 = (npy_@name@ *)dataptr[0];
#endif

#if EINSUM_USE_SSE1 && @float32@
__m128 a, accum_sse = _mm_setzero_ps();
#elif EINSUM_USE_SSE2 && @float64@
__m128d a, accum_sse = _mm_setzero_pd();
#endif


NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_outstride0_one (%d)\n",
(int)count);

/* This is placed before the main loop to make small counts faster */
finish_after_unrolled_loop:
switch (count) {
/**begin repeat2
* #i = 6, 5, 4, 3, 2, 1, 0#
*/
case @i@+1:
#if !@complex@
accum += @from@(data0[@i@]);
#else /* complex */
accum_re += data0[2*@i@+0];
accum_im += data0[2*@i@+1];
#endif
/**end repeat2**/
case 0:
#if @complex@
((npy_@temp@ *)dataptr[1])[0] += accum_re;
((npy_@temp@ *)dataptr[1])[1] += accum_im;
#else
*((npy_@name@ *)dataptr[1]) = @to@(accum +
@from@(*((npy_@name@ *)dataptr[1])));
#endif
return;
}

#if EINSUM_USE_SSE1 && @float32@
/* Use aligned instructions if possible */
if (EINSUM_IS_SSE_ALIGNED(data0)) {
/* Unroll the loop by 8 */
while (count >= 8) {
count -= 8;

_mm_prefetch(data0 + 512, _MM_HINT_T0);

/**begin repeat2
* #i = 0, 4#
*/
/*
* NOTE: This accumulation changes the order, so will likely
* produce slightly different results.
*/
accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data0+@i@));
/**end repeat2**/
data0 += 8;
}

/* Add the four SSE values and put in accum */
a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1));
accum_sse = _mm_add_ps(a, accum_sse);
a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2));
accum_sse = _mm_add_ps(a, accum_sse);
_mm_store_ss(&accum, accum_sse);

/* Finish off the loop */
goto finish_after_unrolled_loop;
}
#elif EINSUM_USE_SSE2 && @float64@
/* Use aligned instructions if possible */
if (EINSUM_IS_SSE_ALIGNED(data0)) {
/* Unroll the loop by 8 */
while (count >= 8) {
count -= 8;

_mm_prefetch(data0 + 512, _MM_HINT_T0);

/**begin repeat2
* #i = 0, 2, 4, 6#
*/
/*
* NOTE: This accumulation changes the order, so will likely
* produce slightly different results.
*/
accum_sse = _mm_add_pd(accum_sse, _mm_load_pd(data0+@i@));
/**end repeat2**/
data0 += 8;
}

/* Add the two SSE2 values and put in accum */
a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1));
accum_sse = _mm_add_pd(a, accum_sse);
_mm_store_sd(&accum, accum_sse);

/* Finish off the loop */
goto finish_after_unrolled_loop;
}
#endif

/* Unroll the loop by 8 */
while (count >= 8) {
count -= 8;

#if EINSUM_USE_SSE1 && @float32@
_mm_prefetch(data0 + 512, _MM_HINT_T0);

/**begin repeat2
* #i = 0, 4#
*/
/*
* NOTE: This accumulation changes the order, so will likely
* produce slightly different results.
*/
accum_sse = _mm_add_ps(accum_sse, _mm_loadu_ps(data0+@i@));
/**end repeat2**/
#elif EINSUM_USE_SSE2 && @float64@
_mm_prefetch(data0 + 512, _MM_HINT_T0);

/**begin repeat2
* #i = 0, 2, 4, 6#
*/
/*
* NOTE: This accumulation changes the order, so will likely
* produce slightly different results.
*/
accum_sse = _mm_add_pd(accum_sse, _mm_loadu_pd(data0+@i@));
/**end repeat2**/
#else
/**begin repeat2
* #i = 0, 1, 2, 3, 4, 5, 6, 7#
*/
# if !@complex@
accum += @from@(data0[@i@]);
# else /* complex */
accum_re += data0[2*@i@+0];
accum_im += data0[2*@i@+1];
# endif
/**end repeat2**/
#endif

#if !@complex@
data0 += 8;
#else
data0 += 8*2;
#endif #endif
}

#if EINSUM_USE_SSE1 && @float32@
/* Add the four SSE values and put in accum */
a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1));
accum_sse = _mm_add_ps(a, accum_sse);
a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2));
accum_sse = _mm_add_ps(a, accum_sse);
_mm_store_ss(&accum, accum_sse);
#elif EINSUM_USE_SSE2 && @float64@
/* Add the two SSE2 values and put in accum */
a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1));
accum_sse = _mm_add_pd(a, accum_sse);
_mm_store_sd(&accum, accum_sse);
#endif

/* Finish off the loop */
goto finish_after_unrolled_loop;
}

#endif /* @nop@ == 1 */


static void static void
@name@_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr, @name@_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr,
Expand Down Expand Up @@ -1062,9 +1255,9 @@ static void
#define _SUMPROD_NOP nop #define _SUMPROD_NOP nop
# endif # endif
npy_@temp@ re, im, tmp; npy_@temp@ re, im, tmp;
int i;
re = ((npy_@temp@ *)dataptr[0])[0]; re = ((npy_@temp@ *)dataptr[0])[0];
im = ((npy_@temp@ *)dataptr[0])[1]; im = ((npy_@temp@ *)dataptr[0])[1];
int i;
for (i = 1; i < _SUMPROD_NOP; ++i) { for (i = 1; i < _SUMPROD_NOP; ++i) {
tmp = re * ((npy_@temp@ *)dataptr[i])[0] - tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
im * ((npy_@temp@ *)dataptr[i])[1]; im * ((npy_@temp@ *)dataptr[i])[1];
Expand Down Expand Up @@ -1347,6 +1540,37 @@ bool_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr,
typedef void (*sum_of_products_fn)(int, char **, npy_intp *, npy_intp); typedef void (*sum_of_products_fn)(int, char **, npy_intp *, npy_intp);


/* These tables need to match up with the type enum */ /* These tables need to match up with the type enum */
static sum_of_products_fn
_contig_outstride0_unary_specialization_table[NPY_NTYPES] = {
/**begin repeat
* #name = bool,
* byte, ubyte,
* short, ushort,
* int, uint,
* long, ulong,
* longlong, ulonglong,
* float, double, longdouble,
* cfloat, cdouble, clongdouble,
* object, string, unicode, void,
* datetime, timedelta, half#
* #use = 0,
* 1, 1,
* 1, 1,
* 1, 1,
* 1, 1,
* 1, 1,
* 1, 1, 1,
* 1, 1, 1,
* 0, 0, 0, 0,
* 0, 0, 1#
*/
#if @use@
&@name@_sum_of_products_contig_outstride0_one,
#else
NULL,
#endif
/**end repeat**/
}; /* End of _contig_outstride0_unary_specialization_table */


static sum_of_products_fn _binary_specialization_table[NPY_NTYPES][5] = { static sum_of_products_fn _binary_specialization_table[NPY_NTYPES][5] = {
/**begin repeat /**begin repeat
Expand Down Expand Up @@ -1503,6 +1727,15 @@ get_sum_of_products_function(int nop, int type_num,
return NULL; return NULL;
} }


/* contiguous reduction */
if (nop == 1 && fixed_strides[0] == itemsize && fixed_strides[1] == 0) {
sum_of_products_fn ret =
_contig_outstride0_unary_specialization_table[type_num];
if (ret != NULL) {
return ret;
}
}

/* nop of 2 has more specializations */ /* nop of 2 has more specializations */
if (nop == 2) { if (nop == 2) {
/* Encode the zero/contiguous strides */ /* Encode the zero/contiguous strides */
Expand Down

0 comments on commit 260824f

Please sign in to comment.