diff --git a/blockindex/contractstaking/cache.go b/blockindex/contractstaking/cache.go index 68ca596fda..fad4b39af4 100644 --- a/blockindex/contractstaking/cache.go +++ b/blockindex/contractstaking/cache.go @@ -25,7 +25,7 @@ type ( BucketInfo(id uint64) (*bucketInfo, bool) MustGetBucketInfo(id uint64) *bucketInfo MustGetBucketType(id uint64) *BucketType - MatchBucketType(amount *big.Int, duration uint64) (uint64, *BucketType, bool) + MatchBucketType(amount *big.Int, duration uint64) (uint64, *BucketType) BucketType(id uint64) (*BucketType, bool) BucketTypeCount() int Buckets() ([]uint64, []*BucketType, []*bucketInfo) @@ -210,15 +210,15 @@ func (s *contractStakingCache) DeleteBucketInfo(id uint64) { s.deltaBuckets[id] = nil } -func (s *contractStakingCache) MatchBucketType(amount *big.Int, duration uint64) (uint64, *BucketType, bool) { +func (s *contractStakingCache) MatchBucketType(amount *big.Int, duration uint64) (uint64, *BucketType) { s.mutex.RLock() defer s.mutex.RUnlock() id, ok := s.getBucketTypeIndex(amount, duration) if !ok { - return 0, nil, false + return 0, nil } - return id, s.mustGetBucketType(id), true + return id, s.mustGetBucketType(id) } func (s *contractStakingCache) LoadFromDB(kvstore db.KVStore) error { diff --git a/blockindex/contractstaking/cache_test.go b/blockindex/contractstaking/cache_test.go index f96d609908..41893b0276 100644 --- a/blockindex/contractstaking/cache_test.go +++ b/blockindex/contractstaking/cache_test.go @@ -376,33 +376,29 @@ func TestContractStakingCache_MatchBucketType(t *testing.T) { cache := newContractStakingCache() // no bucket types - _, bucketType, ok := cache.MatchBucketType(big.NewInt(100), 100) - require.False(ok) + _, bucketType := cache.MatchBucketType(big.NewInt(100), 100) require.Nil(bucketType) // one bucket type cache.PutBucketType(1, &BucketType{Amount: big.NewInt(100), Duration: 100, ActivatedAt: 1}) // match exact bucket type - id, bucketType, ok := cache.MatchBucketType(big.NewInt(100), 100) - require.True(ok) + id, bucketType := cache.MatchBucketType(big.NewInt(100), 100) + require.NotNil(bucketType) require.EqualValues(1, id) require.EqualValues(100, bucketType.Amount.Int64()) require.EqualValues(100, bucketType.Duration) require.EqualValues(1, bucketType.ActivatedAt) // match bucket type with different amount - _, bucketType, ok = cache.MatchBucketType(big.NewInt(200), 100) - require.False(ok) + _, bucketType = cache.MatchBucketType(big.NewInt(200), 100) require.Nil(bucketType) // match bucket type with different duration - _, bucketType, ok = cache.MatchBucketType(big.NewInt(100), 200) - require.False(ok) + _, bucketType = cache.MatchBucketType(big.NewInt(100), 200) require.Nil(bucketType) // no match - _, bucketType, ok = cache.MatchBucketType(big.NewInt(200), 200) - require.False(ok) + _, bucketType = cache.MatchBucketType(big.NewInt(200), 200) require.Nil(bucketType) } @@ -533,8 +529,8 @@ func TestContractStakingCache_LoadFromDB(t *testing.T) { require.Equal(bucketInfo, bi) btc = cache.BucketTypeCount() require.EqualValues(1, btc) - id, bt, ok := cache.MatchBucketType(big.NewInt(100), 100) - require.True(ok) + id, bt := cache.MatchBucketType(big.NewInt(100), 100) + require.NotNil(bt) require.EqualValues(1, id) require.EqualValues(100, bt.Amount.Int64()) require.EqualValues(100, bt.Duration) diff --git a/blockindex/contractstaking/dirty_cache.go b/blockindex/contractstaking/dirty_cache.go index 8f2ce90972..88a04cd0c3 100644 --- a/blockindex/contractstaking/dirty_cache.go +++ b/blockindex/contractstaking/dirty_cache.go @@ -73,8 +73,8 @@ func (dirty *contractStakingDirty) deleteBucketInfo(id uint64) { } func (dirty *contractStakingDirty) putBucketType(bt *BucketType) { - id, _, ok := dirty.matchBucketType(bt.Amount, bt.Duration) - if !ok { + id, old := dirty.matchBucketType(bt.Amount, bt.Duration) + if old == nil { id = dirty.getBucketTypeCount() dirty.addBucketType(id, bt) } @@ -112,7 +112,7 @@ func (dirty *contractStakingDirty) addBucketType(id uint64, bt *BucketType) { dirty.cache.PutBucketType(id, bt) } -func (dirty *contractStakingDirty) matchBucketType(amount *big.Int, duration uint64) (uint64, *BucketType, bool) { +func (dirty *contractStakingDirty) matchBucketType(amount *big.Int, duration uint64) (uint64, *BucketType) { return dirty.cache.MatchBucketType(amount, duration) } diff --git a/blockindex/contractstaking/dirty_cache_test.go b/blockindex/contractstaking/dirty_cache_test.go index b6070a28d6..9d6736bb37 100644 --- a/blockindex/contractstaking/dirty_cache_test.go +++ b/blockindex/contractstaking/dirty_cache_test.go @@ -96,15 +96,14 @@ func TestContractStakingDirty_matchBucketType(t *testing.T) { dirty := newContractStakingDirty(clean) // no bucket type - id, bt, ok := dirty.matchBucketType(big.NewInt(100), 100) - require.False(ok) + id, bt := dirty.matchBucketType(big.NewInt(100), 100) require.Nil(bt) require.EqualValues(0, id) // bucket type in clean cache clean.PutBucketType(1, &BucketType{Amount: big.NewInt(100), Duration: 100, ActivatedAt: 1}) - id, bt, ok = dirty.matchBucketType(big.NewInt(100), 100) - require.True(ok) + id, bt = dirty.matchBucketType(big.NewInt(100), 100) + require.NotNil(bt) require.EqualValues(100, bt.Amount.Int64()) require.EqualValues(100, bt.Duration) require.EqualValues(1, bt.ActivatedAt) @@ -112,8 +111,8 @@ func TestContractStakingDirty_matchBucketType(t *testing.T) { // added bucket type dirty.addBucketType(2, &BucketType{Amount: big.NewInt(200), Duration: 200, ActivatedAt: 2}) - id, bt, ok = dirty.matchBucketType(big.NewInt(200), 200) - require.True(ok) + id, bt = dirty.matchBucketType(big.NewInt(200), 200) + require.NotNil(bt) require.EqualValues(200, bt.Amount.Int64()) require.EqualValues(200, bt.Duration) require.EqualValues(2, bt.ActivatedAt) diff --git a/blockindex/contractstaking/event_handler.go b/blockindex/contractstaking/event_handler.go index c9b4aae092..abcea1db29 100644 --- a/blockindex/contractstaking/event_handler.go +++ b/blockindex/contractstaking/event_handler.go @@ -491,8 +491,8 @@ func (eh *contractStakingEventHandler) handleBucketTypeDeactivatedEvent(event ev return err } - id, bt, ok := eh.dirty.matchBucketType(amountParam, durationParam.Uint64()) - if !ok { + id, bt := eh.dirty.matchBucketType(amountParam, durationParam.Uint64()) + if bt == nil { return errors.Wrapf(errBucketTypeNotExist, "amount %d, duration %d", amountParam.Int64(), durationParam.Uint64()) } bt.ActivatedAt = maxBlockNumber @@ -519,8 +519,8 @@ func (eh *contractStakingEventHandler) handleStakedEvent(event eventParam, heigh return err } - btIdx, _, ok := eh.dirty.matchBucketType(amountParam, durationParam.Uint64()) - if !ok { + btIdx, bt := eh.dirty.matchBucketType(amountParam, durationParam.Uint64()) + if bt == nil { return errors.Wrapf(errBucketTypeNotExist, "amount %d, duration %d", amountParam.Int64(), durationParam.Uint64()) } owner, ok := eh.tokenOwner[tokenIDParam.Uint64()] @@ -557,8 +557,8 @@ func (eh *contractStakingEventHandler) handleLockedEvent(event eventParam) error if !ok { return errors.Wrapf(errBucketTypeNotExist, "id %d", b.TypeIndex) } - newBtIdx, _, ok := eh.dirty.matchBucketType(bt.Amount, durationParam.Uint64()) - if !ok { + newBtIdx, newBt := eh.dirty.matchBucketType(bt.Amount, durationParam.Uint64()) + if newBt == nil { return errors.Wrapf(errBucketTypeNotExist, "amount %v, duration %d", bt.Amount, durationParam.Uint64()) } b.TypeIndex = newBtIdx @@ -615,8 +615,8 @@ func (eh *contractStakingEventHandler) handleMergedEvent(event eventParam) error } // merge to the first bucket - btIdx, _, ok := eh.dirty.matchBucketType(amountParam, durationParam.Uint64()) - if !ok { + btIdx, bt := eh.dirty.matchBucketType(amountParam, durationParam.Uint64()) + if bt == nil { return errors.Wrapf(errBucketTypeNotExist, "amount %d, duration %d", amountParam.Int64(), durationParam.Uint64()) } b, ok := eh.dirty.getBucketInfo(tokenIDsParam[0].Uint64()) @@ -651,8 +651,8 @@ func (eh *contractStakingEventHandler) handleBucketExpandedEvent(event eventPara if !ok { return errors.Wrapf(ErrBucketNotExist, "token id %d", tokenIDParam.Uint64()) } - newBtIdx, _, ok := eh.dirty.matchBucketType(amountParam, durationParam.Uint64()) - if !ok { + newBtIdx, newBucketType := eh.dirty.matchBucketType(amountParam, durationParam.Uint64()) + if newBucketType == nil { return errors.Wrapf(errBucketTypeNotExist, "amount %d, duration %d", amountParam.Int64(), durationParam.Uint64()) } b.TypeIndex = newBtIdx diff --git a/blockindex/contractstaking/indexer.go b/blockindex/contractstaking/indexer.go index d9adc6dfc8..7e21fe5bf2 100644 --- a/blockindex/contractstaking/indexer.go +++ b/blockindex/contractstaking/indexer.go @@ -128,8 +128,8 @@ func (s *Indexer) LoadStakeView(ctx context.Context, sr protocol.StateReader) (s if buckets[i] == nil { return nil, errors.New("bucket is nil") } - tid, _, ok := cache.MatchBucketType(buckets[i].StakedAmount, buckets[i].StakedDuration) - if !ok { + tid, bt := cache.MatchBucketType(buckets[i].StakedAmount, buckets[i].StakedDuration) + if bt == nil { return nil, errors.Errorf( "no bucket type found for bucket %d with staked amount %s and duration %d", id, diff --git a/blockindex/contractstaking/wrappedcache.go b/blockindex/contractstaking/wrappedcache.go index 226e6ea074..7ce7f15324 100644 --- a/blockindex/contractstaking/wrappedcache.go +++ b/blockindex/contractstaking/wrappedcache.go @@ -198,8 +198,8 @@ func (wc *wrappedCache) PutBucketType(id uint64, bt *BucketType) { panic("bucket type amount or duration cannot be changed") } } - oldId, _, ok := wc.matchBucketType(bt.Amount, bt.Duration) - if ok && oldId != id { + oldId, oldBucketType := wc.matchBucketType(bt.Amount, bt.Duration) + if oldBucketType != nil && oldId != id { panic("bucket type with same amount and duration already exists") } if _, ok := wc.propertyBucketTypeMap[bt.Amount.Uint64()]; !ok { @@ -313,21 +313,21 @@ func (wc *wrappedCache) DeleteBucketInfo(id uint64) { wc.updatedBucketInfos[id] = nil } -func (wc *wrappedCache) MatchBucketType(amount *big.Int, duration uint64) (uint64, *BucketType, bool) { +func (wc *wrappedCache) MatchBucketType(amount *big.Int, duration uint64) (uint64, *BucketType) { wc.mu.RLock() defer wc.mu.RUnlock() return wc.matchBucketType(amount, duration) } -func (wc *wrappedCache) matchBucketType(amount *big.Int, duration uint64) (uint64, *BucketType, bool) { +func (wc *wrappedCache) matchBucketType(amount *big.Int, duration uint64) (uint64, *BucketType) { amountUint64 := amount.Uint64() if amountMap, ok := wc.propertyBucketTypeMap[amountUint64]; ok { if id, ok := amountMap[duration]; ok { if bt, ok := wc.updatedBucketTypes[id]; ok { if bt != nil { - return id, bt, true + return id, bt } - return 0, nil, false + return 0, nil } } }