diff --git a/.changelog/unreleased/improvements/2841-speedup-bits-pick-random.md b/.changelog/unreleased/improvements/2841-speedup-bits-pick-random.md new file mode 100644 index 0000000000..b7103be262 --- /dev/null +++ b/.changelog/unreleased/improvements/2841-speedup-bits-pick-random.md @@ -0,0 +1,2 @@ +- `[internal/bits]` 10x speedup and remove heap overhead of bitArray.PickRandom (used extensively in consensus gossip) + ([\#2841](https://github.com/cometbft/cometbft/pull/2841)). diff --git a/.changelog/v0.37.5/improvements/2846-speedup-json-encoding.md b/.changelog/unreleased/improvements/2846-speedup-json-encoding.md similarity index 100% rename from .changelog/v0.37.5/improvements/2846-speedup-json-encoding.md rename to .changelog/unreleased/improvements/2846-speedup-json-encoding.md diff --git a/libs/bits/bit_array.go b/libs/bits/bit_array.go index 92845e98fe..ad4efe3154 100644 --- a/libs/bits/bit_array.go +++ b/libs/bits/bit_array.go @@ -3,6 +3,7 @@ package bits import ( "encoding/binary" "fmt" + "math/bits" "regexp" "strings" "sync" @@ -247,44 +248,61 @@ func (bA *BitArray) PickRandom() (int, bool) { } bA.mtx.Lock() - trueIndices := bA.getTrueIndices() + numTrueIndices := bA.getNumTrueIndices() + if numTrueIndices == 0 { // no bits set to true + bA.mtx.Unlock() + return 0, false + } + index := bA.getNthTrueIndex(cmtrand.Intn(numTrueIndices)) bA.mtx.Unlock() - - if len(trueIndices) == 0 { // no bits set to true + if index == -1 { return 0, false } + return index, true +} - return trueIndices[cmtrand.Intn(len(trueIndices))], true +func (bA *BitArray) getNumTrueIndices() int { + count := 0 + numElems := len(bA.Elems) + for i := 0; i < numElems; i++ { + count += bits.OnesCount64(bA.Elems[i]) + } + return count } -func (bA *BitArray) getTrueIndices() []int { - trueIndices := make([]int, 0, bA.Bits) - curBit := 0 +// getNthTrueIndex returns the index of the nth true bit in the bit array. +// n is 0 indexed. (e.g. for bitarray x__x, getNthTrueIndex(0) returns 0). +// If there is no such value, it returns -1. +func (bA *BitArray) getNthTrueIndex(n int) int { numElems := len(bA.Elems) - // set all true indices - for i := 0; i < numElems-1; i++ { - elem := bA.Elems[i] - if elem == 0 { - curBit += 64 - continue - } - for j := 0; j < 64; j++ { - if (elem & (uint64(1) << uint64(j))) > 0 { - trueIndices = append(trueIndices, curBit) + count := 0 + + // Iterate over each element + for i := 0; i < numElems; i++ { + // Count set bits in the current element + setBits := bits.OnesCount64(bA.Elems[i]) + + // If the count of set bits in this element plus the count so far + // is greater than or equal to n, then the nth bit must be in this element + if count+setBits >= n { + // Find the index of the nth set bit within this element + for j := 0; j < 64; j++ { + if bA.Elems[i]&(1< 0 { - trueIndices = append(trueIndices, curBit) + } else { + // If the count is not enough, continue to the next element + count += setBits } - curBit++ } - return trueIndices + + // If we reach here, it means n is out of range + return -1 } // String returns a string representation of BitArray: BA{}, diff --git a/libs/bits/bit_array_test.go b/libs/bits/bit_array_test.go index 7610a0987d..dce587ca00 100644 --- a/libs/bits/bit_array_test.go +++ b/libs/bits/bit_array_test.go @@ -12,6 +12,13 @@ import ( cmtrand "github.com/cometbft/cometbft/libs/rand" ) +var ( + empty16Bits = "________________" + empty64Bits = empty16Bits + empty16Bits + empty16Bits + empty16Bits + full16bits = "xxxxxxxxxxxxxxxx" + full64bits = full16bits + full16bits + full16bits + full16bits +) + func randBitArray(bits int) (*BitArray, []byte) { src := cmtrand.Bytes((bits + 7) / 8) bA := NewBitArray(bits) @@ -117,8 +124,6 @@ func TestSub(t *testing.T) { } func TestPickRandom(t *testing.T) { - empty16Bits := "________________" - empty64Bits := empty16Bits + empty16Bits + empty16Bits + empty16Bits testCases := []struct { bA string ok bool @@ -133,6 +138,7 @@ func TestPickRandom(t *testing.T) { {`"x` + empty64Bits + `"`, true}, {`"` + empty64Bits + `x"`, true}, {`"x` + empty64Bits + `x"`, true}, + {`"` + empty64Bits + `___x"`, true}, } for _, tc := range testCases { var bitArr *BitArray @@ -143,7 +149,87 @@ func TestPickRandom(t *testing.T) { } } -func TestBytes(t *testing.T) { +func TestGetNumTrueIndices(t *testing.T) { + type testcase struct { + Input string + ExpectedResult int + } + testCases := []testcase{ + {"x_x_x_", 3}, + {"______", 0}, + {"xxxxxx", 6}, + {"x_x_x_x_x_x_x_x_x_", 9}, + } + numOriginalTestCases := len(testCases) + for i := 0; i < numOriginalTestCases; i++ { + testCases = append(testCases, testcase{testCases[i].Input + "x", testCases[i].ExpectedResult + 1}) + testCases = append(testCases, testcase{full64bits + testCases[i].Input, testCases[i].ExpectedResult + 64}) + testCases = append(testCases, testcase{empty64Bits + testCases[i].Input, testCases[i].ExpectedResult}) + } + + for _, tc := range testCases { + var bitArr *BitArray + err := json.Unmarshal([]byte(`"`+tc.Input+`"`), &bitArr) + require.NoError(t, err) + result := bitArr.getNumTrueIndices() + require.Equal(t, tc.ExpectedResult, result, "for input %s, expected %d, got %d", tc.Input, tc.ExpectedResult, result) + } +} + +func TestGetNthTrueIndex(t *testing.T) { + type testcase struct { + Input string + N int + ExpectedResult int + } + testCases := []testcase{ + // Basic cases + {"x_x_x_", 0, 0}, + {"x_x_x_", 1, 2}, + {"x_x_x_", 2, 4}, + {"______", 1, -1}, // No true indices + {"xxxxxx", 5, 5}, // Last true index + {"x_x_x_x_x_x_x_", 9, -1}, // Out-of-range + + // Edge cases + {"xxxxxx", 7, -1}, // Out-of-range + {"______", 0, -1}, // No true indices + {"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 49, 49}, // Last true index + {"____________________________________________", 1, -1}, // No true indices + {"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 63, 63}, // last index of first word + {"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 64, 64}, // first index of second word + {"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 100, -1}, // Out-of-range + + // Input beyond 64 bits + {"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 99, 99}, // Last true index + + // Input less than 64 bits + {"x_x_x_", 3, -1}, // Out-of-range + } + + numOriginalTestCases := len(testCases) + // Add 64 underscores to each test case + for i := 0; i < numOriginalTestCases; i++ { + expectedResult := testCases[i].ExpectedResult + if expectedResult != -1 { + expectedResult += 64 + } + testCases = append(testCases, testcase{empty64Bits + testCases[i].Input, testCases[i].N, expectedResult}) + } + + for _, tc := range testCases { + var bitArr *BitArray + err := json.Unmarshal([]byte(`"`+tc.Input+`"`), &bitArr) + require.NoError(t, err) + + // Get the nth true index + result := bitArr.getNthTrueIndex(tc.N) + + require.Equal(t, tc.ExpectedResult, result, "for bit array %s, input %d, expected %d, got %d", tc.Input, tc.N, tc.ExpectedResult, result) + } +} + +func TestBytes(_ *testing.T) { bA := NewBitArray(4) bA.SetIndex(0, true) check := func(bA *BitArray, bz []byte) { @@ -303,3 +389,15 @@ func TestUnmarshalJSONDoesntCrashOnZeroBits(t *testing.T) { require.NoError(t, err) require.Equal(t, ic.BitArray, &BitArray{Bits: 0, Elems: nil}) } + +func BenchmarkPickRandomBitArray(b *testing.B) { + // A random 150 bit string to use as the benchmark bit array + benchmarkBitArrayStr := "_______xx__xxx_xx__x_xx_x_x_x__x_x_x_xx__xx__xxx__xx_x_xxx_x__xx____x____xx__xx____x_x__x_____xx_xx_xxxxxxx__xx_x_xxxx_x___x_xxxxx_xx__xxxx_xx_x___x_x" + var bitArr *BitArray + err := json.Unmarshal([]byte(`"`+benchmarkBitArrayStr+`"`), &bitArr) + require.NoError(b, err) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = bitArr.PickRandom() + } +}