Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/mpn_extras.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,20 +549,25 @@ flint_mpn_sqr(mp_ptr r, mp_srcptr x, mp_size_t n)
#define FLINT_HAVE_MULHIGH_FUNC(n) ((n) <= FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH)
#define FLINT_HAVE_SQRHIGH_FUNC(n) ((n) <= FLINT_MPN_SQRHIGH_FUNC_TAB_WIDTH)
#define FLINT_HAVE_MULHIGH_NORMALISED_FUNC(n) ((n) <= FLINT_MPN_MULHIGH_NORMALISED_FUNC_TAB_WIDTH)
#define FLINT_HAVE_SQRHIGH_NORMALISED_FUNC(n) ((n) <= FLINT_MPN_SQRHIGH_NORMALISED_FUNC_TAB_WIDTH)

typedef struct { mp_limb_t m1; mp_limb_t m2; } mp_limb_pair_t;
typedef mp_limb_pair_t (* flint_mpn_sqrhigh_normalised_func_t)(mp_ptr, mp_srcptr);
typedef mp_limb_pair_t (* flint_mpn_mulhigh_normalised_func_t)(mp_ptr, mp_srcptr, mp_srcptr);

FLINT_DLL extern const flint_mpn_mul_func_t flint_mpn_mullow_func_tab[];
FLINT_DLL extern const flint_mpn_mul_func_t flint_mpn_mulhigh_func_tab[];
FLINT_DLL extern const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[];
FLINT_DLL extern const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_normalised_func_tab[];
FLINT_DLL extern const flint_mpn_sqrhigh_normalised_func_t flint_mpn_sqrhigh_normalised_func_tab[];

#if FLINT_HAVE_ASSEMBLY_x86_64_adx
# define FLINT_MPN_MULLOW_FUNC_TAB_WIDTH 8
# define FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH 9
# define FLINT_MPN_SQRHIGH_FUNC_TAB_WIDTH 8
# define FLINT_MPN_MULHIGH_NORMALISED_FUNC_TAB_WIDTH 9
# define FLINT_MPN_SQRHIGH_NORMALISED_FUNC_TAB_WIDTH 8

# define FLINT_HAVE_NATIVE_mpn_mullow_basecase 1
/* NOTE: This function only works for n >= 6 */
# define FLINT_HAVE_NATIVE_mpn_mulhigh_basecase 1
Expand All @@ -574,6 +579,8 @@ FLINT_DLL extern const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_nor
# define FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH 8
# define FLINT_MPN_SQRHIGH_FUNC_TAB_WIDTH 8
# define FLINT_MPN_MULHIGH_NORMALISED_FUNC_TAB_WIDTH 0
# define FLINT_MPN_SQRHIGH_NORMALISED_FUNC_TAB_WIDTH 0

/* NOTE: This function only works for n > 8 */
# define FLINT_HAVE_NATIVE_mpn_mulhigh_basecase 1

Expand All @@ -583,6 +590,7 @@ FLINT_DLL extern const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_nor
# define FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH 16
# define FLINT_MPN_SQRHIGH_FUNC_TAB_WIDTH 2
# define FLINT_MPN_MULHIGH_NORMALISED_FUNC_TAB_WIDTH 0
# define FLINT_MPN_SQRHIGH_NORMALISED_FUNC_TAB_WIDTH 0

#endif

Expand Down Expand Up @@ -715,6 +723,19 @@ mp_limb_pair_t flint_mpn_mulhigh_normalised(mp_ptr rp, mp_srcptr xp, mp_srcptr y
return _flint_mpn_mulhigh_normalised(rp, xp, yp, n);
}

mp_limb_pair_t _flint_mpn_sqrhigh_normalised(mp_ptr rp, mp_srcptr xp, mp_size_t n);

MPN_EXTRAS_INLINE
mp_limb_pair_t flint_mpn_sqrhigh_normalised(mp_ptr rp, mp_srcptr xp, mp_size_t n)
{
FLINT_ASSERT(n >= 1);

if (FLINT_HAVE_SQRHIGH_NORMALISED_FUNC(n))
return flint_mpn_sqrhigh_normalised_func_tab[n](rp, xp);
else
return _flint_mpn_sqrhigh_normalised(rp, xp, n);
}

