diff --git a/state/chain.go b/state/chain.go index 7e11ee12..de1e8164 100644 --- a/state/chain.go +++ b/state/chain.go @@ -17,8 +17,10 @@ package state import ( "errors" "fmt" + "sync" "github.com/blinklabs-io/dingo/database" + "github.com/blinklabs-io/dingo/event" ochainsync "github.com/blinklabs-io/gouroboros/protocol/chainsync" ocommon "github.com/blinklabs-io/gouroboros/protocol/common" ) @@ -26,9 +28,15 @@ import ( var ErrIteratorChainTip = errors.New("chain iterator is at chain tip") type ChainIterator struct { - ls *LedgerState - startPoint ocommon.Point - blockNumber uint64 + mutex sync.Mutex + ls *LedgerState + startPoint ocommon.Point + blockNumber uint64 + chainUpdateSubId event.EventSubscriberId + chainUpdateChan <-chan event.Event + needsRollback bool + rollbackPoint ocommon.Point + waitingChan chan event.Event } type ChainIteratorResult struct { @@ -42,9 +50,15 @@ func newChainIterator( startPoint ocommon.Point, inclusive bool, ) (*ChainIterator, error) { + // Subscribe to chain updates + chainUpdateSubId, chainUpdateChan := ls.config.EventBus.Subscribe( + ChainUpdateEventType, + ) ci := &ChainIterator{ - ls: ls, - startPoint: startPoint, + ls: ls, + startPoint: startPoint, + chainUpdateSubId: chainUpdateSubId, + chainUpdateChan: chainUpdateChan, } // Lookup start block in metadata DB if not origin if startPoint.Slot > 0 || len(startPoint.Hash) > 0 { @@ -61,15 +75,67 @@ func newChainIterator( ci.blockNumber++ } } + go ci.handleChainUpdateEvents() return ci, nil } +func (ci *ChainIterator) handleChainUpdateEvents() { + for { + evt, ok := <-ci.chainUpdateChan + if !ok { + return + } + ci.mutex.Lock() + switch e := evt.Data.(type) { + case ChainBlockEvent: + if ci.waitingChan != nil { + // Send event without blocking + select { + case ci.waitingChan <- evt: + default: + } + } + case ChainRollbackEvent: + if ci.blockNumber > 0 { + ci.rollbackPoint = e.Point + ci.needsRollback = true + } + if ci.waitingChan != nil { + // Send event without blocking + select { + case ci.waitingChan <- evt: + default: + } + } + } + ci.mutex.Unlock() + } +} + func (ci *ChainIterator) Tip() (ochainsync.Tip, error) { return ci.ls.chainTip(nil) } func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) { - ci.ls.RLock() + ci.mutex.Lock() + // Check for pending rollback + if ci.needsRollback { + ret := &ChainIteratorResult{} + ret.Point = ci.rollbackPoint + ret.Rollback = true + ci.needsRollback = false + if ci.rollbackPoint.Slot > 0 { + // Lookup block number for rollback point + tmpBlock, err := database.BlockByPoint(ci.ls.db, ci.rollbackPoint) + if err != nil { + ci.mutex.Unlock() + return nil, err + } + ci.blockNumber = tmpBlock.Number + 1 + } + ci.mutex.Unlock() + return ret, nil + } ret := &ChainIteratorResult{} // Lookup next block in metadata DB tmpBlock, err := database.BlockByNumber(ci.ls.db, ci.blockNumber) @@ -78,42 +144,31 @@ func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) { ret.Point = ocommon.NewPoint(tmpBlock.Slot, tmpBlock.Hash) ret.Block = tmpBlock ci.blockNumber++ - ci.ls.RUnlock() + ci.mutex.Unlock() return ret, nil } // Return any actual error if !errors.Is(err, database.ErrBlockNotFound) { - ci.ls.RUnlock() + ci.mutex.Unlock() return ret, err } - // Check against current tip to see if it was rolled back - tip, err := ci.Tip() - if err != nil { - return nil, err - } - if ci.blockNumber > 0 && ci.blockNumber-1 > tip.BlockNumber { - ret.Point = tip.Point - ret.Rollback = true - ci.blockNumber = tip.BlockNumber + 1 - ci.ls.RUnlock() - return ret, nil - } // Return immediately if we're not blocking if !blocking { - ci.ls.RUnlock() + ci.mutex.Unlock() return nil, ErrIteratorChainTip } - // Wait for new block or a rollback - chainUpdateSubId, chainUpdateChan := ci.ls.config.EventBus.Subscribe( - ChainUpdateEventType, - ) + // Wait for chain update + ci.waitingChan = make(chan event.Event, 1) // Release read lock while we wait for new event - ci.ls.RUnlock() - evt, ok := <-chainUpdateChan + ci.mutex.Unlock() + evt, ok := <-ci.waitingChan if !ok { // TODO: return an actual error (#389) return nil, nil } + ci.mutex.Lock() + defer ci.mutex.Unlock() + ci.waitingChan = nil switch e := evt.Data.(type) { case ChainBlockEvent: ret.Point = e.Point @@ -122,6 +177,7 @@ func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) { case ChainRollbackEvent: ret.Point = e.Point ret.Rollback = true + ci.needsRollback = false if e.Point.Slot > 0 { // Lookup block number for rollback point tmpBlock, err := database.BlockByPoint(ci.ls.db, e.Point) @@ -133,6 +189,5 @@ func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) { default: return nil, fmt.Errorf("unexpected event type %T", e) } - ci.ls.config.EventBus.Unsubscribe(ChainUpdateEventType, chainUpdateSubId) return ret, nil }