From b3cb927d99d1568191199357c7121ad075f790d6 Mon Sep 17 00:00:00 2001 From: Aaron Chen Date: Mon, 27 May 2024 22:09:36 +0800 Subject: [PATCH] uint256: optimize Sqrt (#174) * uint256: optimize Sqrt * fix lint * improve test coverage * replace the first div with a right shift * move the first div outsite the loop * move the division to the end of the loop --- shared_test.go | 2 ++ uint256.go | 63 ++++++++++++++++++++++++++++++++------------------ 2 files changed, 42 insertions(+), 23 deletions(-) diff --git a/shared_test.go b/shared_test.go index 42a2519..2c3fbb6 100644 --- a/shared_test.go +++ b/shared_test.go @@ -24,6 +24,8 @@ var ( unTestCases = []string{ "0x0", "0x1", + "0x8000000000000000", + "0x12cbafcee8f60f9f", "0x80000000000000000000000000000000", "0x80000000000000010000000000000000", "0x80000000000000000000000000000001", diff --git a/uint256.go b/uint256.go index 5d1df27..10dfcb9 100644 --- a/uint256.go +++ b/uint256.go @@ -1269,34 +1269,51 @@ func (z *Int) ExtendSign(x, byteNum *Int) *Int { // Sqrt sets z to ⌊√x⌋, the largest integer such that z² ≤ x, and returns z. func (z *Int) Sqrt(x *Int) *Int { // This implementation of Sqrt is based on big.Int (see math/big/nat.go). - if x.LtUint64(2) { - return z.Set(x) + if x.IsUint64() { + var ( + x0 uint64 = x.Uint64() + z1 uint64 = 1 << ((bits.Len64(x0) + 1) / 2) + z2 uint64 + ) + if x0 < 2 { + return z.SetUint64(x0) + } + for { + z2 = (z1 + x0 / z1) >> 1 + if z2 >= z1 { + return z.SetUint64(z1) + } + z1 = z2 + } } - var ( - z1 = &Int{1, 0, 0, 0} - z2 = &Int{} - ) + + z1 := NewInt(1) + z2 := NewInt(0) + // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller. - z1 = z1.Lsh(z1, uint(x.BitLen()+1)/2) // must be ≥ √x - for { - z2 = z2.Div(x, z1) - z2 = z2.Add(z2, z1) - { //z2 = z2.Rsh(z2, 1) -- the code below does a 1-bit rsh faster - a := z2[3] << 63 - z2[3] = z2[3] >> 1 - b := z2[2] << 63 - z2[2] = (z2[2] >> 1) | a - a = z2[1] << 63 - z2[1] = (z2[1] >> 1) | b - z2[0] = (z2[0] >> 1) | a - } - // end of inlined bitshift + z1.Lsh(z1, uint(x.BitLen() + 1) / 2) // must be ≥ √x + + // We can do the first division outside the loop + z2.Rsh(x, uint(x.BitLen() + 1) / 2) // The first div is equal to a right shift - if z2.Cmp(z1) >= 0 { - // z1 is answer. + for { + z2.Add(z2, z1) + + // z2 = z2.Rsh(z2, 1) -- the code below does a 1-bit rsh faster + z2[0] = (z2[0] >> 1) | z2[1] << 63 + z2[1] = (z2[1] >> 1) | z2[2] << 63 + z2[2] = (z2[2] >> 1) | z2[3] << 63 + z2[3] >>= 1 + + if !z2.Lt(z1) { return z.Set(z1) } - z1, z2 = z2, z1 + z1.Set(z2) + + // Next iteration of the loop + // z2.Div(x, z1) -- x > MaxUint64, x > z1 > 0 + z2.Clear() + udivrem(z2[:], x[:], z1, nil) } }