Permalink
Browse files

ENH: einsum: Specialize contiguous reduction, add SSE prefetching

Also fix some compiler warnings. The biggest performance improvement
was from adding SSE prefetching.
  • Loading branch information...
1 parent 8598315 commit 260824fe05b1a314d67420669ee0d012c072c064 @mwiebe mwiebe committed Feb 10, 2011
Showing with 239 additions and 6 deletions.
  1. +239 −6 numpy/core/src/multiarray/einsum.c.src
@@ -169,9 +169,9 @@ static void
#define _SUMPROD_NOP nop
# endif
npy_@temp@ re, im, tmp;
+ int i;
re = ((npy_@temp@ *)dataptr[0])[0];
im = ((npy_@temp@ *)dataptr[0])[1];
- int i;
for (i = 1; i < _SUMPROD_NOP; ++i) {
tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
im * ((npy_@temp@ *)dataptr[i])[1];
@@ -202,7 +202,8 @@ static void
npy_@name@ *data0 = (npy_@name@ *)dataptr[0];
npy_@name@ *data_out = (npy_@name@ *)dataptr[1];
- NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_one (%d)\n", (int)count);
+ NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_one (%d)\n",
+ (int)count);
/* This is placed before the main loop to make small counts faster */
finish_after_unrolled_loop:
@@ -268,7 +269,8 @@ static void
__m128 a, b;
#endif
- NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_two (%d)\n", (int)count);
+ NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_two (%d)\n",
+ (int)count);
/* This is placed before the main loop to make small counts faster */
finish_after_unrolled_loop:
@@ -592,6 +594,9 @@ finish_after_unrolled_loop:
while (count >= 8) {
count -= 8;
+ _mm_prefetch(data0 + 512, _MM_HINT_T0);
+ _mm_prefetch(data1 + 512, _MM_HINT_T0);
+
/**begin repeat2
* #i = 0, 4#
*/
@@ -623,6 +628,9 @@ finish_after_unrolled_loop:
while (count >= 8) {
count -= 8;
+ _mm_prefetch(data0 + 512, _MM_HINT_T0);
+ _mm_prefetch(data1 + 512, _MM_HINT_T0);
+
/**begin repeat2
* #i = 0, 2, 4, 6#
*/
@@ -652,6 +660,9 @@ finish_after_unrolled_loop:
count -= 8;
#if EINSUM_USE_SSE1 && @float32@
+ _mm_prefetch(data0 + 512, _MM_HINT_T0);
+ _mm_prefetch(data1 + 512, _MM_HINT_T0);
+
/**begin repeat2
* #i = 0, 4#
*/
@@ -663,6 +674,9 @@ finish_after_unrolled_loop:
accum_sse = _mm_add_ps(accum_sse, a);
/**end repeat2**/
#elif EINSUM_USE_SSE2 && @float64@
+ _mm_prefetch(data0 + 512, _MM_HINT_T0);
+ _mm_prefetch(data1 + 512, _MM_HINT_T0);
+
/**begin repeat2
* #i = 0, 2, 4, 6#
*/
@@ -943,7 +957,7 @@ static void
/**end repeat2**/
}
-#else
+#else /* @nop@ > 3 || @complex */
static void
@name@_sum_of_products_contig_@noplabel@(int nop, char **dataptr,
@@ -971,9 +985,9 @@ static void
# define _SUMPROD_NOP nop
# endif
npy_@temp@ re, im, tmp;
+ int i;
re = ((npy_@temp@ *)dataptr[0])[0];
im = ((npy_@temp@ *)dataptr[0])[1];
- int i;
for (i = 1; i < _SUMPROD_NOP; ++i) {
tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
im * ((npy_@temp@ *)dataptr[i])[1];
@@ -994,7 +1008,186 @@ static void
}
}
+#endif /* functions for various @nop@ */
+
+#if @nop@ == 1
+
+static void
+@name@_sum_of_products_contig_outstride0_one(int nop, char **dataptr,
+ npy_intp *strides, npy_intp count)
+{
+#if @complex@
+ npy_@temp@ accum_re = 0, accum_im = 0;
+ npy_@temp@ *data0 = (npy_@temp@ *)dataptr[0];
+#else
+ npy_@temp@ accum = 0;
+ npy_@name@ *data0 = (npy_@name@ *)dataptr[0];
+#endif
+
+#if EINSUM_USE_SSE1 && @float32@
+ __m128 a, accum_sse = _mm_setzero_ps();
+#elif EINSUM_USE_SSE2 && @float64@
+ __m128d a, accum_sse = _mm_setzero_pd();
+#endif
+
+
+ NPY_EINSUM_DBG_PRINTF("@name@_sum_of_products_contig_outstride0_one (%d)\n",
+ (int)count);
+
+/* This is placed before the main loop to make small counts faster */
+finish_after_unrolled_loop:
+ switch (count) {
+/**begin repeat2
+ * #i = 6, 5, 4, 3, 2, 1, 0#
+ */
+ case @i@+1:
+#if !@complex@
+ accum += @from@(data0[@i@]);
+#else /* complex */
+ accum_re += data0[2*@i@+0];
+ accum_im += data0[2*@i@+1];
+#endif
+/**end repeat2**/
+ case 0:
+#if @complex@
+ ((npy_@temp@ *)dataptr[1])[0] += accum_re;
+ ((npy_@temp@ *)dataptr[1])[1] += accum_im;
+#else
+ *((npy_@name@ *)dataptr[1]) = @to@(accum +
+ @from@(*((npy_@name@ *)dataptr[1])));
+#endif
+ return;
+ }
+
+#if EINSUM_USE_SSE1 && @float32@
+ /* Use aligned instructions if possible */
+ if (EINSUM_IS_SSE_ALIGNED(data0)) {
+ /* Unroll the loop by 8 */
+ while (count >= 8) {
+ count -= 8;
+
+ _mm_prefetch(data0 + 512, _MM_HINT_T0);
+
+/**begin repeat2
+ * #i = 0, 4#
+ */
+ /*
+ * NOTE: This accumulation changes the order, so will likely
+ * produce slightly different results.
+ */
+ accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data0+@i@));
+/**end repeat2**/
+ data0 += 8;
+ }
+
+ /* Add the four SSE values and put in accum */
+ a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1));
+ accum_sse = _mm_add_ps(a, accum_sse);
+ a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2));
+ accum_sse = _mm_add_ps(a, accum_sse);
+ _mm_store_ss(&accum, accum_sse);
+
+ /* Finish off the loop */
+ goto finish_after_unrolled_loop;
+ }
+#elif EINSUM_USE_SSE2 && @float64@
+ /* Use aligned instructions if possible */
+ if (EINSUM_IS_SSE_ALIGNED(data0)) {
+ /* Unroll the loop by 8 */
+ while (count >= 8) {
+ count -= 8;
+
+ _mm_prefetch(data0 + 512, _MM_HINT_T0);
+
+/**begin repeat2
+ * #i = 0, 2, 4, 6#
+ */
+ /*
+ * NOTE: This accumulation changes the order, so will likely
+ * produce slightly different results.
+ */
+ accum_sse = _mm_add_pd(accum_sse, _mm_load_pd(data0+@i@));
+/**end repeat2**/
+ data0 += 8;
+ }
+
+ /* Add the two SSE2 values and put in accum */
+ a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1));
+ accum_sse = _mm_add_pd(a, accum_sse);
+ _mm_store_sd(&accum, accum_sse);
+
+ /* Finish off the loop */
+ goto finish_after_unrolled_loop;
+ }
+#endif
+
+ /* Unroll the loop by 8 */
+ while (count >= 8) {
+ count -= 8;
+
+#if EINSUM_USE_SSE1 && @float32@
+ _mm_prefetch(data0 + 512, _MM_HINT_T0);
+
+/**begin repeat2
+ * #i = 0, 4#
+ */
+ /*
+ * NOTE: This accumulation changes the order, so will likely
+ * produce slightly different results.
+ */
+ accum_sse = _mm_add_ps(accum_sse, _mm_loadu_ps(data0+@i@));
+/**end repeat2**/
+#elif EINSUM_USE_SSE2 && @float64@
+ _mm_prefetch(data0 + 512, _MM_HINT_T0);
+
+/**begin repeat2
+ * #i = 0, 2, 4, 6#
+ */
+ /*
+ * NOTE: This accumulation changes the order, so will likely
+ * produce slightly different results.
+ */
+ accum_sse = _mm_add_pd(accum_sse, _mm_loadu_pd(data0+@i@));
+/**end repeat2**/
+#else
+/**begin repeat2
+ * #i = 0, 1, 2, 3, 4, 5, 6, 7#
+ */
+# if !@complex@
+ accum += @from@(data0[@i@]);
+# else /* complex */
+ accum_re += data0[2*@i@+0];
+ accum_im += data0[2*@i@+1];
+# endif
+/**end repeat2**/
+#endif
+
+#if !@complex@
+ data0 += 8;
+#else
+ data0 += 8*2;
#endif
+ }
+
+#if EINSUM_USE_SSE1 && @float32@
+ /* Add the four SSE values and put in accum */
+ a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1));
+ accum_sse = _mm_add_ps(a, accum_sse);
+ a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2));
+ accum_sse = _mm_add_ps(a, accum_sse);
+ _mm_store_ss(&accum, accum_sse);
+#elif EINSUM_USE_SSE2 && @float64@
+ /* Add the two SSE2 values and put in accum */
+ a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1));
+ accum_sse = _mm_add_pd(a, accum_sse);
+ _mm_store_sd(&accum, accum_sse);
+#endif
+
+ /* Finish off the loop */
+ goto finish_after_unrolled_loop;
+}
+
+#endif /* @nop@ == 1 */
static void
@name@_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr,
@@ -1062,9 +1255,9 @@ static void
#define _SUMPROD_NOP nop
# endif
npy_@temp@ re, im, tmp;
+ int i;
re = ((npy_@temp@ *)dataptr[0])[0];
im = ((npy_@temp@ *)dataptr[0])[1];
- int i;
for (i = 1; i < _SUMPROD_NOP; ++i) {
tmp = re * ((npy_@temp@ *)dataptr[i])[0] -
im * ((npy_@temp@ *)dataptr[i])[1];
@@ -1347,6 +1540,37 @@ bool_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr,
typedef void (*sum_of_products_fn)(int, char **, npy_intp *, npy_intp);
/* These tables need to match up with the type enum */
+static sum_of_products_fn
+_contig_outstride0_unary_specialization_table[NPY_NTYPES] = {
+/**begin repeat
+ * #name = bool,
+ * byte, ubyte,
+ * short, ushort,
+ * int, uint,
+ * long, ulong,
+ * longlong, ulonglong,
+ * float, double, longdouble,
+ * cfloat, cdouble, clongdouble,
+ * object, string, unicode, void,
+ * datetime, timedelta, half#
+ * #use = 0,
+ * 1, 1,
+ * 1, 1,
+ * 1, 1,
+ * 1, 1,
+ * 1, 1,
+ * 1, 1, 1,
+ * 1, 1, 1,
+ * 0, 0, 0, 0,
+ * 0, 0, 1#
+ */
+#if @use@
+ &@name@_sum_of_products_contig_outstride0_one,
+#else
+ NULL,
+#endif
+/**end repeat**/
+}; /* End of _contig_outstride0_unary_specialization_table */
static sum_of_products_fn _binary_specialization_table[NPY_NTYPES][5] = {
/**begin repeat
@@ -1503,6 +1727,15 @@ get_sum_of_products_function(int nop, int type_num,
return NULL;
}
+ /* contiguous reduction */
+ if (nop == 1 && fixed_strides[0] == itemsize && fixed_strides[1] == 0) {
+ sum_of_products_fn ret =
+ _contig_outstride0_unary_specialization_table[type_num];
+ if (ret != NULL) {
+ return ret;
+ }
+ }
+
/* nop of 2 has more specializations */
if (nop == 2) {
/* Encode the zero/contiguous strides */

0 comments on commit 260824f

Please sign in to comment.