Skip to content

Commit

Permalink
Improve performance of BigInteger.Multiply(large, small) (#92208)
Browse files Browse the repository at this point in the history
  • Loading branch information
kzrnm committed Nov 6, 2023
1 parent 7f7702e commit e733539
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public static void Multiply(ReadOnlySpan<uint> left, uint right, Span<uint> bits
int i = 0;
ulong carry = 0UL;

for ( ; i < left.Length; i++)
for (; i < left.Length; i++)
{
ulong digits = (ulong)left[i] * right + carry;
bits[i] = unchecked((uint)digits);
Expand All @@ -151,9 +151,9 @@ public static void Multiply(ReadOnlySpan<uint> left, uint right, Span<uint> bits

#if DEBUG
// Mutable for unit testing...
private static
internal static
#else
private const
internal const
#endif
int MultiplyThreshold = 32;

Expand All @@ -162,6 +162,216 @@ public static void Multiply(ReadOnlySpan<uint> left, ReadOnlySpan<uint> right, S
Debug.Assert(left.Length >= right.Length);
Debug.Assert(bits.Length == left.Length + right.Length);

if (left.Length - right.Length < 3)
{
MultiplyNearLength(left, right, bits);
}
else
{
MultiplyFarLength(left, right, bits);
}
}

private static void MultiplyFarLength(ReadOnlySpan<uint> left, ReadOnlySpan<uint> right, Span<uint> bits)
{
Debug.Assert(left.Length - right.Length >= 3);
Debug.Assert(bits.Length == left.Length + right.Length);

// Executes different algorithms for computing z = a * b
// based on the actual length of b. If b is "small" enough
// we stick to the classic "grammar-school" method; for the
// rest we switch to implementations with less complexity
// albeit more overhead (which needs to pay off!).

// NOTE: useful thresholds needs some "empirical" testing,
// which are smaller in DEBUG mode for testing purpose.

if (right.Length < MultiplyThreshold)
{
// Switching to managed references helps eliminating
// index bounds check...
ref uint resultPtr = ref MemoryMarshal.GetReference(bits);

// Multiplies the bits using the "grammar-school" method.
// Envisioning the "rhombus" of a pen-and-paper calculation
// should help getting the idea of these two loops...
// The inner multiplication operations are safe, because
// z_i+j + a_j * b_i + c <= 2(2^32 - 1) + (2^32 - 1)^2 =
// = 2^64 - 1 (which perfectly matches with ulong!).

for (int i = 0; i < right.Length; i++)
{
ulong carry = 0UL;
for (int j = 0; j < left.Length; j++)
{
ref uint elementPtr = ref Unsafe.Add(ref resultPtr, i + j);
ulong digits = elementPtr + carry + (ulong)left[j] * right[i];
elementPtr = unchecked((uint)digits);
carry = digits >> 32;
}
Unsafe.Add(ref resultPtr, i + left.Length) = (uint)carry;
}
}
else
{
// Based on the Toom-Cook multiplication we split left/right
// into two smaller values, doing recursive multiplication.
// The special form of this multiplication, where we
// split both operands into two operands, is also known
// as the Karatsuba algorithm...

// https://en.wikipedia.org/wiki/Toom-Cook_multiplication
// https://en.wikipedia.org/wiki/Karatsuba_algorithm

// Say we want to compute z = a * b ...

// ... we need to determine our new length (just the half)
int n = left.Length >> 1;
if (right.Length <= n + 1)
{
// ... split left like a = (a_1 << n) + a_0
ReadOnlySpan<uint> leftLow = left.Slice(0, n);
ReadOnlySpan<uint> leftHigh = left.Slice(n);

// ... split right like b = (b_1 << n) + b_0
ReadOnlySpan<uint> rightLow;
uint rightHigh;
if (n < right.Length)
{
Debug.Assert(right.Length == n + 1);
rightLow = right.Slice(0, n);
rightHigh = right[n];
}
else
{
rightLow = right;
rightHigh = 0;
}

// ... prepare our result array (to reuse its memory)
Span<uint> bitsLow = bits.Slice(0, n + rightLow.Length);
Span<uint> bitsHigh = bits.Slice(n);

int carryLength = rightLow.Length;
uint[]? carryFromPool = null;
Span<uint> carry = ((uint)carryLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: carryFromPool = ArrayPool<uint>.Shared.Rent(carryLength)).Slice(0, carryLength);

// ... compute low
Multiply(leftLow, rightLow, bitsLow);
Span<uint> carryOrig = bits.Slice(n, rightLow.Length);
carryOrig.CopyTo(carry);
carryOrig.Clear();

if (rightHigh != 0)
{
// ... compute high
MultiplyNearLength(leftHigh, rightLow, bitsHigh.Slice(0, leftHigh.Length + n));

int upperRightLength = left.Length + 1;
uint[]? upperRightFromPool = null;
Span<uint> upperRight = ((uint)upperRightLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: upperRightFromPool = ArrayPool<uint>.Shared.Rent(upperRightLength)).Slice(0, upperRightLength);
upperRight.Clear();

Multiply(left, rightHigh, upperRight);

AddSelf(bitsHigh, upperRight);

if (upperRightFromPool != null)
ArrayPool<uint>.Shared.Return(upperRightFromPool);
}
else
{
// ... compute high
Multiply(leftHigh, rightLow, bitsHigh);
}

AddSelf(bitsHigh, carry);

if (carryFromPool != null)
ArrayPool<uint>.Shared.Return(carryFromPool);
}
else
{
int n2 = n << 1;

Debug.Assert(left.Length > right.Length);

// ... split left like a = (a_1 << n) + a_0
ReadOnlySpan<uint> leftLow = left.Slice(0, n);
ReadOnlySpan<uint> leftHigh = left.Slice(n);

// ... split right like b = (b_1 << n) + b_0
ReadOnlySpan<uint> rightLow = right.Slice(0, n);
ReadOnlySpan<uint> rightHigh = right.Slice(n);

// ... prepare our result array (to reuse its memory)
Span<uint> bitsLow = bits.Slice(0, n2);
Span<uint> bitsHigh = bits.Slice(n2);

// ... compute z_0 = a_0 * b_0 (multiply again)
MultiplyNearLength(rightLow, leftLow, bitsLow);

// ... compute z_2 = a_1 * b_1 (multiply again)
MultiplyFarLength(leftHigh, rightHigh, bitsHigh);

int leftFoldLength = leftHigh.Length + 1;
uint[]? leftFoldFromPool = null;
Span<uint> leftFold = ((uint)leftFoldLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: leftFoldFromPool = ArrayPool<uint>.Shared.Rent(leftFoldLength)).Slice(0, leftFoldLength);
leftFold.Clear();

int rightFoldLength = n + 1;
uint[]? rightFoldFromPool = null;
Span<uint> rightFold = ((uint)rightFoldLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: rightFoldFromPool = ArrayPool<uint>.Shared.Rent(rightFoldLength)).Slice(0, rightFoldLength);
rightFold.Clear();

int coreLength = leftFoldLength + rightFoldLength;
uint[]? coreFromPool = null;
Span<uint> core = ((uint)coreLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: coreFromPool = ArrayPool<uint>.Shared.Rent(coreLength)).Slice(0, coreLength);
core.Clear();

Debug.Assert(bits.Length - n >= core.Length);
Debug.Assert(rightLow.Length >= rightHigh.Length);

// ... compute z_a = a_1 + a_0 (call it fold...)
Add(leftHigh, leftLow, leftFold);

// ... compute z_b = b_1 + b_0 (call it fold...)
Add(rightLow, rightHigh, rightFold);

// ... compute z_1 = z_a * z_b - z_0 - z_2
MultiplyNearLength(leftFold, rightFold, core);

if (leftFoldFromPool != null)
ArrayPool<uint>.Shared.Return(leftFoldFromPool);

if (rightFoldFromPool != null)
ArrayPool<uint>.Shared.Return(rightFoldFromPool);

SubtractCore(bitsLow, bitsHigh, core);

// ... and finally merge the result! :-)
AddSelf(bits.Slice(n), core);

if (coreFromPool != null)
ArrayPool<uint>.Shared.Return(coreFromPool);
}
}
}
private static void MultiplyNearLength(ReadOnlySpan<uint> left, ReadOnlySpan<uint> right, Span<uint> bits)
{
Debug.Assert(left.Length - right.Length < 3);
Debug.Assert(bits.Length == left.Length + right.Length);

// Executes different algorithms for computing z = a * b
// based on the actual length of b. If b is "small" enough
// we stick to the classic "grammar-school" method; for the
Expand Down Expand Up @@ -227,10 +437,10 @@ public static void Multiply(ReadOnlySpan<uint> left, ReadOnlySpan<uint> right, S
Span<uint> bitsHigh = bits.Slice(n2);

// ... compute z_0 = a_0 * b_0 (multiply again)
Multiply(leftLow, rightLow, bitsLow);
MultiplyNearLength(leftLow, rightLow, bitsLow);

// ... compute z_2 = a_1 * b_1 (multiply again)
Multiply(leftHigh, rightHigh, bitsHigh);
MultiplyNearLength(leftHigh, rightHigh, bitsHigh);

int leftFoldLength = leftHigh.Length + 1;
uint[]? leftFoldFromPool = null;
Expand Down Expand Up @@ -260,7 +470,7 @@ public static void Multiply(ReadOnlySpan<uint> left, ReadOnlySpan<uint> right, S
Add(rightHigh, rightLow, rightFold);

// ... compute z_1 = z_a * z_b - z_0 - z_2
Multiply(leftFold, rightFold, core);
MultiplyNearLength(leftFold, rightFold, core);

if (leftFoldFromPool != null)
ArrayPool<uint>.Shared.Return(leftFoldFromPool);
Expand Down Expand Up @@ -298,21 +508,21 @@ private static void SubtractCore(ReadOnlySpan<uint> left, ReadOnlySpan<uint> rig
ref uint leftPtr = ref MemoryMarshal.GetReference(left);
ref uint corePtr = ref MemoryMarshal.GetReference(core);

for ( ; i < right.Length; i++)
for (; i < right.Length; i++)
{
long digit = (Unsafe.Add(ref corePtr, i) + carry) - Unsafe.Add(ref leftPtr, i) - right[i];
Unsafe.Add(ref corePtr, i) = unchecked((uint)digit);
carry = digit >> 32;
}

for ( ; i < left.Length; i++)
for (; i < left.Length; i++)
{
long digit = (Unsafe.Add(ref corePtr, i) + carry) - left[i];
Unsafe.Add(ref corePtr, i) = unchecked((uint)digit);
carry = digit >> 32;
}

for ( ; carry != 0 && i < core.Length; i++)
for (; carry != 0 && i < core.Length; i++)
{
long digit = core[i] + carry;
core[i] = (uint)digit;
Expand Down
22 changes: 22 additions & 0 deletions src/libraries/System.Runtime.Numerics/tests/BigInteger/multiply.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,28 @@ public static void RunMultiply_Boundary()
VerifyMultiplyString(Math.Pow(2, 33) + " 2 bMultiply");
}

[Fact]
public static void RunMultiplyKaratsubaBoundary()
{
Random random = new Random(s_seed);
byte[] tempByteArray1 = new byte[0];
byte[] tempByteArray2 = new byte[0];

// Multiply Method - One Large BigInteger
for (int i = 0; i < s_samples; i++)
{
for (int d1 = -2; d1 <= 2; d1++)
{
tempByteArray1 = GetRandomByteArray(random, BigIntegerCalculator.MultiplyThreshold + d1);
for (int d2 = -4; d2 <= 4; d2++)
{
tempByteArray2 = GetRandomByteArray(random, (BigIntegerCalculator.MultiplyThreshold + 1) * 2 + d2);
VerifyMultiplyString(Print(tempByteArray1) + Print(tempByteArray2) + "bMultiply");
}
}
}
}

[Fact]
public static void RunMultiply_OnePositiveOneNegative()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,27 @@ public static void RunMultiplyBoundary()
VerifyMultiplyString(Math.Pow(2, 33) + " 2 b*");
}

[Fact]
public static void RunMultiplyKaratsubaBoundary()
{
byte[] tempByteArray1 = new byte[0];
byte[] tempByteArray2 = new byte[0];

// Multiply Method - One Large BigInteger
for (int i = 0; i < s_samples; i++)
{
for (int d1 = -2; d1 <= 2; d1++)
{
tempByteArray1 = GetRandomByteArray(s_random, BigIntegerCalculator.MultiplyThreshold + d1);
for (int d2 = -4; d2 <= 4; d2++)
{
tempByteArray2 = GetRandomByteArray(s_random, (BigIntegerCalculator.MultiplyThreshold + 1) * 2 + d2);
VerifyMultiplyString(Print(tempByteArray1) + Print(tempByteArray2) + "b*");
}
}
}
}

[Fact]
public static void RunMultiplyTests()
{
Expand Down

0 comments on commit e733539

Please sign in to comment.