Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 103 additions & 18 deletions chain/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ type Chain struct {
lastDbBlockIndex uint64
blocks []database.Block
headers []ledger.BlockHeader
waitingChan chan struct{}
waitingChanMutex sync.Mutex
iterators []*ChainIterator
}

func NewChain(
Expand Down Expand Up @@ -196,17 +199,24 @@ func (c *Chain) AddBlock(block ledger.Block, blockNonce []byte, txn *database.Tx
BlockNumber: block.BlockNumber(),
}
c.tipBlockIndex = newBlockIndex
// Notify waiting iterators
if c.waitingChan != nil {
close(c.waitingChan)
c.waitingChan = nil
}
// Generate event
c.eventBus.Publish(
ChainUpdateEventType,
event.NewEvent(
if c.eventBus != nil {
c.eventBus.Publish(
ChainUpdateEventType,
ChainBlockEvent{
Point: tmpPoint,
Block: tmpBlock,
},
),
)
event.NewEvent(
ChainUpdateEventType,
ChainBlockEvent{
Point: tmpPoint,
Block: tmpBlock,
},
),
)
}
return nil
}

Expand Down Expand Up @@ -277,16 +287,29 @@ func (c *Chain) Rollback(point ocommon.Point) error {
BlockNumber: tmpBlock.Number,
}
c.tipBlockIndex = rollbackBlockIndex
// Update iterators for rollback
for _, iter := range c.iterators {
if iter.lastPoint.Slot > point.Slot {
// Don't update rollback point if the iterator already has an older one pending
if iter.needsRollback && point.Slot > iter.rollbackPoint.Slot {
continue
}
iter.rollbackPoint = point
iter.needsRollback = true
}
}
// Generate event
c.eventBus.Publish(
ChainUpdateEventType,
event.NewEvent(
if c.eventBus != nil {
c.eventBus.Publish(
ChainUpdateEventType,
ChainRollbackEvent{
Point: point,
},
),
)
event.NewEvent(
ChainUpdateEventType,
ChainRollbackEvent{
Point: point,
},
),
)
}
return nil
}

