Skip to content

Commit

Permalink
perf(internal/bits): Significantly speedup bitArray.PickRandom (backp…
Browse files Browse the repository at this point in the history
…ort #2841) (#2887)

This PR significantly speeds up bitArray.PickRandom which is used in
VoteGossip and BlockPart gossip. We saw for a query serving full node,
over an hour, this was a very large amount of RAM allocations. (75GB of
RAM!)


![image](https://github.com/cometbft/cometbft/assets/6440154/755918a5-0cef-4e67-a47e-ce8a56aa1cd5)

This PR drops it down to 0 allocations, and makes the routine 10x faster
on my machine.

OLD:
```
BenchmarkPickRandomBitArray-12           1545199               846.1 ns/op          1280 B/op          1 allocs/op
```
NEW:
```
BenchmarkPickRandomBitArray-12          22192857                75.39 ns/op            0 B/op          0 allocs/op
```

I think the new tests I wrote make this more tested than the old code
that was here tbh, but pls let me know if theres more tests we'd like to
see!

---

#### PR checklist

- [x] Tests written/updated
- [x] Changelog entry added in `.changelog` (we use
[unclog](https://github.com/informalsystems/unclog) to manage our
changelog)
- [x] Updated relevant documentation (`docs/` or `spec/`) and code
comments
- [x] Title follows the [Conventional
Commits](https://www.conventionalcommits.org/en/v1.0.0/) spec
<hr>This is an automatic backport of pull request #2841 done by
[Mergify](https://mergify.com).

---------

Co-authored-by: Dev Ojha <ValarDragon@users.noreply.github.com>
Co-authored-by: Anton Kaliaev <anton.kalyaev@gmail.com>
  • Loading branch information
3 people committed Apr 24, 2024
1 parent 4899cd5 commit 983cbaa
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 31 deletions.
@@ -0,0 +1,2 @@
- `[internal/bits]` 10x speedup and remove heap overhead of bitArray.PickRandom (used extensively in consensus gossip)
([\#2841](https://github.com/cometbft/cometbft/pull/2841)).
74 changes: 46 additions & 28 deletions libs/bits/bit_array.go
Expand Up @@ -3,6 +3,7 @@ package bits
import (
"encoding/binary"
"fmt"
"math/bits"
"regexp"
"strings"
"sync"
Expand Down Expand Up @@ -247,44 +248,61 @@ func (bA *BitArray) PickRandom() (int, bool) {
}

bA.mtx.Lock()
trueIndices := bA.getTrueIndices()
numTrueIndices := bA.getNumTrueIndices()
if numTrueIndices == 0 { // no bits set to true
bA.mtx.Unlock()
return 0, false
}
index := bA.getNthTrueIndex(cmtrand.Intn(numTrueIndices))
bA.mtx.Unlock()

if len(trueIndices) == 0 { // no bits set to true
if index == -1 {
return 0, false
}
return index, true
}

return trueIndices[cmtrand.Intn(len(trueIndices))], true
func (bA *BitArray) getNumTrueIndices() int {
count := 0
numElems := len(bA.Elems)
for i := 0; i < numElems; i++ {
count += bits.OnesCount64(bA.Elems[i])
}
return count
}

func (bA *BitArray) getTrueIndices() []int {
trueIndices := make([]int, 0, bA.Bits)
curBit := 0
// getNthTrueIndex returns the index of the nth true bit in the bit array.
// n is 0 indexed. (e.g. for bitarray x__x, getNthTrueIndex(0) returns 0).
// If there is no such value, it returns -1.
func (bA *BitArray) getNthTrueIndex(n int) int {
numElems := len(bA.Elems)
// set all true indices
for i := 0; i < numElems-1; i++ {
elem := bA.Elems[i]
if elem == 0 {
curBit += 64
continue
}
for j := 0; j < 64; j++ {
if (elem & (uint64(1) << uint64(j))) > 0 {
trueIndices = append(trueIndices, curBit)
count := 0

// Iterate over each element
for i := 0; i < numElems; i++ {
// Count set bits in the current element
setBits := bits.OnesCount64(bA.Elems[i])

// If the count of set bits in this element plus the count so far
// is greater than or equal to n, then the nth bit must be in this element
if count+setBits >= n {
// Find the index of the nth set bit within this element
for j := 0; j < 64; j++ {
if bA.Elems[i]&(1<<uint(j)) != 0 {
if count == n {
// Calculate the absolute index of the set bit
return i*64 + j
}
count++
}
}
curBit++
}
}
// handle last element
lastElem := bA.Elems[numElems-1]
numFinalBits := bA.Bits - curBit
for i := 0; i < numFinalBits; i++ {
if (lastElem & (uint64(1) << uint64(i))) > 0 {
trueIndices = append(trueIndices, curBit)
} else {
// If the count is not enough, continue to the next element
count += setBits
}
curBit++
}
return trueIndices

// If we reach here, it means n is out of range
return -1
}

// String returns a string representation of BitArray: BA{<bit-string>},
Expand Down
104 changes: 101 additions & 3 deletions libs/bits/bit_array_test.go
Expand Up @@ -12,6 +12,13 @@ import (
cmtrand "github.com/cometbft/cometbft/libs/rand"
)

var (
empty16Bits = "________________"
empty64Bits = empty16Bits + empty16Bits + empty16Bits + empty16Bits
full16bits = "xxxxxxxxxxxxxxxx"
full64bits = full16bits + full16bits + full16bits + full16bits
)

func randBitArray(bits int) (*BitArray, []byte) {
src := cmtrand.Bytes((bits + 7) / 8)
bA := NewBitArray(bits)
Expand Down Expand Up @@ -117,8 +124,6 @@ func TestSub(t *testing.T) {
}

func TestPickRandom(t *testing.T) {
empty16Bits := "________________"
empty64Bits := empty16Bits + empty16Bits + empty16Bits + empty16Bits
testCases := []struct {
bA string
ok bool
Expand All @@ -133,6 +138,7 @@ func TestPickRandom(t *testing.T) {
{`"x` + empty64Bits + `"`, true},
{`"` + empty64Bits + `x"`, true},
{`"x` + empty64Bits + `x"`, true},
{`"` + empty64Bits + `___x"`, true},
}
for _, tc := range testCases {
var bitArr *BitArray
Expand All @@ -143,7 +149,87 @@ func TestPickRandom(t *testing.T) {
}
}

func TestBytes(t *testing.T) {
func TestGetNumTrueIndices(t *testing.T) {
type testcase struct {
Input string
ExpectedResult int
}
testCases := []testcase{
{"x_x_x_", 3},
{"______", 0},
{"xxxxxx", 6},
{"x_x_x_x_x_x_x_x_x_", 9},
}
numOriginalTestCases := len(testCases)
for i := 0; i < numOriginalTestCases; i++ {
testCases = append(testCases, testcase{testCases[i].Input + "x", testCases[i].ExpectedResult + 1})
testCases = append(testCases, testcase{full64bits + testCases[i].Input, testCases[i].ExpectedResult + 64})
testCases = append(testCases, testcase{empty64Bits + testCases[i].Input, testCases[i].ExpectedResult})
}

for _, tc := range testCases {
var bitArr *BitArray
err := json.Unmarshal([]byte(`"`+tc.Input+`"`), &bitArr)
require.NoError(t, err)
result := bitArr.getNumTrueIndices()
require.Equal(t, tc.ExpectedResult, result, "for input %s, expected %d, got %d", tc.Input, tc.ExpectedResult, result)
}
}

func TestGetNthTrueIndex(t *testing.T) {
type testcase struct {
Input string
N int
ExpectedResult int
}
testCases := []testcase{
// Basic cases
{"x_x_x_", 0, 0},
{"x_x_x_", 1, 2},
{"x_x_x_", 2, 4},
{"______", 1, -1}, // No true indices
{"xxxxxx", 5, 5}, // Last true index
{"x_x_x_x_x_x_x_", 9, -1}, // Out-of-range

// Edge cases
{"xxxxxx", 7, -1}, // Out-of-range
{"______", 0, -1}, // No true indices
{"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 49, 49}, // Last true index
{"____________________________________________", 1, -1}, // No true indices
{"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 63, 63}, // last index of first word
{"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 64, 64}, // first index of second word
{"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 100, -1}, // Out-of-range

// Input beyond 64 bits
{"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 99, 99}, // Last true index

// Input less than 64 bits
{"x_x_x_", 3, -1}, // Out-of-range
}

numOriginalTestCases := len(testCases)
// Add 64 underscores to each test case
for i := 0; i < numOriginalTestCases; i++ {
expectedResult := testCases[i].ExpectedResult
if expectedResult != -1 {
expectedResult += 64
}
testCases = append(testCases, testcase{empty64Bits + testCases[i].Input, testCases[i].N, expectedResult})
}

for _, tc := range testCases {
var bitArr *BitArray
err := json.Unmarshal([]byte(`"`+tc.Input+`"`), &bitArr)
require.NoError(t, err)

// Get the nth true index
result := bitArr.getNthTrueIndex(tc.N)

require.Equal(t, tc.ExpectedResult, result, "for bit array %s, input %d, expected %d, got %d", tc.Input, tc.N, tc.ExpectedResult, result)
}
}

func TestBytes(_ *testing.T) {
bA := NewBitArray(4)
bA.SetIndex(0, true)
check := func(bA *BitArray, bz []byte) {
Expand Down Expand Up @@ -303,3 +389,15 @@ func TestUnmarshalJSONDoesntCrashOnZeroBits(t *testing.T) {
require.NoError(t, err)
require.Equal(t, ic.BitArray, &BitArray{Bits: 0, Elems: nil})
}

func BenchmarkPickRandomBitArray(b *testing.B) {
// A random 150 bit string to use as the benchmark bit array
benchmarkBitArrayStr := "_______xx__xxx_xx__x_xx_x_x_x__x_x_x_xx__xx__xxx__xx_x_xxx_x__xx____x____xx__xx____x_x__x_____xx_xx_xxxxxxx__xx_x_xxxx_x___x_xxxxx_xx__xxxx_xx_x___x_x"
var bitArr *BitArray
err := json.Unmarshal([]byte(`"`+benchmarkBitArrayStr+`"`), &bitArr)
require.NoError(b, err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = bitArr.PickRandom()
}
}

0 comments on commit 983cbaa

Please sign in to comment.