Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions action/protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,50 @@ type (

// Views stores the view for all protocols
Views struct {
vm map[string]View
snapshotID int
snapshots map[int]map[string]int
vm map[string]View
}
)

func NewViews() *Views {
return &Views{
vm: make(map[string]View),
snapshotID: 0,
snapshots: make(map[int]map[string]int),
vm: make(map[string]View),
}
}

func (views *Views) Snapshot() int {
views.snapshotID++
views.snapshots[views.snapshotID] = make(map[string]int)
keys := make([]string, 0, len(views.vm))
for key := range views.vm {
keys = append(keys, key)
}
for _, key := range keys {
views.snapshots[views.snapshotID][key] = views.vm[key].Snapshot()
}
return views.snapshotID
}

func (views *Views) Revert(id int) error {
if id > views.snapshotID || id < 0 {
return errors.Errorf("invalid snapshot id %d, max id is %d", id, views.snapshotID)
}
for k, v := range views.snapshots[id] {
if err := views.vm[k].Revert(v); err != nil {
return err
}
}
views.snapshotID = id
// clean up snapshots that are not needed anymore
for i := id + 1; i <= views.snapshotID; i++ {
delete(views.snapshots, i)
}
return nil
}

func (views *Views) Fork() *Views {
fork := NewViews()
for key, view := range views.vm {
Expand Down
17 changes: 16 additions & 1 deletion action/protocol/staking/viewdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ type (
Wrap() ContractStakeView
// Fork forks the contract stake view, commit will not affect the original view
Fork() ContractStakeView
// IsDirty checks if the contract stake view is dirty
IsDirty() bool
// Commit commits the contract stake view
Commit(context.Context, protocol.StateManager) error
// CreatePreStates creates pre states for the contract stake view
Expand Down Expand Up @@ -91,7 +93,7 @@ func (v *viewData) Commit(ctx context.Context, sm protocol.StateManager) error {
}

func (v *viewData) IsDirty() bool {
return v.candCenter.IsDirty() || v.bucketPool.IsDirty()
return v.candCenter.IsDirty() || v.bucketPool.IsDirty() || (v.contractsStake != nil && v.contractsStake.IsDirty())
}

func (v *viewData) Snapshot() int {
Expand Down Expand Up @@ -198,6 +200,19 @@ func (csv *contractStakeView) CreatePreStates(ctx context.Context) error {
return nil
}

func (csv *contractStakeView) IsDirty() bool {
if csv.v1 != nil && csv.v1.IsDirty() {
return true
}
if csv.v2 != nil && csv.v2.IsDirty() {
return true
}
if csv.v3 != nil && csv.v3.IsDirty() {
return true
}
return false
}

func (csv *contractStakeView) Commit(ctx context.Context, sm protocol.StateManager) error {
featureCtx, ok := protocol.GetFeatureCtx(ctx)
if !ok || featureCtx.LoadContractStakingFromIndexer {
Expand Down
4 changes: 4 additions & 0 deletions blockindex/contractstaking/stakeview.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ func (s *stakeView) assembleBuckets(ids []uint64, types []*BucketType, infos []*
return vbs
}

func (s *stakeView) IsDirty() bool {
return s.cache.IsDirty()
}

func (s *stakeView) WriteBuckets(sm protocol.StateManager) error {
ids, types, infos := s.cache.Buckets()
cssm := contractstaking.NewContractStakingStateManager(sm)
Expand Down
5 changes: 5 additions & 0 deletions blockindex/contractstaking/wrappedcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ func (wc *wrappedCache) Commit(ctx context.Context, ca address.Address, sm proto
wc.base.PutBucketInfo(id, bi)
}
}
wc.updatedBucketInfos = make(map[uint64]*bucketInfo)
wc.updatedBucketTypes = make(map[uint64]*BucketType)
wc.updatedCandidates = make(map[string]map[uint64]bool)
wc.propertyBucketTypeMap = make(map[uint64]map[uint64]uint64)

return wc.base.Commit(ctx, ca, sm)
}

Expand Down
4 changes: 1 addition & 3 deletions e2etest/expect.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ func (ce *candidateExpect) expect(test *e2etest, act *action.SealedEnvelope, rec
cs := test.svr.ChainService(test.cfg.Chain.ID)
sr := cs.StateFactory()
bc := cs.Blockchain()
prtcl, ok := cs.Registry().Find("staking")
require.True(ok)
stkPrtcl := prtcl.(*staking.Protocol)
stkPrtcl := staking.FindProtocol(cs.Registry())
reqBytes, err := proto.Marshal(r)
require.NoError(err)
ctx := protocol.WithRegistry(context.Background(), cs.Registry())
Expand Down
18 changes: 17 additions & 1 deletion state/factory/workingset.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ type (
workingSetStoreFactory WorkingSetStoreFactory
height uint64
views *protocol.Views
viewsSnapshots map[int]int
store workingSetStore
finalized bool
txValidator *protocol.GenericValidator
Expand All @@ -82,6 +83,7 @@ func newWorkingSet(height uint64, views *protocol.Views, store workingSetStore,
ws := &workingSet{
height: height,
views: views,
viewsSnapshots: make(map[int]int),
store: store,
workingSetStoreFactory: storeFactory,
}
Expand Down Expand Up @@ -281,14 +283,28 @@ func (ws *workingSet) finalizeTx(ctx context.Context) {
}

func (ws *workingSet) Snapshot() int {
return ws.store.Snapshot()
id := ws.store.Snapshot()
vid := ws.views.Snapshot()
ws.viewsSnapshots[id] = vid

return id
}

func (ws *workingSet) Revert(snapshot int) error {
vid, ok := ws.viewsSnapshots[snapshot]
if !ok {
return errors.Errorf("snapshot %d not found", snapshot)
}
if err := ws.views.Revert(vid); err != nil {
return errors.Wrapf(err, "failed to revert views to snapshot %d", vid)
}
return ws.store.RevertSnapshot(snapshot)
}

func (ws *workingSet) ResetSnapshots() {
if len(ws.viewsSnapshots) > 0 {
ws.viewsSnapshots = make(map[int]int)
}
ws.store.ResetSnapshots()
}

Expand Down
4 changes: 4 additions & 0 deletions systemcontractindex/stakingindex/stakeview.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ func (s *stakeView) Fork() staking.ContractStakeView {
}
}

func (s *stakeView) IsDirty() bool {
return s.cache.IsDirty()
}

func (s *stakeView) WriteBuckets(sm protocol.StateManager) error {
ids := s.cache.BucketIdxs()
slices.Sort(ids)
Expand Down