Expand Down Expand Up @@ -319,11 +342,18 @@ func (c *Chain) HeaderRange(count int) (ocommon.Point, ocommon.Point) {
// FromPoint returns a ChainIterator starting at the specified point. If inclusive is true, the iterator
// will start at the specified point. Otherwise it will start at the point following the specified point
func (c *Chain) FromPoint(point ocommon.Point, inclusive bool) (*ChainIterator, error) {
return newChainIterator(
c.mutex.Lock()
defer c.mutex.Unlock()
iter, err := newChainIterator(
c,
point,
inclusive,
)
if err != nil {
return nil, err
}
c.iterators = append(c.iterators, iter)
return iter, nil
}

func (c *Chain) BlockByPoint(point ocommon.Point, txn *database.Txn) (database.Block, error) {
Expand Down Expand Up @@ -370,3 +400,58 @@ func (c *Chain) blockByIndex(blockIndex uint64, txn *database.Txn) (database.Blo
}
return c.blocks[memBlockIndex], nil
}

func (c *Chain) iterNext(iter *ChainIterator, blocking bool) (*ChainIteratorResult, error) {
c.mutex.RLock()
// Check for pending rollback
if iter.needsRollback {
ret := &ChainIteratorResult{}
ret.Point = iter.rollbackPoint
ret.Rollback = true
iter.lastPoint = iter.rollbackPoint
iter.needsRollback = false
if iter.rollbackPoint.Slot > 0 {
// Lookup block index for rollback point
tmpBlock, err := c.BlockByPoint(iter.rollbackPoint, nil)
if err != nil {
c.mutex.RUnlock()
return nil, err
}
iter.nextBlockIndex = tmpBlock.ID + 1
}
c.mutex.RUnlock()
return ret, nil
}
ret := &ChainIteratorResult{}
// Lookup next block in metadata DB
tmpBlock, err := c.blockByIndex(iter.nextBlockIndex, nil)
// Return immedidately if a block is found
if err == nil {
ret.Point = ocommon.NewPoint(tmpBlock.Slot, tmpBlock.Hash)
ret.Block = tmpBlock
iter.nextBlockIndex++
iter.lastPoint = ret.Point
c.mutex.RUnlock()
return ret, nil
}
// Return any actual error
if !errors.Is(err, ErrBlockNotFound) {
c.mutex.RUnlock()
return ret, err
}
// Return immediately if we're not blocking
if !blocking {
c.mutex.RUnlock()
return nil, ErrIteratorChainTip
}
c.mutex.RUnlock()
// Wait for chain update
c.waitingChanMutex.Lock()
if c.waitingChan == nil {
c.waitingChan = make(chan struct{})
}
c.waitingChanMutex.Unlock()
<-c.waitingChan
// Call ourselves again now that we should have new data
return c.iterNext(iter, blocking)
}
36 changes: 13 additions & 23 deletions chain/chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ import (
"encoding/hex"
"errors"
"testing"
"time"

"github.com/blinklabs-io/dingo/chain"
"github.com/blinklabs-io/dingo/event"
"github.com/blinklabs-io/gouroboros/ledger"
"github.com/blinklabs-io/gouroboros/ledger/common"
ocommon "github.com/blinklabs-io/gouroboros/protocol/common"
Expand Down Expand Up @@ -105,10 +103,9 @@ var (
)

func TestChainBasic(t *testing.T) {
eventBus := event.NewEventBus(nil)
c, err := chain.NewChain(
nil, // db
eventBus,
nil, // db
nil, // eventBus,
false, // persistent
)
if err != nil {
Expand Down Expand Up @@ -176,10 +173,9 @@ func TestChainBasic(t *testing.T) {
}

func TestChainRollback(t *testing.T) {
eventBus := event.NewEventBus(nil)
c, err := chain.NewChain(
nil, // db
eventBus,
nil, // db
nil, // eventBus,
false, // persistent
)
if err != nil {
Expand All @@ -194,7 +190,7 @@ func TestChainRollback(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error creating chain iterator: %s", err)
}
// Iterate until hitting chain tip, and make sure we get blocks in the correct order with
// Iterate until hitting chain tip, and make sure we get blocks in the correct order
testBlockIdx := 0
for {
next, err := iter.Next(false)
Expand Down Expand Up @@ -243,8 +239,6 @@ func TestChainRollback(t *testing.T) {
testRollbackPoint.Hash,
)
}
// XXX: how fast does this propagate to iterators?
time.Sleep(1000 * time.Millisecond)
// The chain iterator should give us a rollback
next, err := iter.Next(false)
if err != nil {
Expand All @@ -267,10 +261,9 @@ func TestChainRollback(t *testing.T) {

func TestChainHeaderRange(t *testing.T) {
testBlockCount := 3
eventBus := event.NewEventBus(nil)
c, err := chain.NewChain(
nil, // db
eventBus,
nil, // db
nil, // eventBus,
false, // persistent
)
if err != nil {
Expand Down Expand Up @@ -316,10 +309,9 @@ func TestChainHeaderRange(t *testing.T) {

func TestChainHeaderBlock(t *testing.T) {
testBlockCount := 3
eventBus := event.NewEventBus(nil)
c, err := chain.NewChain(
nil, // db
eventBus,
nil, // db
nil, // eventBus,
false, // persistent
)
if err != nil {
Expand Down Expand Up @@ -347,10 +339,9 @@ func TestChainHeaderBlock(t *testing.T) {

func TestChainHeaderWrongBlock(t *testing.T) {
testBlockCount := 3
eventBus := event.NewEventBus(nil)
c, err := chain.NewChain(
nil, // db
eventBus,
nil, // db
nil, // eventBus,
false, // persistent
)
if err != nil {
Expand Down Expand Up @@ -386,10 +377,9 @@ func TestChainHeaderWrongBlock(t *testing.T) {

func TestChainHeaderRollback(t *testing.T) {
testBlockCount := 3
eventBus := event.NewEventBus(nil)
c, err := chain.NewChain(
nil, // db
eventBus,
nil, // db
nil, // eventBus,
false, // persistent
)
if err != nil {
Expand Down
Loading
Loading