Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions protocol/blockfetch/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"sync"
"sync/atomic"

"github.com/blinklabs-io/gouroboros/cbor"
"github.com/blinklabs-io/gouroboros/ledger"
Expand All @@ -36,6 +37,7 @@ type Client struct {
blockUseCallback bool // Whether to use callback for blocks
onceStart sync.Once // Ensures Start is only called once
onceStop sync.Once // Ensures Stop is only called once
started atomic.Bool // Whether the protocol has been started
}

// NewClient creates a new Block Fetch protocol client with the given options and configuration.
Expand Down Expand Up @@ -93,13 +95,8 @@ func (c *Client) Start() {
"protocol", ProtocolName,
"connection_id", c.callbackContext.ConnectionId.String(),
)
c.started.Store(true)
c.Protocol.Start()
// Start goroutine to cleanup resources on protocol shutdown
go func() {
<-c.DoneChan()
close(c.blockChan)
close(c.startBatchResultChan)
}()
})
}

Expand All @@ -114,7 +111,22 @@ func (c *Client) Stop() error {
"connection_id", c.callbackContext.ConnectionId.String(),
)
msg := NewMsgClientDone()
err = c.SendMessage(msg)
if sendErr := c.SendMessage(msg); sendErr != nil {
err = sendErr
}
_ = c.Protocol.Stop() // Always stop to signal muxerDoneChan
// Defer closing channels until protocol fully shuts down (only if started)
if c.started.Load() {
go func() {
<-c.DoneChan()
close(c.blockChan)
close(c.startBatchResultChan)
}()
} else {
// If protocol was never started, close channels immediately
close(c.blockChan)
close(c.startBatchResultChan)
}
})
return err
}
Expand Down Expand Up @@ -222,13 +234,11 @@ func (c *Client) handleStartBatch() error {
"role", "client",
"connection_id", c.callbackContext.ConnectionId.String(),
)
// Check for shutdown
select {
case <-c.DoneChan():
return protocol.ErrProtocolShuttingDown
default:
case c.startBatchResultChan <- nil:
}
c.startBatchResultChan <- nil
return nil
}

Expand All @@ -241,14 +251,12 @@ func (c *Client) handleNoBlocks() error {
"role", "client",
"connection_id", c.callbackContext.ConnectionId.String(),
)
// Check for shutdown
err := errors.New("block(s) not found")
select {
case <-c.DoneChan():
return protocol.ErrProtocolShuttingDown
default:
case c.startBatchResultChan <- err:
}
err := errors.New("block(s) not found")
c.startBatchResultChan <- err
return nil
}

Expand Down Expand Up @@ -298,7 +306,11 @@ func (c *Client) handleBlock(msgGeneric protocol.Message) error {
return errors.New("received block-fetch Block message but no callback function is defined")
}
} else {
c.blockChan <- block
select {
case <-c.DoneChan():
return protocol.ErrProtocolShuttingDown
case c.blockChan <- block:
}
}
return nil
}
Expand Down
21 changes: 21 additions & 0 deletions protocol/blockfetch/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,24 @@ func TestGetBlockNoBlocks(t *testing.T) {
},
)
}

