Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trie mutex refactoring #4984

Merged
merged 8 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 3 additions & 24 deletions trie/patriciaMerkleTrie.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,6 @@ func (tr *patriciaMerkleTrie) Commit() error {

// Recreate returns a new trie that has the given root hash and database
func (tr *patriciaMerkleTrie) Recreate(root []byte) (common.Trie, error) {
tr.mutOperation.Lock()
defer tr.mutOperation.Unlock()

return tr.recreate(root, tr.trieStorage)
}

Expand All @@ -255,9 +252,6 @@ func (tr *patriciaMerkleTrie) RecreateFromEpoch(options common.RootHashHolder) (
return nil, ErrNilRootHashHolder
}

tr.mutOperation.Lock()
defer tr.mutOperation.Unlock()

if !options.GetEpoch().HasValue {
return tr.recreate(options.GetRootHash(), tr.trieStorage)
}
Expand Down Expand Up @@ -301,6 +295,9 @@ func (tr *patriciaMerkleTrie) recreate(root []byte, tsm common.StorageManager) (

// String outputs a graphical view of the trie. Mainly used in tests/debugging
func (tr *patriciaMerkleTrie) String() string {
tr.mutOperation.Lock()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need here Lock?
you can use RLock

same for Get(), GetObsoleteHashes(), VerifyProof(), GetNumNodes(), GetOldRoot()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, we need the Lock/Unlock pair here and on all methods that are doing only get operations.
Reason: if the trie is collapsed/partially collapsed (as what happens after each Commit call) the nodes will hold only the collapsed version of the nodes on the last stored level. Then, at each traversing, the trie might try to fetch a collapsed node that will trigger the DB loading and altering of its containing pointers. So, because of the usage of the resolveCollapsed function, we can not use RLock/RUnlock functions. The concurrency test might not fail because is not big enough, but otherwise, it will fail on concurrent operations. Changed the test to have only one level in memory so it will fail if we change the Lock/Unlock to RLock/RUnlock.

defer tr.mutOperation.Unlock()

writer := bytes.NewBuffer(make([]byte, 0))

if tr.root == nil {
Expand Down Expand Up @@ -377,19 +374,13 @@ func (tr *patriciaMerkleTrie) recreateFromDb(rootHash []byte, tsm common.Storage

// GetSerializedNode returns the serialized node (if existing) provided the node's hash
func (tr *patriciaMerkleTrie) GetSerializedNode(hash []byte) ([]byte, error) {
tr.mutOperation.Lock()
defer tr.mutOperation.Unlock()

log.Trace("GetSerializedNode", "hash", hash)

return tr.trieStorage.Get(hash)
}

// GetSerializedNodes returns a batch of serialized nodes from the trie, starting from the given hash
func (tr *patriciaMerkleTrie) GetSerializedNodes(rootHash []byte, maxBuffToSend uint64) ([][]byte, uint64, error) {
tr.mutOperation.Lock()
defer tr.mutOperation.Unlock()

log.Trace("GetSerializedNodes", "rootHash", rootHash)
size := uint64(0)

Expand Down Expand Up @@ -452,24 +443,20 @@ func (tr *patriciaMerkleTrie) GetAllLeavesOnChannel(
return ErrNilTrieIteratorErrChannel
}

tr.mutOperation.RLock()
newTrie, err := tr.recreate(rootHash, tr.trieStorage)
if err != nil {
tr.mutOperation.RUnlock()
close(leavesChannels.LeavesChan)
close(leavesChannels.ErrChan)
return err
}

if check.IfNil(newTrie) || newTrie.root == nil {
tr.mutOperation.RUnlock()
close(leavesChannels.LeavesChan)
close(leavesChannels.ErrChan)
return nil
}

tr.trieStorage.EnterPruningBufferingMode()
tr.mutOperation.RUnlock()

go func() {
err = newTrie.root.getAllLeavesOnChannel(
Expand All @@ -485,9 +472,7 @@ func (tr *patriciaMerkleTrie) GetAllLeavesOnChannel(
log.Error("could not get all trie leaves: ", "error", err)
}

tr.mutOperation.Lock()
tr.trieStorage.ExitPruningBufferingMode()
tr.mutOperation.Unlock()

close(leavesChannels.LeavesChan)
close(leavesChannels.ErrChan)
Expand Down Expand Up @@ -619,9 +604,6 @@ func (tr *patriciaMerkleTrie) GetNumNodes() common.NumNodesDTO {

// GetStorageManager returns the storage manager for the trie
func (tr *patriciaMerkleTrie) GetStorageManager() common.StorageManager {
tr.mutOperation.Lock()
defer tr.mutOperation.Unlock()

return tr.trieStorage
}

Expand All @@ -635,13 +617,10 @@ func (tr *patriciaMerkleTrie) GetOldRoot() []byte {

// GetTrieStats will collect and return the statistics for the given rootHash
func (tr *patriciaMerkleTrie) GetTrieStats(address string, rootHash []byte) (*statistics.TrieStatsDTO, error) {
tr.mutOperation.RLock()
newTrie, err := tr.recreate(rootHash, tr.trieStorage)
if err != nil {
tr.mutOperation.RUnlock()
return nil, err
}
tr.mutOperation.RUnlock()

ts := statistics.NewTrieStatistics()
err = newTrie.root.collectStats(ts, rootDepthLevel, newTrie.trieStorage)
Expand Down
180 changes: 176 additions & 4 deletions trie/patriciaMerkleTrie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strconv"
"sync"
"testing"
"time"

"github.com/multiversx/mx-chain-core-go/core"
"github.com/multiversx/mx-chain-core-go/hashing"
Expand All @@ -18,6 +19,7 @@ import (
"github.com/multiversx/mx-chain-go/common/holders"
"github.com/multiversx/mx-chain-go/config"
"github.com/multiversx/mx-chain-go/testscommon"
"github.com/multiversx/mx-chain-go/testscommon/storage"
trieMock "github.com/multiversx/mx-chain-go/testscommon/trie"
"github.com/multiversx/mx-chain-go/trie"
"github.com/multiversx/mx-chain-go/trie/hashesHolder"
Expand All @@ -35,7 +37,7 @@ func emptyTrie() common.Trie {
return tr
}

func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, hashing.Hasher, uint) {
func getDefaultTrieStorageManagerParameters() trie.NewTrieStorageManagerArgs {
marshalizer := &testscommon.ProtobufMarshalizerMock{}
hasher := &testscommon.KeccakMock{}

Expand All @@ -44,7 +46,8 @@ func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, has
SnapshotsBufferLen: 10,
SnapshotsGoroutineNum: 1,
}
args := trie.NewTrieStorageManagerArgs{

return trie.NewTrieStorageManagerArgs{
MainStorer: testscommon.NewSnapshotPruningStorerMock(),
CheckpointsStorer: testscommon.NewSnapshotPruningStorerMock(),
Marshalizer: marshalizer,
Expand All @@ -53,10 +56,14 @@ func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, has
CheckpointHashesHolder: hashesHolder.NewCheckpointHashesHolder(10000000, testscommon.HashSize),
IdleProvider: &testscommon.ProcessStatusHandlerStub{},
}
}

func getDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, hashing.Hasher, uint) {
args := getDefaultTrieStorageManagerParameters()
trieStorageManager, _ := trie.NewTrieStorageManager(args)
maxTrieLevelInMemory := uint(5)
maxTrieLevelInMemory := uint(1)

return trieStorageManager, marshalizer, hasher, maxTrieLevelInMemory
return trieStorageManager, args.Marshalizer, args.Hasher, maxTrieLevelInMemory
}

func initTrieMultipleValues(nr int) (common.Trie, [][]byte) {
Expand Down Expand Up @@ -964,6 +971,171 @@ func TestPatriciaMerkleTree_GetValueReturnsTrieDepth(t *testing.T) {
assert.Equal(t, uint32(3), depth)
}

func TestPatriciaMerkleTrie_ConcurrentOperations(t *testing.T) {
t.Parallel()

tr := initTrie()
_ = tr.Commit()
numOperations := 1000
wg := sync.WaitGroup{}
wg.Add(numOperations)
numFunctions := 20

initialRootHash, _ := tr.RootHash()

for i := 0; i < numOperations; i++ {
go func(idx int) {
time.Sleep(time.Millisecond * 10)

operation := idx % numFunctions
switch operation {
case 0:
_, _, err := tr.Get([]byte("dog"))
assert.Nil(t, err)
case 1:
err := tr.Update([]byte("doe"), []byte("alt"))
assert.Nil(t, err)
case 2:
err := tr.Delete([]byte("alt"))
assert.Nil(t, err)
case 3:
_, err := tr.RootHash()
assert.Nil(t, err)
case 4:
err := tr.Commit()
assert.Nil(t, err)
case 5:
_, err := tr.Recreate(initialRootHash)
assert.Nil(t, err)
case 6:
epoch := core.OptionalUint32{
Value: 3,
HasValue: true,
}
rootHashHolder := holders.NewRootHashHolder(initialRootHash, epoch)
_, err := tr.RecreateFromEpoch(rootHashHolder)
assert.Nil(t, err)
case 7:
_ = tr.String()
case 8:
_ = tr.GetObsoleteHashes()
case 9:
_, err := tr.GetDirtyHashes()
assert.Nil(t, err)
case 10:
_, err := tr.GetSerializedNode(initialRootHash)
assert.Nil(t, err)
case 11:
size1KB := uint64(1024 * 1024)
_, _, err := tr.GetSerializedNodes(initialRootHash, size1KB)
assert.Nil(t, err)
case 12:
trieIteratorChannels := &common.TrieIteratorChannels{
LeavesChan: make(chan core.KeyValueHolder, 1000),
ErrChan: make(chan error, 1000),
}

err := tr.GetAllLeavesOnChannel(
trieIteratorChannels,
context.Background(),
initialRootHash,
keyBuilder.NewKeyBuilder(),
)
assert.Nil(t, err)
case 13:
_, err := tr.GetAllHashes()
assert.Nil(t, err)
case 14:
_, _, _ = tr.GetProof(initialRootHash) // this might error due to concurrent operations that change the roothash
case 15:
// extremely hard to compute an existing hash due to concurrent changes.
_, _ = tr.VerifyProof([]byte("dog"), []byte("puppy"), [][]byte{[]byte("proof1")}) // this might error due to concurrent operations that change the roothash
case 16:
numNodes := tr.GetNumNodes()
assert.Equal(t, 4, numNodes.MaxLevel)
case 17:
sm := tr.GetStorageManager()
assert.NotNil(t, sm)
case 18:
_ = tr.GetOldRoot()
case 19:
trieStatsHandler := tr.(common.TrieStats)
_, err := trieStatsHandler.GetTrieStats("address", initialRootHash)
assert.Nil(t, err)
default:
assert.Fail(t, fmt.Sprintf("invalid numFunctions value %d, operation: %d", numFunctions, operation))
}

wg.Done()
}(i)
}

wg.Wait()
}

func TestPatriciaMerkleTrie_GetSerializedNodesClose(t *testing.T) {
t.Parallel()

args := getDefaultTrieStorageManagerParameters()
args.MainStorer = &storage.StorerStub{
GetCalled: func(key []byte) ([]byte, error) {
// gets take a long time
time.Sleep(time.Millisecond * 10)
return key, nil
},
}

trieStorageManager, _ := trie.NewTrieStorageManager(args)
tr, _ := trie.NewTrie(trieStorageManager, args.Marshalizer, args.Hasher, 5)
numGoRoutines := 1000
wgStart := sync.WaitGroup{}
wgStart.Add(numGoRoutines)
wgEnd := sync.WaitGroup{}
wgEnd.Add(numGoRoutines)

for i := 0; i < numGoRoutines; i++ {
if i%2 == 0 {
go func() {
time.Sleep(time.Millisecond * 100)
wgStart.Done()

_, _, _ = tr.GetSerializedNodes([]byte("dog"), 1024)
wgEnd.Done()
}()
} else {
go func() {
time.Sleep(time.Millisecond * 100)
wgStart.Done()

_, _ = tr.GetSerializedNode([]byte("dog"))
wgEnd.Done()
}()
}
}

wgStart.Wait()
chanClosed := make(chan struct{})
go func() {
_ = tr.Close()
close(chanClosed)
}()

chanGetsEnded := make(chan struct{})
go func() {
wgEnd.Wait()
close(chanGetsEnded)
}()

timeout := time.Second * 10
select {
case <-chanClosed: // ok
case <-chanGetsEnded:
assert.Fail(t, "trie should have been closed before all gets ended")
case <-time.After(timeout):
assert.Fail(t, "timeout waiting for trie to be closed")
}
}

func BenchmarkPatriciaMerkleTree_Insert(b *testing.B) {
tr := emptyTrie()
hsh := keccak.NewKeccak()
Expand Down