Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 83 additions & 28 deletions state/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,26 @@ package state
import (
"errors"
"fmt"
"sync"

"github.com/blinklabs-io/dingo/database"
"github.com/blinklabs-io/dingo/event"
ochainsync "github.com/blinklabs-io/gouroboros/protocol/chainsync"
ocommon "github.com/blinklabs-io/gouroboros/protocol/common"
)

var ErrIteratorChainTip = errors.New("chain iterator is at chain tip")

type ChainIterator struct {
ls *LedgerState
startPoint ocommon.Point
blockNumber uint64
mutex sync.Mutex
ls *LedgerState
startPoint ocommon.Point
blockNumber uint64
chainUpdateSubId event.EventSubscriberId
chainUpdateChan <-chan event.Event
needsRollback bool
rollbackPoint ocommon.Point
waitingChan chan event.Event
}

type ChainIteratorResult struct {
Expand All @@ -42,9 +50,15 @@ func newChainIterator(
startPoint ocommon.Point,
inclusive bool,
) (*ChainIterator, error) {
// Subscribe to chain updates
chainUpdateSubId, chainUpdateChan := ls.config.EventBus.Subscribe(
ChainUpdateEventType,
)
ci := &ChainIterator{
ls: ls,
startPoint: startPoint,
ls: ls,
startPoint: startPoint,
chainUpdateSubId: chainUpdateSubId,
chainUpdateChan: chainUpdateChan,
}
// Lookup start block in metadata DB if not origin
if startPoint.Slot > 0 || len(startPoint.Hash) > 0 {
Expand All @@ -61,15 +75,67 @@ func newChainIterator(
ci.blockNumber++
}
}
go ci.handleChainUpdateEvents()
return ci, nil
}

func (ci *ChainIterator) handleChainUpdateEvents() {
for {
evt, ok := <-ci.chainUpdateChan
if !ok {
return
}
ci.mutex.Lock()
switch e := evt.Data.(type) {
case ChainBlockEvent:
if ci.waitingChan != nil {
// Send event without blocking
select {
case ci.waitingChan <- evt:
default:
}
}
case ChainRollbackEvent:
if ci.blockNumber > 0 {
ci.rollbackPoint = e.Point
ci.needsRollback = true
}
if ci.waitingChan != nil {
// Send event without blocking
select {
case ci.waitingChan <- evt:
default:
}
}
}
ci.mutex.Unlock()
}
}

func (ci *ChainIterator) Tip() (ochainsync.Tip, error) {
return ci.ls.chainTip(nil)
}

func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) {
ci.ls.RLock()
ci.mutex.Lock()
// Check for pending rollback
if ci.needsRollback {
ret := &ChainIteratorResult{}
ret.Point = ci.rollbackPoint
ret.Rollback = true
ci.needsRollback = false
if ci.rollbackPoint.Slot > 0 {
// Lookup block number for rollback point
tmpBlock, err := database.BlockByPoint(ci.ls.db, ci.rollbackPoint)
if err != nil {
ci.mutex.Unlock()
return nil, err
}
ci.blockNumber = tmpBlock.Number + 1
}
ci.mutex.Unlock()
return ret, nil
}
ret := &ChainIteratorResult{}
// Lookup next block in metadata DB
tmpBlock, err := database.BlockByNumber(ci.ls.db, ci.blockNumber)
Expand All @@ -78,42 +144,31 @@ func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) {
ret.Point = ocommon.NewPoint(tmpBlock.Slot, tmpBlock.Hash)
ret.Block = tmpBlock
ci.blockNumber++
ci.ls.RUnlock()
ci.mutex.Unlock()
return ret, nil
}
// Return any actual error
if !errors.Is(err, database.ErrBlockNotFound) {
ci.ls.RUnlock()
ci.mutex.Unlock()
return ret, err
}
// Check against current tip to see if it was rolled back
tip, err := ci.Tip()
if err != nil {
return nil, err
}
if ci.blockNumber > 0 && ci.blockNumber-1 > tip.BlockNumber {
ret.Point = tip.Point
ret.Rollback = true
ci.blockNumber = tip.BlockNumber + 1
ci.ls.RUnlock()
return ret, nil
}
// Return immediately if we're not blocking
if !blocking {
ci.ls.RUnlock()
ci.mutex.Unlock()
return nil, ErrIteratorChainTip
}
// Wait for new block or a rollback
chainUpdateSubId, chainUpdateChan := ci.ls.config.EventBus.Subscribe(
ChainUpdateEventType,
)
// Wait for chain update
ci.waitingChan = make(chan event.Event, 1)
// Release read lock while we wait for new event
ci.ls.RUnlock()
evt, ok := <-chainUpdateChan
ci.mutex.Unlock()
evt, ok := <-ci.waitingChan
if !ok {
// TODO: return an actual error (#389)
return nil, nil
}
ci.mutex.Lock()
defer ci.mutex.Unlock()
ci.waitingChan = nil
switch e := evt.Data.(type) {
case ChainBlockEvent:
ret.Point = e.Point
Expand All @@ -122,6 +177,7 @@ func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) {
case ChainRollbackEvent:
ret.Point = e.Point
ret.Rollback = true
ci.needsRollback = false
if e.Point.Slot > 0 {
// Lookup block number for rollback point
tmpBlock, err := database.BlockByPoint(ci.ls.db, e.Point)
Expand All @@ -133,6 +189,5 @@ func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) {
default:
return nil, fmt.Errorf("unexpected event type %T", e)
}
ci.ls.config.EventBus.Unsubscribe(ChainUpdateEventType, chainUpdateSubId)
return ret, nil
}
Loading