diff --git a/action/protocol/protocol.go b/action/protocol/protocol.go index 4d609ddd80..2e89130873 100644 --- a/action/protocol/protocol.go +++ b/action/protocol/protocol.go @@ -120,16 +120,50 @@ type ( // Views stores the view for all protocols Views struct { - vm map[string]View + snapshotID int + snapshots map[int]map[string]int + vm map[string]View } ) func NewViews() *Views { return &Views{ - vm: make(map[string]View), + snapshotID: 0, + snapshots: make(map[int]map[string]int), + vm: make(map[string]View), } } +func (views *Views) Snapshot() int { + views.snapshotID++ + views.snapshots[views.snapshotID] = make(map[string]int) + keys := make([]string, 0, len(views.vm)) + for key := range views.vm { + keys = append(keys, key) + } + for _, key := range keys { + views.snapshots[views.snapshotID][key] = views.vm[key].Snapshot() + } + return views.snapshotID +} + +func (views *Views) Revert(id int) error { + if id > views.snapshotID || id < 0 { + return errors.Errorf("invalid snapshot id %d, max id is %d", id, views.snapshotID) + } + for k, v := range views.snapshots[id] { + if err := views.vm[k].Revert(v); err != nil { + return err + } + } + views.snapshotID = id + // clean up snapshots that are not needed anymore + for i := id + 1; i <= views.snapshotID; i++ { + delete(views.snapshots, i) + } + return nil +} + func (views *Views) Fork() *Views { fork := NewViews() for key, view := range views.vm { diff --git a/action/protocol/staking/viewdata.go b/action/protocol/staking/viewdata.go index 108b33bf6d..ece4958849 100644 --- a/action/protocol/staking/viewdata.go +++ b/action/protocol/staking/viewdata.go @@ -22,6 +22,8 @@ type ( Wrap() ContractStakeView // Fork forks the contract stake view, commit will not affect the original view Fork() ContractStakeView + // IsDirty checks if the contract stake view is dirty + IsDirty() bool // Commit commits the contract stake view Commit(context.Context, protocol.StateManager) error // CreatePreStates creates pre states for the contract stake view @@ -91,7 +93,7 @@ func (v *viewData) Commit(ctx context.Context, sm protocol.StateManager) error { } func (v *viewData) IsDirty() bool { - return v.candCenter.IsDirty() || v.bucketPool.IsDirty() + return v.candCenter.IsDirty() || v.bucketPool.IsDirty() || (v.contractsStake != nil && v.contractsStake.IsDirty()) } func (v *viewData) Snapshot() int { @@ -198,6 +200,19 @@ func (csv *contractStakeView) CreatePreStates(ctx context.Context) error { return nil } +func (csv *contractStakeView) IsDirty() bool { + if csv.v1 != nil && csv.v1.IsDirty() { + return true + } + if csv.v2 != nil && csv.v2.IsDirty() { + return true + } + if csv.v3 != nil && csv.v3.IsDirty() { + return true + } + return false +} + func (csv *contractStakeView) Commit(ctx context.Context, sm protocol.StateManager) error { featureCtx, ok := protocol.GetFeatureCtx(ctx) if !ok || featureCtx.LoadContractStakingFromIndexer { diff --git a/blockindex/contractstaking/stakeview.go b/blockindex/contractstaking/stakeview.go index 44bd1d2724..6de2803567 100644 --- a/blockindex/contractstaking/stakeview.go +++ b/blockindex/contractstaking/stakeview.go @@ -53,6 +53,10 @@ func (s *stakeView) assembleBuckets(ids []uint64, types []*BucketType, infos []* return vbs } +func (s *stakeView) IsDirty() bool { + return s.cache.IsDirty() +} + func (s *stakeView) WriteBuckets(sm protocol.StateManager) error { ids, types, infos := s.cache.Buckets() cssm := contractstaking.NewContractStakingStateManager(sm) diff --git a/blockindex/contractstaking/wrappedcache.go b/blockindex/contractstaking/wrappedcache.go index 058515d75b..226e6ea074 100644 --- a/blockindex/contractstaking/wrappedcache.go +++ b/blockindex/contractstaking/wrappedcache.go @@ -287,6 +287,11 @@ func (wc *wrappedCache) Commit(ctx context.Context, ca address.Address, sm proto wc.base.PutBucketInfo(id, bi) } } + wc.updatedBucketInfos = make(map[uint64]*bucketInfo) + wc.updatedBucketTypes = make(map[uint64]*BucketType) + wc.updatedCandidates = make(map[string]map[uint64]bool) + wc.propertyBucketTypeMap = make(map[uint64]map[uint64]uint64) + return wc.base.Commit(ctx, ca, sm) } diff --git a/e2etest/expect.go b/e2etest/expect.go index 2b4fe31d8e..b81e8de105 100644 --- a/e2etest/expect.go +++ b/e2etest/expect.go @@ -97,9 +97,7 @@ func (ce *candidateExpect) expect(test *e2etest, act *action.SealedEnvelope, rec cs := test.svr.ChainService(test.cfg.Chain.ID) sr := cs.StateFactory() bc := cs.Blockchain() - prtcl, ok := cs.Registry().Find("staking") - require.True(ok) - stkPrtcl := prtcl.(*staking.Protocol) + stkPrtcl := staking.FindProtocol(cs.Registry()) reqBytes, err := proto.Marshal(r) require.NoError(err) ctx := protocol.WithRegistry(context.Background(), cs.Registry()) diff --git a/state/factory/workingset.go b/state/factory/workingset.go index beacc73233..b0c18663c1 100644 --- a/state/factory/workingset.go +++ b/state/factory/workingset.go @@ -71,6 +71,7 @@ type ( workingSetStoreFactory WorkingSetStoreFactory height uint64 views *protocol.Views + viewsSnapshots map[int]int store workingSetStore finalized bool txValidator *protocol.GenericValidator @@ -82,6 +83,7 @@ func newWorkingSet(height uint64, views *protocol.Views, store workingSetStore, ws := &workingSet{ height: height, views: views, + viewsSnapshots: make(map[int]int), store: store, workingSetStoreFactory: storeFactory, } @@ -281,14 +283,28 @@ func (ws *workingSet) finalizeTx(ctx context.Context) { } func (ws *workingSet) Snapshot() int { - return ws.store.Snapshot() + id := ws.store.Snapshot() + vid := ws.views.Snapshot() + ws.viewsSnapshots[id] = vid + + return id } func (ws *workingSet) Revert(snapshot int) error { + vid, ok := ws.viewsSnapshots[snapshot] + if !ok { + return errors.Errorf("snapshot %d not found", snapshot) + } + if err := ws.views.Revert(vid); err != nil { + return errors.Wrapf(err, "failed to revert views to snapshot %d", vid) + } return ws.store.RevertSnapshot(snapshot) } func (ws *workingSet) ResetSnapshots() { + if len(ws.viewsSnapshots) > 0 { + ws.viewsSnapshots = make(map[int]int) + } ws.store.ResetSnapshots() } diff --git a/systemcontractindex/stakingindex/stakeview.go b/systemcontractindex/stakingindex/stakeview.go index 829a478cf6..dd49e80278 100644 --- a/systemcontractindex/stakingindex/stakeview.go +++ b/systemcontractindex/stakingindex/stakeview.go @@ -50,6 +50,10 @@ func (s *stakeView) Fork() staking.ContractStakeView { } } +func (s *stakeView) IsDirty() bool { + return s.cache.IsDirty() +} + func (s *stakeView) WriteBuckets(sm protocol.StateManager) error { ids := s.cache.BucketIdxs() slices.Sort(ids)