/* division ******************************************************************/

#if FLINT_HAVE_NATIVE_mpn_modexact_1_odd
Expand Down
70 changes: 0 additions & 70 deletions src/mpn_extras/mulhigh_basecase.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,6 @@ mp_limb_pair_t flint_mpn_mulhigh_normalised_7(mp_ptr, mp_srcptr, mp_srcptr);
mp_limb_pair_t flint_mpn_mulhigh_normalised_8(mp_ptr, mp_srcptr, mp_srcptr);
mp_limb_pair_t flint_mpn_mulhigh_normalised_9(mp_ptr, mp_srcptr, mp_srcptr);

mp_limb_t flint_mpn_sqrhigh_1(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_2(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_3(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_4(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_5(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_6(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_7(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_8(mp_ptr, mp_srcptr);

const flint_mpn_mul_func_t flint_mpn_mulhigh_func_tab[] =
{
NULL,
Expand Down Expand Up @@ -69,19 +60,6 @@ const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_normalised_func_tab[
flint_mpn_mulhigh_normalised_8,
flint_mpn_mulhigh_normalised_9
};

const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[] =
{
NULL,
flint_mpn_sqrhigh_1,
flint_mpn_sqrhigh_2,
flint_mpn_sqrhigh_3,
flint_mpn_sqrhigh_4,
flint_mpn_sqrhigh_5,
flint_mpn_sqrhigh_6,
flint_mpn_sqrhigh_7,
flint_mpn_sqrhigh_8
};
#elif FLINT_HAVE_ASSEMBLY_armv8
mp_limb_t flint_mpn_mulhigh_1(mp_ptr, mp_srcptr, mp_srcptr);
mp_limb_t flint_mpn_mulhigh_2(mp_ptr, mp_srcptr, mp_srcptr);
Expand All @@ -92,15 +70,6 @@ mp_limb_t flint_mpn_mulhigh_6(mp_ptr, mp_srcptr, mp_srcptr);
mp_limb_t flint_mpn_mulhigh_7(mp_ptr, mp_srcptr, mp_srcptr);
mp_limb_t flint_mpn_mulhigh_8(mp_ptr, mp_srcptr, mp_srcptr);

mp_limb_t flint_mpn_sqrhigh_1(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_2(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_3(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_4(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_5(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_6(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_7(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_8(mp_ptr, mp_srcptr);

const flint_mpn_mul_func_t flint_mpn_mulhigh_func_tab[] =
{
NULL,
Expand All @@ -118,25 +87,10 @@ const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_normalised_func_tab[
{
NULL,
};

const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[] =
{
NULL,
flint_mpn_sqrhigh_1,
flint_mpn_sqrhigh_2,
flint_mpn_sqrhigh_3,
flint_mpn_sqrhigh_4,
flint_mpn_sqrhigh_5,
flint_mpn_sqrhigh_6,
flint_mpn_sqrhigh_7,
flint_mpn_sqrhigh_8,
};
#else

/* todo: add MPFR-like basecase for use in mulders */
/* todo: squaring code */
/* todo: define the generic basecase also on x86_64_adx,
and use to test the assembly versions */

mp_limb_t _flint_mpn_mulhigh_basecase(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n)
{
Expand Down Expand Up @@ -428,28 +382,4 @@ const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_normalised_func_tab[
{
NULL,
};

mp_limb_t flint_mpn_sqrhigh_1(mp_ptr res, mp_srcptr u)
{
mp_limb_t low;
umul_ppmm(res[0], low, u[0], u[0]);
return low;
}

/* todo */
mp_limb_t flint_mpn_sqrhigh_2(mp_ptr res, mp_srcptr u)
{
mp_limb_t b, low;
FLINT_MPN_MUL_2X2(res[1], res[0], low, b, u[1], u[0], u[1], u[0]);
return low;
}

/* todo: higher cases */

const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[] = {
NULL,
flint_mpn_sqrhigh_1,
flint_mpn_sqrhigh_2,
};

#endif
24 changes: 24 additions & 0 deletions src/mpn_extras/sqrhigh.c
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,27 @@ _flint_mpn_sqrhigh(mp_ptr res, mp_srcptr u, mp_size_t n)
else
return _flint_mpn_sqrhigh_sqr(res, u, n);
}

mp_limb_pair_t _flint_mpn_sqrhigh_normalised(mp_ptr rp, mp_srcptr xp, mp_size_t n)
{
mp_limb_pair_t ret;

FLINT_ASSERT(n >= 1);
FLINT_ASSERT(rp != xp);

ret.m1 = flint_mpn_sqrhigh(rp, xp, n);

if (rp[n - 1] >> (FLINT_BITS - 1))
{
ret.m2 = 0;
}
else
{
ret.m2 = 1;
mpn_lshift(rp, rp, n, 1);
rp[0] |= (ret.m1 >> (FLINT_BITS - 1));
ret.m1 <<= 1;
}

return ret;
}
88 changes: 88 additions & 0 deletions src/mpn_extras/sqrhigh_basecase.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
Copyright (C) 2024 Albin Ahlbäck

This file is part of FLINT.

FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "mpn_extras.h"

#if FLINT_HAVE_ASSEMBLY_x86_64_adx || FLINT_HAVE_ASSEMBLY_armv8
mp_limb_t flint_mpn_sqrhigh_1(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_2(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_3(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_4(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_5(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_6(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_7(mp_ptr, mp_srcptr);
mp_limb_t flint_mpn_sqrhigh_8(mp_ptr, mp_srcptr);

const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[] =
{
NULL,
flint_mpn_sqrhigh_1,
flint_mpn_sqrhigh_2,
flint_mpn_sqrhigh_3,
flint_mpn_sqrhigh_4,
flint_mpn_sqrhigh_5,
flint_mpn_sqrhigh_6,
flint_mpn_sqrhigh_7,
flint_mpn_sqrhigh_8
};
#else
mp_limb_t flint_mpn_sqrhigh_1(mp_ptr res, mp_srcptr u)
{
mp_limb_t low;
umul_ppmm(res[0], low, u[0], u[0]);
return low;
}

/* todo */
mp_limb_t flint_mpn_sqrhigh_2(mp_ptr res, mp_srcptr u)
{
mp_limb_t b, low;
FLINT_MPN_MUL_2X2(res[1], res[0], low, b, u[1], u[0], u[1], u[0]);
return low;
}

/* todo: higher cases */

const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[] = {
NULL,
flint_mpn_sqrhigh_1,
flint_mpn_sqrhigh_2,
};
#endif

#if FLINT_HAVE_ASSEMBLY_x86_64_adx
mp_limb_pair_t flint_mpn_sqrhigh_normalised_1(mp_ptr, mp_srcptr);
mp_limb_pair_t flint_mpn_sqrhigh_normalised_2(mp_ptr, mp_srcptr);
mp_limb_pair_t flint_mpn_sqrhigh_normalised_3(mp_ptr, mp_srcptr);
mp_limb_pair_t flint_mpn_sqrhigh_normalised_4(mp_ptr, mp_srcptr);
mp_limb_pair_t flint_mpn_sqrhigh_normalised_5(mp_ptr, mp_srcptr);
mp_limb_pair_t flint_mpn_sqrhigh_normalised_6(mp_ptr, mp_srcptr);
mp_limb_pair_t flint_mpn_sqrhigh_normalised_7(mp_ptr, mp_srcptr);
mp_limb_pair_t flint_mpn_sqrhigh_normalised_8(mp_ptr, mp_srcptr);

const flint_mpn_sqrhigh_normalised_func_t flint_mpn_sqrhigh_normalised_func_tab[] =
{
NULL,
flint_mpn_sqrhigh_normalised_1,
flint_mpn_sqrhigh_normalised_2,
flint_mpn_sqrhigh_normalised_3,
flint_mpn_sqrhigh_normalised_4,
flint_mpn_sqrhigh_normalised_5,
flint_mpn_sqrhigh_normalised_6,
flint_mpn_sqrhigh_normalised_7,
flint_mpn_sqrhigh_normalised_8
};
#else
const flint_mpn_sqrhigh_normalised_func_t flint_mpn_sqrhigh_normalised_func_tab[] =
{
NULL
};
#endif
4 changes: 3 additions & 1 deletion src/mpn_extras/test/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "t-remove_power.c"
#include "t-sqr.c"
#include "t-sqrhigh.c"
#include "t-sqrhigh_normalised.c"

/* Array of test functions ***************************************************/

Expand All @@ -55,7 +56,8 @@ test_struct tests[] =
TEST_FUNCTION(flint_mpn_remove_2exp),
TEST_FUNCTION(flint_mpn_remove_power),
TEST_FUNCTION(flint_mpn_sqr),
TEST_FUNCTION(flint_mpn_sqrhigh)
TEST_FUNCTION(flint_mpn_sqrhigh),
TEST_FUNCTION(flint_mpn_sqrhigh_normalised)
};

/* main function *************************************************************/
Expand Down
93 changes: 93 additions & 0 deletions src/mpn_extras/test/t-sqrhigh_normalised.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
Copyright (C) 2024 Albin Ahlbäck

This file is part of FLINT.

FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "test_helpers.h"
#include "mpn_extras.h"

#define N_MIN 1
#define N_MAX 64

TEST_FUNCTION_START(flint_mpn_sqrhigh_normalised, state)
{
slong ix;
int result;

for (ix = 0; ix < 10000 * flint_test_multiplier(); ix++)
{
mp_ptr rp_n, rp_u, xp;
mp_size_t n;
mp_limb_pair_t res_norm;
mp_limb_t retlimb, normalised;

n = N_MIN + n_randint(state, N_MAX - N_MIN + 1);

rp_n = flint_malloc(sizeof(mp_limb_t) * (n + 1));
rp_u = flint_malloc(sizeof(mp_limb_t) * (n + 1));
xp = flint_malloc(sizeof(mp_limb_t) * n);

flint_mpn_rrandom(xp, state, n);
xp[n - 1] |= (UWORD(1) << (FLINT_BITS - 1));

rp_u[0] = flint_mpn_sqrhigh(rp_u + 1, xp, n);
res_norm = flint_mpn_sqrhigh_normalised(rp_n + 1, xp, n);

retlimb = res_norm.m1;
normalised = res_norm.m2;
rp_n[0] = retlimb;

result = ((rp_n[n] & (UWORD(1) << (FLINT_BITS - 1))) != UWORD(0));
if (!result)
TEST_FUNCTION_FAIL(
"Top bit not set in normalised result\n"
"ix = %wd\n"
"n = %wd\n"
"xp = %{ulong*}\n"
"rp_n = %{ulong*}\n"
"rp_u = %{ulong*}\n",
ix, n, xp, n, rp_n, n + 1, rp_u, n + 1);

if (normalised)
{
result = (mpn_lshift(rp_u, rp_u, n + 1, 1) == 0);
result = result && (mpn_cmp(rp_n, rp_u, n + 1) == 0);
if (!result)
TEST_FUNCTION_FAIL(
"rp_n != rp_u << 1 when normalised\n"
"ix = %wd\n"
"n = %wd\n"
"xp = %{ulong*}\n"
"rp_n = %{ulong*}\n"
"rp_u = %{ulong*}\n",
ix, n, xp, n, rp_n, n + 1, rp_u, n + 1);
}
else
{
result = (mpn_cmp(rp_n, rp_u, n + 1) == 0);
if (!result)
TEST_FUNCTION_FAIL(
"rp_n != rp_u when unnormalised\n"
"ix = %wd\n"
"n = %wd\n"
"xp = %{ulong*}\n"
"rp_n = %{ulong*}\n"
"rp_u = %{ulong*}\n",
ix, n, xp, n, rp_n, n + 1, rp_u, n + 1);
}

flint_free(rp_n);
flint_free(rp_u);
flint_free(xp);
}

TEST_FUNCTION_END(state);
}
#undef N_MIN
#undef N_MAX
Loading