func TestClientShutdown(t *testing.T) {
runTest(
t,
[]ouroboros_mock.ConversationEntry{
ouroboros_mock.ConversationEntryHandshakeRequestGeneric,
ouroboros_mock.ConversationEntryHandshakeNtNResponse,
},
func(t *testing.T, oConn *ouroboros.Connection) {
if oConn.BlockFetch() == nil {
t.Fatalf("BlockFetch client is nil")
}
// Start the client
oConn.BlockFetch().Client.Start()
// Stop the client
if err := oConn.BlockFetch().Client.Stop(); err != nil {
t.Fatalf("unexpected error when stopping client: %s", err)
}
},
)
}
4 changes: 3 additions & 1 deletion protocol/blockfetch/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ func (s *Server) handleClientDone() error {
"connection_id", s.callbackContext.ConnectionId.String(),
)
// Restart protocol
s.Stop()
if err := s.Stop(); err != nil {
return err
}
s.initProtocol()
s.Start()
return nil
Expand Down
75 changes: 60 additions & 15 deletions protocol/chainsync/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ type Client struct {
readyForNextBlockChan chan bool
onceStart sync.Once
onceStop sync.Once
started atomic.Bool
stopped atomic.Bool // prevents Start() after Stop()
syncPipelinedRequestNext int

// waitingForCurrentTipChan will process all the requests for the current tip until the channel
// is empty.
//
// want* only processes one request per message reply received from the server. If the message
// request fails, it is the responsibility of the caller to clear the channel.
Expand Down Expand Up @@ -120,25 +119,25 @@ func NewClient(

func (c *Client) Start() {
c.onceStart.Do(func() {
if c.stopped.Load() {
return
}
c.Protocol.Logger().
Debug("starting client protocol",
"component", "network",
"protocol", ProtocolName,
"connection_id", c.callbackContext.ConnectionId.String(),
)
c.started.Store(true)
c.Protocol.Start()
// Start goroutine to cleanup resources on protocol shutdown
go func() {
<-c.DoneChan()
close(c.readyForNextBlockChan)
}()
})
}

// Stop transitions the protocol to the Done state. No more protocol operations will be possible afterward
func (c *Client) Stop() error {
var err error
c.onceStop.Do(func() {
c.stopped.Store(true)
c.Protocol.Logger().
Debug("stopping client protocol",
"component", "network",
Expand All @@ -148,8 +147,30 @@ func (c *Client) Stop() error {
c.busyMutex.Lock()
defer c.busyMutex.Unlock()
msg := NewMsgDone()
if err = c.SendMessage(msg); err != nil {
return
if c.started.Load() {
if sendErr := c.SendMessage(msg); sendErr != nil {
err = sendErr
// Still proceed to stopping the protocol
}
}
if stopErr := c.Protocol.Stop(); stopErr != nil {
c.Protocol.Logger().
Error("error stopping protocol",
"component", "network",
"protocol", ProtocolName,
"connection_id", c.callbackContext.ConnectionId.String(),
"error", stopErr,
)
}
// Defer closing channel until protocol fully shuts down (only if started)
if c.started.Load() {
go func() {
<-c.DoneChan()
close(c.readyForNextBlockChan)
}()
} else {
// If protocol was never started, close channel immediately
close(c.readyForNextBlockChan)
}
})
return err
Expand Down Expand Up @@ -334,7 +355,15 @@ func (c *Client) GetAvailableBlockRange(
)
}
start = firstBlock.point
case <-c.readyForNextBlockChan:
case ready, ok := <-c.readyForNextBlockChan:
if !ok {
// Channel closed, protocol shutting down
return start, end, protocol.ErrProtocolShuttingDown
}
// Only proceed if ready is true
if !ready {
return start, end, ErrSyncCancelled
}
// Request the next block
msg := NewMsgRequestNext()
if err := c.SendMessage(msg); err != nil {
Expand Down Expand Up @@ -721,14 +750,22 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error {
if callbackErr != nil {
if errors.Is(callbackErr, ErrStopSyncProcess) {
// Signal that we're cancelling the sync
c.readyForNextBlockChan <- false
select {
case <-c.DoneChan():
return protocol.ErrProtocolShuttingDown
case c.readyForNextBlockChan <- false:
}
return nil
} else {
return callbackErr
}
}
// Signal that we're ready for the next block
c.readyForNextBlockChan <- true
select {
case <-c.DoneChan():
return protocol.ErrProtocolShuttingDown
case c.readyForNextBlockChan <- true:
}
return nil
}

Expand All @@ -752,15 +789,23 @@ func (c *Client) handleRollBackward(msgGeneric protocol.Message) error {
if callbackErr := c.config.RollBackwardFunc(c.callbackContext, msgRollBackward.Point, msgRollBackward.Tip); callbackErr != nil {
if errors.Is(callbackErr, ErrStopSyncProcess) {
// Signal that we're cancelling the sync
c.readyForNextBlockChan <- false
select {
case <-c.DoneChan():
return protocol.ErrProtocolShuttingDown
case c.readyForNextBlockChan <- false:
}
return nil
} else {
return callbackErr
}
}
}
// Signal that we're ready for the next block
c.readyForNextBlockChan <- true
select {
case <-c.DoneChan():
return protocol.ErrProtocolShuttingDown
case c.readyForNextBlockChan <- true:
}
return nil
}

Expand Down
Loading
Loading