Skip to content

Commit

Permalink
uint256: optimize Sqrt (#174)
Browse files Browse the repository at this point in the history
* uint256: optimize Sqrt

* fix lint

* improve test coverage

* replace the first div with a right shift

* move the first div outsite the loop

* move the division to the end of the loop
  • Loading branch information
AaronChen0 committed May 27, 2024
1 parent 8dfcfde commit b3cb927
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 23 deletions.
2 changes: 2 additions & 0 deletions shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ var (
unTestCases = []string{
"0x0",
"0x1",
"0x8000000000000000",
"0x12cbafcee8f60f9f",
"0x80000000000000000000000000000000",
"0x80000000000000010000000000000000",
"0x80000000000000000000000000000001",
Expand Down
63 changes: 40 additions & 23 deletions uint256.go
Original file line number Diff line number Diff line change
Expand Up @@ -1269,34 +1269,51 @@ func (z *Int) ExtendSign(x, byteNum *Int) *Int {
// Sqrt sets z to ⌊√x⌋, the largest integer such that z² ≤ x, and returns z.
func (z *Int) Sqrt(x *Int) *Int {
// This implementation of Sqrt is based on big.Int (see math/big/nat.go).
if x.LtUint64(2) {
return z.Set(x)
if x.IsUint64() {
var (
x0 uint64 = x.Uint64()
z1 uint64 = 1 << ((bits.Len64(x0) + 1) / 2)
z2 uint64
)
if x0 < 2 {
return z.SetUint64(x0)
}
for {
z2 = (z1 + x0 / z1) >> 1
if z2 >= z1 {
return z.SetUint64(z1)
}
z1 = z2
}
}
var (
z1 = &Int{1, 0, 0, 0}
z2 = &Int{}
)

z1 := NewInt(1)
z2 := NewInt(0)

// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
z1 = z1.Lsh(z1, uint(x.BitLen()+1)/2) // must be ≥ √x
for {
z2 = z2.Div(x, z1)
z2 = z2.Add(z2, z1)
{ //z2 = z2.Rsh(z2, 1) -- the code below does a 1-bit rsh faster
a := z2[3] << 63
z2[3] = z2[3] >> 1
b := z2[2] << 63
z2[2] = (z2[2] >> 1) | a
a = z2[1] << 63
z2[1] = (z2[1] >> 1) | b
z2[0] = (z2[0] >> 1) | a
}
// end of inlined bitshift
z1.Lsh(z1, uint(x.BitLen() + 1) / 2) // must be ≥ √x

// We can do the first division outside the loop
z2.Rsh(x, uint(x.BitLen() + 1) / 2) // The first div is equal to a right shift

if z2.Cmp(z1) >= 0 {
// z1 is answer.
for {
z2.Add(z2, z1)

// z2 = z2.Rsh(z2, 1) -- the code below does a 1-bit rsh faster
z2[0] = (z2[0] >> 1) | z2[1] << 63
z2[1] = (z2[1] >> 1) | z2[2] << 63
z2[2] = (z2[2] >> 1) | z2[3] << 63
z2[3] >>= 1

if !z2.Lt(z1) {
return z.Set(z1)
}
z1, z2 = z2, z1
z1.Set(z2)

// Next iteration of the loop
// z2.Div(x, z1) -- x > MaxUint64, x > z1 > 0
z2.Clear()
udivrem(z2[:], x[:], z1, nil)
}
}

Expand Down

0 comments on commit b3cb927

Please sign in to comment.