Skip to content

Commit

Permalink
dot_rev functions and uses in nmod_poly/{mul,inv}
Browse files Browse the repository at this point in the history
  • Loading branch information
vneiger committed Jun 20, 2024
1 parent 5bc71b2 commit 4f52931
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 128 deletions.
9 changes: 9 additions & 0 deletions src/nmod.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ ulong nmod_addmul(ulong a, ulong b, ulong c, nmod_t mod)
(r) = nmod_addmul((r), (a), (b), (mod)); \
} while (0)

// TODO doc a*b + c*d
NMOD_INLINE
ulong nmod_fmma(ulong a, ulong b, ulong c, ulong d, nmod_t mod)
{
a = nmod_mul(a, b, mod);
NMOD_ADDMUL(a, c, d, mod);
return a;
}

NMOD_INLINE
ulong nmod_inv(ulong a, nmod_t mod)
{
Expand Down
2 changes: 0 additions & 2 deletions src/nmod_poly/inv_series.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ _nmod_poly_inv_series_basecase_preinv1(nn_ptr Qinv, nn_srcptr Q, slong Qlen, slo
for (i = 1; i < n; i++)
{
l = FLINT_MIN(i, Qlen - 1);
//NMOD_VEC_DOT(s, j, l, Q[j + 1], Qinv[i - 1 - j], mod, params);
// FIXME macro more efficient for small l ?
s = _nmod_vec_dot_rev(Q+1, Qinv + i-l, l, mod, params);

if (q == 1)
Expand Down
13 changes: 2 additions & 11 deletions src/nmod_poly/mul_classical.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,9 @@ _nmod_poly_mul_classical(nn_ptr res, nn_srcptr poly1,

squaring = (poly1 == poly2 && len1 == len2);

// TODO could what is below make more direct use of nmod_vec_dot?
log_len = FLINT_BIT_COUNT(len2);
bits = FLINT_BITS - (slong) mod.norm;
bits = 2 * bits + log_len;
const dot_params_t params = _nmod_vec_dot_params(FLINT_MIN(len1, len2), mod);

if (bits <= FLINT_BITS)
if (params.method <= _DOT1)
{
flint_mpn_zero(res, len1 + len2 - 1);

Expand Down Expand Up @@ -83,12 +80,6 @@ _nmod_poly_mul_classical(nn_ptr res, nn_srcptr poly1,
return;
}

dot_params_t params = {_DOT2, 0};
if (bits <= 2 * FLINT_BITS)
params.method = _DOT2;
else
params.method = _DOT3;

if (squaring)
{
for (i = 0; i < 2 * len1 - 1; i++)
Expand Down
68 changes: 65 additions & 3 deletions src/nmod_vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#endif

#include "flint.h"
#include "nmod.h" // nmod_mul, nmod_fmma

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -112,6 +113,7 @@ void _nmod_vec_scalar_addmul_nmod(nn_ptr res, nn_srcptr vec, slong len, ulong c,
/* -------------------- dot product ----------------------- */
/* more comments in nmod_vec/dot.c */


// for _DOT2_SPLIT (currently restricted to FLINT_BITS == 64)
#if (FLINT_BITS == 64)
# define DOT_SPLIT_BITS 56
Expand Down Expand Up @@ -377,14 +379,31 @@ ulong _nmod_vec_dot2(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod);
ulong _nmod_vec_dot3_acc(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod);
ulong _nmod_vec_dot3(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod);

ulong _nmod_vec_dot_pow2_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod);
ulong _nmod_vec_dot1_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod);
ulong _nmod_vec_dot2_half_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod);
ulong _nmod_vec_dot2_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod);
ulong _nmod_vec_dot3_acc_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod);
ulong _nmod_vec_dot3_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod);

#if FLINT_BITS == 64
ulong _nmod_vec_dot2_split(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, ulong pow2_precomp);
ulong _nmod_vec_dot2_split_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, ulong pow2_precomp);
#endif // FLINT_BITS == 64

/* general dot functions */

NMOD_VEC_INLINE ulong _nmod_vec_dot(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, dot_params_t params)
{
if (len <= 2 && params.method > _DOT1)
{
if (len == 2)
return nmod_fmma(vec1[0], vec2[0], vec1[1], vec2[1], mod);
if (len == 1)
return nmod_mul(vec1[0], vec2[0], mod);
return 0;
}

if (params.method == _DOT1)
return _nmod_vec_dot1(vec1, vec2, len, mod);

Expand All @@ -407,19 +426,62 @@ NMOD_VEC_INLINE ulong _nmod_vec_dot(nn_srcptr vec1, nn_srcptr vec2, slong len, n

if (params.method == _DOT_POW2)
{
#if defined(__AVX2__)
#if defined(__AVX2__) && FLINT_BITS == 64
if (mod.n < (UWORD(1) << (FLINT_BITS / 2)))
return _nmod_vec_dot1(vec1, vec2, len, mod);
else // make sure not to use avx 32-bit mul
#endif // defined(__AVX2__)
#endif // defined(__AVX2__) && FLINT_BITS == 64
return _nmod_vec_dot_pow2(vec1, vec2, len, mod);
}

else // params.method == _DOT0
return UWORD(0);
}

ulong _nmod_vec_dot_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, dot_params_t);
NMOD_VEC_INLINE ulong _nmod_vec_dot_rev(nn_srcptr vec1, nn_srcptr vec2, slong len, nmod_t mod, dot_params_t params)
{
if (len <= 2 && params.method > _DOT1)
{
if (len == 2)
return nmod_fmma(vec1[0], vec2[1], vec1[1], vec2[0], mod);
if (len == 1)
return nmod_mul(vec1[0], vec2[0], mod);
return 0;
}

if (params.method == _DOT1)
return _nmod_vec_dot1_rev(vec1, vec2, len, mod);

#if FLINT_BITS == 64
if (params.method == _DOT2_SPLIT)
return _nmod_vec_dot2_split_rev(vec1, vec2, len, mod, params.pow2_precomp);
#endif // FLINT_BITS == 64

if (params.method == _DOT2)
return _nmod_vec_dot2_rev(vec1, vec2, len, mod);

if (params.method == _DOT3_ACC)
return _nmod_vec_dot3_acc_rev(vec1, vec2, len, mod);

if (params.method == _DOT3)
return _nmod_vec_dot3_rev(vec1, vec2, len, mod);

if (params.method == _DOT2_HALF)
return _nmod_vec_dot2_half_rev(vec1, vec2, len, mod);

if (params.method == _DOT_POW2)
{
#if defined(__AVX2__) && FLINT_BITS == 64
if (mod.n < (UWORD(1) << (FLINT_BITS / 2)))
return _nmod_vec_dot1_rev(vec1, vec2, len, mod);
else // make sure not to use avx 32-bit mul
#endif // defined(__AVX2__) && FLINT_BITS == 64
return _nmod_vec_dot_pow2_rev(vec1, vec2, len, mod);
}

else // params.method == _DOT0
return UWORD(0);
}

ulong _nmod_vec_dot_ptr(nn_srcptr vec1, const nn_ptr * vec2, slong offset, slong len, nmod_t mod, dot_params_t);

Expand Down
Loading

0 comments on commit 4f52931

Please sign in to comment.