diff --git a/protocol/blockfetch/client.go b/protocol/blockfetch/client.go index d3ad3eb7..36caf954 100644 --- a/protocol/blockfetch/client.go +++ b/protocol/blockfetch/client.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "github.com/blinklabs-io/gouroboros/cbor" "github.com/blinklabs-io/gouroboros/ledger" @@ -36,6 +37,7 @@ type Client struct { blockUseCallback bool // Whether to use callback for blocks onceStart sync.Once // Ensures Start is only called once onceStop sync.Once // Ensures Stop is only called once + started atomic.Bool // Whether the protocol has been started } // NewClient creates a new Block Fetch protocol client with the given options and configuration. @@ -93,13 +95,8 @@ func (c *Client) Start() { "protocol", ProtocolName, "connection_id", c.callbackContext.ConnectionId.String(), ) + c.started.Store(true) c.Protocol.Start() - // Start goroutine to cleanup resources on protocol shutdown - go func() { - <-c.DoneChan() - close(c.blockChan) - close(c.startBatchResultChan) - }() }) } @@ -114,7 +111,22 @@ func (c *Client) Stop() error { "connection_id", c.callbackContext.ConnectionId.String(), ) msg := NewMsgClientDone() - err = c.SendMessage(msg) + if sendErr := c.SendMessage(msg); sendErr != nil { + err = sendErr + } + _ = c.Protocol.Stop() // Always stop to signal muxerDoneChan + // Defer closing channels until protocol fully shuts down (only if started) + if c.started.Load() { + go func() { + <-c.DoneChan() + close(c.blockChan) + close(c.startBatchResultChan) + }() + } else { + // If protocol was never started, close channels immediately + close(c.blockChan) + close(c.startBatchResultChan) + } }) return err } @@ -222,13 +234,11 @@ func (c *Client) handleStartBatch() error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) - // Check for shutdown select { case <-c.DoneChan(): return protocol.ErrProtocolShuttingDown - default: + case c.startBatchResultChan <- nil: } - c.startBatchResultChan <- nil return nil } @@ -241,14 +251,12 @@ func (c *Client) handleNoBlocks() error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) - // Check for shutdown + err := errors.New("block(s) not found") select { case <-c.DoneChan(): return protocol.ErrProtocolShuttingDown - default: + case c.startBatchResultChan <- err: } - err := errors.New("block(s) not found") - c.startBatchResultChan <- err return nil } @@ -298,7 +306,11 @@ func (c *Client) handleBlock(msgGeneric protocol.Message) error { return errors.New("received block-fetch Block message but no callback function is defined") } } else { - c.blockChan <- block + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.blockChan <- block: + } } return nil } diff --git a/protocol/blockfetch/client_test.go b/protocol/blockfetch/client_test.go index 1a1c3a12..d44ef730 100644 --- a/protocol/blockfetch/client_test.go +++ b/protocol/blockfetch/client_test.go @@ -207,3 +207,24 @@ func TestGetBlockNoBlocks(t *testing.T) { }, ) } + +func TestClientShutdown(t *testing.T) { + runTest( + t, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + func(t *testing.T, oConn *ouroboros.Connection) { + if oConn.BlockFetch() == nil { + t.Fatalf("BlockFetch client is nil") + } + // Start the client + oConn.BlockFetch().Client.Start() + // Stop the client + if err := oConn.BlockFetch().Client.Stop(); err != nil { + t.Fatalf("unexpected error when stopping client: %s", err) + } + }, + ) +} diff --git a/protocol/blockfetch/server.go b/protocol/blockfetch/server.go index 1ecc0558..13552c75 100644 --- a/protocol/blockfetch/server.go +++ b/protocol/blockfetch/server.go @@ -176,7 +176,9 @@ func (s *Server) handleClientDone() error { "connection_id", s.callbackContext.ConnectionId.String(), ) // Restart protocol - s.Stop() + if err := s.Stop(); err != nil { + return err + } s.initProtocol() s.Start() return nil diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index 0bf49a99..ae7df712 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -34,10 +34,9 @@ type Client struct { readyForNextBlockChan chan bool onceStart sync.Once onceStop sync.Once + started atomic.Bool + stopped atomic.Bool // prevents Start() after Stop() syncPipelinedRequestNext int - - // waitingForCurrentTipChan will process all the requests for the current tip until the channel - // is empty. // // want* only processes one request per message reply received from the server. If the message // request fails, it is the responsibility of the caller to clear the channel. @@ -120,18 +119,17 @@ func NewClient( func (c *Client) Start() { c.onceStart.Do(func() { + if c.stopped.Load() { + return + } c.Protocol.Logger(). Debug("starting client protocol", "component", "network", "protocol", ProtocolName, "connection_id", c.callbackContext.ConnectionId.String(), ) + c.started.Store(true) c.Protocol.Start() - // Start goroutine to cleanup resources on protocol shutdown - go func() { - <-c.DoneChan() - close(c.readyForNextBlockChan) - }() }) } @@ -139,6 +137,7 @@ func (c *Client) Start() { func (c *Client) Stop() error { var err error c.onceStop.Do(func() { + c.stopped.Store(true) c.Protocol.Logger(). Debug("stopping client protocol", "component", "network", @@ -148,8 +147,30 @@ func (c *Client) Stop() error { c.busyMutex.Lock() defer c.busyMutex.Unlock() msg := NewMsgDone() - if err = c.SendMessage(msg); err != nil { - return + if c.started.Load() { + if sendErr := c.SendMessage(msg); sendErr != nil { + err = sendErr + // Still proceed to stopping the protocol + } + } + if stopErr := c.Protocol.Stop(); stopErr != nil { + c.Protocol.Logger(). + Error("error stopping protocol", + "component", "network", + "protocol", ProtocolName, + "connection_id", c.callbackContext.ConnectionId.String(), + "error", stopErr, + ) + } + // Defer closing channel until protocol fully shuts down (only if started) + if c.started.Load() { + go func() { + <-c.DoneChan() + close(c.readyForNextBlockChan) + }() + } else { + // If protocol was never started, close channel immediately + close(c.readyForNextBlockChan) } }) return err @@ -334,7 +355,15 @@ func (c *Client) GetAvailableBlockRange( ) } start = firstBlock.point - case <-c.readyForNextBlockChan: + case ready, ok := <-c.readyForNextBlockChan: + if !ok { + // Channel closed, protocol shutting down + return start, end, protocol.ErrProtocolShuttingDown + } + // Only proceed if ready is true + if !ready { + return start, end, ErrSyncCancelled + } // Request the next block msg := NewMsgRequestNext() if err := c.SendMessage(msg); err != nil { @@ -721,14 +750,22 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { if callbackErr != nil { if errors.Is(callbackErr, ErrStopSyncProcess) { // Signal that we're cancelling the sync - c.readyForNextBlockChan <- false + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.readyForNextBlockChan <- false: + } return nil } else { return callbackErr } } // Signal that we're ready for the next block - c.readyForNextBlockChan <- true + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.readyForNextBlockChan <- true: + } return nil } @@ -752,7 +789,11 @@ func (c *Client) handleRollBackward(msgGeneric protocol.Message) error { if callbackErr := c.config.RollBackwardFunc(c.callbackContext, msgRollBackward.Point, msgRollBackward.Tip); callbackErr != nil { if errors.Is(callbackErr, ErrStopSyncProcess) { // Signal that we're cancelling the sync - c.readyForNextBlockChan <- false + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.readyForNextBlockChan <- false: + } return nil } else { return callbackErr @@ -760,7 +801,11 @@ func (c *Client) handleRollBackward(msgGeneric protocol.Message) error { } } // Signal that we're ready for the next block - c.readyForNextBlockChan <- true + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.readyForNextBlockChan <- true: + } return nil } diff --git a/protocol/chainsync/client_concurrency_test.go b/protocol/chainsync/client_concurrency_test.go new file mode 100644 index 00000000..f5b86c1a --- /dev/null +++ b/protocol/chainsync/client_concurrency_test.go @@ -0,0 +1,148 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package chainsync_test + +import ( + "sync" + "testing" + "time" + + ouroboros "github.com/blinklabs-io/gouroboros" + "github.com/blinklabs-io/gouroboros/protocol/chainsync" + ouroboros_mock "github.com/blinklabs-io/ouroboros-mock" + + "go.uber.org/goleak" +) + +// TestConcurrentStartStop tests that concurrent Start/Stop operations don't cause deadlocks or races +func TestConcurrentStartStop(t *testing.T) { + defer goleak.VerifyNone(t) + + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtCResponse, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithChainSyncConfig(chainsync.NewConfig()), + ) + if err != nil { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + defer func() { + if err := oConn.Close(); err != nil { + t.Errorf("unexpected error when closing Ouroboros object: %s", err) + } + }() + + client := oConn.ChainSync().Client + if client == nil { + t.Fatalf("ChainSync client is nil") + } + + // Run concurrent Start/Stop operations + var wg sync.WaitGroup + const numGoroutines = 10 + const numOperations = 5 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + // Start the client + client.Start() + + // Small delay to allow operations to interleave + time.Sleep(time.Millisecond) + + // Stop the client + if err := client.Stop(); err != nil { + t.Errorf( + "goroutine %d: unexpected error when stopping client: %s", + id, + err, + ) + } + } + }(i) + } + + // Wait for all goroutines to complete with timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // All goroutines completed successfully + case <-time.After(10 * time.Second): + t.Fatal( + "concurrent Start/Stop operations timed out - possible deadlock", + ) + } +} + +// TestStopBeforeStart tests that Stop works correctly when called before Start +func TestStopBeforeStart(t *testing.T) { + defer goleak.VerifyNone(t) + + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtCResponse, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithChainSyncConfig(chainsync.NewConfig()), + ) + if err != nil { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + defer func() { + if err := oConn.Close(); err != nil { + t.Errorf("unexpected error when closing Ouroboros object: %s", err) + } + }() + + client := oConn.ChainSync().Client + if client == nil { + t.Fatalf("ChainSync client is nil") + } + + // Stop before Start - should not panic or deadlock + if err := client.Stop(); err != nil { + t.Errorf("unexpected error when stopping unstarted client: %s", err) + } + + // Now Start should work normally (but should not actually start due to stopped flag) + client.Start() + + // Stop again should work + if err := client.Stop(); err != nil { + t.Errorf("unexpected error when stopping client: %s", err) + } +} diff --git a/protocol/chainsync/client_test.go b/protocol/chainsync/client_test.go index 32baf743..a71541ee 100644 --- a/protocol/chainsync/client_test.go +++ b/protocol/chainsync/client_test.go @@ -80,13 +80,19 @@ func runTest( }() // Run test inner function innerFunc(t, oConn) + // Stop the client to clean up goroutines + if client := oConn.ChainSync().Client; client != nil { + if err := client.Stop(); err != nil { + t.Logf("client.Stop error: %v", err) + } + } // Wait for mock connection shutdown select { case err, ok := <-asyncErrChan: if ok { t.Fatal(err.Error()) } - case <-time.After(2 * time.Second): + case <-time.After(5 * time.Second): t.Fatalf("did not complete within timeout") } // Close Ouroboros connection @@ -275,3 +281,24 @@ func TestGetAvailableBlockRange(t *testing.T) { }, ) } + +func TestClientShutdown(t *testing.T) { + runTest( + t, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtCResponse, + }, + func(t *testing.T, oConn *ouroboros.Connection) { + if oConn.ChainSync() == nil { + t.Fatalf("ChainSync client is nil") + } + // Start the client + oConn.ChainSync().Client.Start() + // Stop the client + if err := oConn.ChainSync().Client.Stop(); err != nil { + t.Fatalf("unexpected error when stopping client: %s", err) + } + }, + ) +} diff --git a/protocol/chainsync/error.go b/protocol/chainsync/error.go index c7709c2d..7a017913 100644 --- a/protocol/chainsync/error.go +++ b/protocol/chainsync/error.go @@ -21,3 +21,6 @@ var ErrIntersectNotFound = errors.New("chain intersection not found") // StopChainSync is used as a special return value from a RollForward or RollBackward handler function // to signify that the sync process should be stopped var ErrStopSyncProcess = errors.New("stop sync process") + +// ErrSyncCancelled is returned when a sync operation is cancelled +var ErrSyncCancelled = errors.New("sync cancelled") diff --git a/protocol/chainsync/server.go b/protocol/chainsync/server.go index 73588334..5eec1159 100644 --- a/protocol/chainsync/server.go +++ b/protocol/chainsync/server.go @@ -242,7 +242,9 @@ func (s *Server) handleDone() error { ) } // Restart protocol - s.Stop() + if err := s.Stop(); err != nil { + return err + } s.initProtocol() s.Start() return nil diff --git a/protocol/keepalive/client.go b/protocol/keepalive/client.go index d9c2f992..8bf290b5 100644 --- a/protocol/keepalive/client.go +++ b/protocol/keepalive/client.go @@ -17,6 +17,7 @@ package keepalive import ( "fmt" "sync" + "sync/atomic" "time" "github.com/blinklabs-io/gouroboros/protocol" @@ -30,6 +31,9 @@ type Client struct { timer *time.Timer timerMutex sync.Mutex onceStart sync.Once + onceStop sync.Once + stopErr error + started atomic.Bool } // NewClient creates and returns a new keep-alive protocol client with the given options and configuration. @@ -72,6 +76,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { // Start begins the keep-alive protocol client and starts sending keep-alive messages at the configured interval. func (c *Client) Start() { c.onceStart.Do(func() { + c.started.Store(true) c.Protocol.Logger(). Debug("starting client protocol", "component", "network", @@ -93,6 +98,39 @@ func (c *Client) Start() { }) } +// Stop stops the KeepAlive client protocol +func (c *Client) Stop() error { + c.onceStop.Do(func() { + c.Protocol.Logger(). + Debug("stopping client protocol", + "component", "network", + "protocol", ProtocolName, + "connection_id", c.callbackContext.ConnectionId.String(), + ) + // Stop the keep-alive timer to prevent any further sends + c.timerMutex.Lock() + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } + c.timerMutex.Unlock() + // Only send MsgDone if the protocol was actually started; otherwise + // avoid blocking on a send queue that is not being drained. + if c.started.Load() { + msg := NewMsgDone() + sendErr := c.SendMessage(msg) + if sendErr != nil { + c.stopErr = sendErr + } + stopErr := c.Protocol.Stop() + if c.stopErr == nil { + c.stopErr = stopErr + } + } + }) + return c.stopErr +} + // sendKeepAlive sends a keep-alive message and schedules the next one. func (c *Client) sendKeepAlive() { msg := NewMsgKeepAlive(c.config.Cookie) diff --git a/protocol/keepalive/client_test.go b/protocol/keepalive/client_test.go index dbd48bdf..aaa49241 100644 --- a/protocol/keepalive/client_test.go +++ b/protocol/keepalive/client_test.go @@ -238,3 +238,85 @@ func TestServerKeepaliveHandlingWithDifferentCookie(t *testing.T) { t.Errorf("did not shutdown within timeout") } } + +type testInnerFunc func(*testing.T, *ouroboros.Connection) + +func runTest( + t *testing.T, + conversation []ouroboros_mock.ConversationEntry, + innerFunc testInnerFunc, +) { + defer goleak.VerifyNone(t) + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + conversation, + ) + // Async mock connection error handler + asyncErrChan := make(chan error, 1) + go func() { + err := <-mockConn.(*ouroboros_mock.Connection).ErrorChan() + if err != nil { + asyncErrChan <- fmt.Errorf("received unexpected error: %w", err) + } + close(asyncErrChan) + }() + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + // Async error handler + go func() { + err, ok := <-oConn.ErrorChan() + if !ok { + return + } + // We can't call t.Fatalf() from a different Goroutine, so we panic instead + panic(fmt.Sprintf("unexpected Ouroboros error: %s", err)) + }() + // Run test inner function + innerFunc(t, oConn) + // Wait for mock connection shutdown + select { + case err, ok := <-asyncErrChan: + if ok { + t.Fatal(err.Error()) + } + case <-time.After(2 * time.Second): + t.Fatalf("did not complete within timeout") + } + // Close Ouroboros connection + if err := oConn.Close(); err != nil { + t.Fatalf("unexpected error when closing Ouroboros object: %s", err) + } + // Wait for connection shutdown + select { + case <-oConn.ErrorChan(): + case <-time.After(10 * time.Second): + t.Errorf("did not shutdown within timeout") + } +} + +func TestClientShutdown(t *testing.T) { + runTest( + t, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + func(t *testing.T, oConn *ouroboros.Connection) { + if oConn.KeepAlive() == nil { + t.Fatalf("KeepAlive client is nil") + } + // Start the client + oConn.KeepAlive().Client.Start() + // Stop the client + if err := oConn.KeepAlive().Client.Stop(); err != nil { + t.Fatalf("unexpected error when stopping client: %s", err) + } + }, + ) +} diff --git a/protocol/keepalive/keepalive.go b/protocol/keepalive/keepalive.go index 940a6b62..223e50b6 100644 --- a/protocol/keepalive/keepalive.go +++ b/protocol/keepalive/keepalive.go @@ -74,6 +74,10 @@ var StateMap = protocol.StateMap{ MsgType: MessageTypeKeepAliveResponse, NewState: StateClient, }, + { + MsgType: MessageTypeDone, + NewState: StateDone, + }, }, }, StateDone: protocol.StateMapEntry{ diff --git a/protocol/leiosfetch/client.go b/protocol/leiosfetch/client.go index 77487e8a..f3893984 100644 --- a/protocol/leiosfetch/client.go +++ b/protocol/leiosfetch/client.go @@ -17,6 +17,7 @@ package leiosfetch import ( "fmt" "sync" + "sync/atomic" "github.com/blinklabs-io/gouroboros/protocol" pcommon "github.com/blinklabs-io/gouroboros/protocol/common" @@ -28,6 +29,8 @@ type Client struct { callbackContext CallbackContext onceStart sync.Once onceStop sync.Once + started atomic.Bool // Used internally to track protocol lifecycle for Stop() behavior + stopped atomic.Bool // Used to prevent Start() after Stop() blockResultChan chan protocol.Message blockTxsResultChan chan protocol.Message votesResultChan chan protocol.Message @@ -88,21 +91,18 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { func (c *Client) Start() { c.onceStart.Do(func() { + if c.stopped.Load() { + // Cannot start a client that has been stopped + return + } c.Protocol.Logger(). Debug("starting client protocol", "component", "network", "protocol", ProtocolName, "connection_id", c.callbackContext.ConnectionId.String(), ) + c.started.Store(true) c.Protocol.Start() - // Start goroutine to cleanup resources on protocol shutdown - go func() { - <-c.DoneChan() - close(c.blockResultChan) - close(c.blockTxsResultChan) - close(c.votesResultChan) - close(c.blockRangeResultChan) - }() }) } @@ -116,7 +116,34 @@ func (c *Client) Stop() error { "connection_id", c.callbackContext.ConnectionId.String(), ) msg := NewMsgDone() - err = c.SendMessage(msg) + // Only send MsgDone if the protocol was actually started; otherwise + // avoid blocking on a send queue that is not being drained. + if c.started.Load() { + if sendErr := c.SendMessage(msg); sendErr != nil { + // Preserve the SendMessage error but still shut down the protocol. + err = sendErr + } + } + // Always attempt to stop the protocol so DoneChan and muxer shutdown complete. + _ = c.Protocol.Stop() // Stop error ignored; err already reflects SendMessage failure if any + // Defer closing channels until protocol fully shuts down (only if started) + if c.started.Load() { + go func() { + <-c.DoneChan() + close(c.blockResultChan) + close(c.blockTxsResultChan) + close(c.votesResultChan) + close(c.blockRangeResultChan) + }() + } else { + // If protocol was never started, close channels immediately + close(c.blockResultChan) + close(c.blockTxsResultChan) + close(c.votesResultChan) + close(c.blockRangeResultChan) + } + c.started.Store(false) + c.stopped.Store(true) }) return err } @@ -217,26 +244,46 @@ func (c *Client) messageHandler(msg protocol.Message) error { } func (c *Client) handleBlock(msg protocol.Message) error { - c.blockResultChan <- msg + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.blockResultChan <- msg: + } return nil } func (c *Client) handleBlockTxs(msg protocol.Message) error { - c.blockTxsResultChan <- msg + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.blockTxsResultChan <- msg: + } return nil } func (c *Client) handleVotes(msg protocol.Message) error { - c.votesResultChan <- msg + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.votesResultChan <- msg: + } return nil } func (c *Client) handleNextBlockAndTxsInRange(msg protocol.Message) error { - c.blockRangeResultChan <- msg + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.blockRangeResultChan <- msg: + } return nil } func (c *Client) handleLastBlockAndTxsInRange(msg protocol.Message) error { - c.blockRangeResultChan <- msg + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.blockRangeResultChan <- msg: + } return nil } diff --git a/protocol/leiosfetch/server.go b/protocol/leiosfetch/server.go index 46a9c6e9..720aa04c 100644 --- a/protocol/leiosfetch/server.go +++ b/protocol/leiosfetch/server.go @@ -200,7 +200,9 @@ func (s *Server) handleDone() error { "connection_id", s.callbackContext.ConnectionId.String(), ) // Restart protocol - s.Stop() + if err := s.Stop(); err != nil { + return err + } s.initProtocol() s.Start() return nil diff --git a/protocol/leiosnotify/client.go b/protocol/leiosnotify/client.go index 9d974a7d..a16f7580 100644 --- a/protocol/leiosnotify/client.go +++ b/protocol/leiosnotify/client.go @@ -28,6 +28,9 @@ type Client struct { requestNextChan chan protocol.Message onceStart sync.Once onceStop sync.Once + stateMutex sync.Mutex + started bool + stopped bool } func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { @@ -69,32 +72,53 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { func (c *Client) Start() { c.onceStart.Do(func() { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + + if c.stopped { + return + } + c.Protocol.Logger(). Debug("starting client protocol", "component", "network", "protocol", ProtocolName, "connection_id", c.callbackContext.ConnectionId.String(), ) + c.started = true c.Protocol.Start() - // Start goroutine to cleanup resources on protocol shutdown - go func() { - <-c.DoneChan() - close(c.requestNextChan) - }() }) } func (c *Client) Stop() error { var err error c.onceStop.Do(func() { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + c.Protocol.Logger(). Debug("stopping client protocol", "component", "network", "protocol", ProtocolName, "connection_id", c.callbackContext.ConnectionId.String(), ) + c.stopped = true msg := NewMsgDone() - err = c.SendMessage(msg) + if c.started { + if sendErr := c.SendMessage(msg); sendErr != nil { + err = sendErr + } + } + // Defer closing channel until protocol fully shuts down (only if started) + if c.started { + go func() { + <-c.DoneChan() + close(c.requestNextChan) + }() + } else { + // If protocol was never started, close channel immediately + close(c.requestNextChan) + } }) return err } @@ -134,21 +158,37 @@ func (c *Client) messageHandler(msg protocol.Message) error { } func (c *Client) handleBlockAnnouncement(msg protocol.Message) error { - c.requestNextChan <- msg + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.requestNextChan <- msg: + } return nil } func (c *Client) handleBlockOffer(msg protocol.Message) error { - c.requestNextChan <- msg + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.requestNextChan <- msg: + } return nil } func (c *Client) handleBlockTxsOffer(msg protocol.Message) error { - c.requestNextChan <- msg + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.requestNextChan <- msg: + } return nil } func (c *Client) handleVotesOffer(msg protocol.Message) error { - c.requestNextChan <- msg + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.requestNextChan <- msg: + } return nil } diff --git a/protocol/leiosnotify/client_concurrency_test.go b/protocol/leiosnotify/client_concurrency_test.go new file mode 100644 index 00000000..839264bc --- /dev/null +++ b/protocol/leiosnotify/client_concurrency_test.go @@ -0,0 +1,69 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package leiosnotify_test + +import ( + "testing" + + ouroboros "github.com/blinklabs-io/gouroboros" + ouroboros_mock "github.com/blinklabs-io/ouroboros-mock" + + "go.uber.org/goleak" +) + +// TestStopBeforeStart tests that Stop works correctly when called before Start +func TestStopBeforeStart(t *testing.T) { + defer goleak.VerifyNone(t) + + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + conversationEntryNtNResponseV15, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + defer func() { + if err := oConn.Close(); err != nil { + t.Errorf("unexpected error when closing Ouroboros object: %s", err) + } + }() + + client := oConn.LeiosNotify().Client + if client == nil { + t.Fatalf("LeiosNotify client is nil") + } + + // Stop before Start - should not panic or deadlock + if err := client.Stop(); err != nil { + t.Errorf("unexpected error when stopping unstarted client: %s", err) + } + + // Now Start should work normally (but should not actually start due to stopped flag) + client.Start() + + // Stop again should work + if err := client.Stop(); err != nil { + t.Errorf("unexpected error when stopping client: %s", err) + } +} diff --git a/protocol/leiosnotify/client_test.go b/protocol/leiosnotify/client_test.go new file mode 100644 index 00000000..0d7ff138 --- /dev/null +++ b/protocol/leiosnotify/client_test.go @@ -0,0 +1,136 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package leiosnotify_test + +import ( + "fmt" + "testing" + "time" + + ouroboros "github.com/blinklabs-io/gouroboros" + "github.com/blinklabs-io/gouroboros/protocol" + "github.com/blinklabs-io/gouroboros/protocol/handshake" + ouroboros_mock "github.com/blinklabs-io/ouroboros-mock" + + "go.uber.org/goleak" +) + +func mockNtNVersionData() protocol.VersionData { + return protocol.VersionDataNtN13andUp{ + VersionDataNtN11to12: mockNtNVersionDataV11().(protocol.VersionDataNtN11to12), + } +} + +func mockNtNVersionDataV11() protocol.VersionData { + return protocol.VersionDataNtN11to12{ + CborNetworkMagic: ouroboros_mock.MockNetworkMagic, + CborInitiatorAndResponderDiffusionMode: protocol.DiffusionModeInitiatorOnly, + CborPeerSharing: protocol.PeerSharingModeNoPeerSharing, + CborQuery: protocol.QueryModeDisabled, + } +} + +var conversationEntryNtNResponseV15 = ouroboros_mock.ConversationEntryOutput{ + ProtocolId: handshake.ProtocolId, + IsResponse: true, + Messages: []protocol.Message{ + handshake.NewMsgAcceptVersion( + 15, + mockNtNVersionData(), + ), + }, +} + +type testInnerFunc func(*testing.T, *ouroboros.Connection) + +func runTest( + t *testing.T, + conversation []ouroboros_mock.ConversationEntry, + innerFunc testInnerFunc, +) { + defer goleak.VerifyNone(t) + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + conversation, + ) + // Async mock connection error handler + asyncErrChan := make(chan error, 1) + go func() { + err := <-mockConn.(*ouroboros_mock.Connection).ErrorChan() + if err != nil { + asyncErrChan <- fmt.Errorf("received unexpected error: %w", err) + } + close(asyncErrChan) + }() + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + // Async error handler + go func() { + err, ok := <-oConn.ErrorChan() + if !ok { + return + } + // We can't call t.Fatalf() from a different Goroutine, so we panic instead + panic(fmt.Sprintf("unexpected Ouroboros error: %s", err)) + }() + // Run test inner function + innerFunc(t, oConn) + // Wait for mock connection shutdown + select { + case err, ok := <-asyncErrChan: + if ok { + t.Fatal(err.Error()) + } + case <-time.After(5 * time.Second): + t.Fatalf("did not complete within timeout") + } + // Close Ouroboros connection + if err := oConn.Close(); err != nil { + t.Fatalf("unexpected error when closing Ouroboros object: %s", err) + } + // Wait for connection shutdown + select { + case <-oConn.ErrorChan(): + case <-time.After(10 * time.Second): + t.Errorf("did not shutdown within timeout") + } +} + +func TestClientShutdown(t *testing.T) { + runTest( + t, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + conversationEntryNtNResponseV15, + }, + func(t *testing.T, oConn *ouroboros.Connection) { + if oConn.LeiosNotify() == nil { + t.Fatalf("LeiosNotify client is nil") + } + // Start the client + oConn.LeiosNotify().Client.Start() + // Stop the client + if err := oConn.LeiosNotify().Client.Stop(); err != nil { + t.Fatalf("unexpected error when stopping client: %s", err) + } + }, + ) +} diff --git a/protocol/leiosnotify/server.go b/protocol/leiosnotify/server.go index ebca8ede..d2be3c16 100644 --- a/protocol/leiosnotify/server.go +++ b/protocol/leiosnotify/server.go @@ -115,7 +115,9 @@ func (s *Server) handleDone() error { "connection_id", s.callbackContext.ConnectionId.String(), ) // Restart protocol - s.Stop() + if err := s.Stop(); err != nil { + return err + } s.initProtocol() s.Start() return nil diff --git a/protocol/localstatequery/client.go b/protocol/localstatequery/client.go index 1c938362..8df2d554 100644 --- a/protocol/localstatequery/client.go +++ b/protocol/localstatequery/client.go @@ -35,11 +35,15 @@ type Client struct { enableGetChainPoint bool enableGetRewardInfoPoolsBlock bool busyMutex sync.Mutex + acquiredMutex sync.Mutex acquired bool queryResultChan chan []byte acquireResultChan chan error currentEra int onceStart sync.Once + onceStop sync.Once + stateMutex sync.Mutex + started bool } // NewClient returns a new LocalStateQuery client object @@ -103,14 +107,45 @@ func (c *Client) Start() { "protocol", ProtocolName, "connection_id", c.callbackContext.ConnectionId.String(), ) + c.stateMutex.Lock() + c.started = true + c.stateMutex.Unlock() c.Protocol.Start() - // Start goroutine to cleanup resources on protocol shutdown - go func() { - <-c.DoneChan() + }) +} + +// Stop stops the LocalStateQuery client protocol. +func (c *Client) Stop() error { + var err error + c.onceStop.Do(func() { + c.Protocol.Logger(). + Debug("stopping client protocol", + "component", "network", + "protocol", ProtocolName, + "connection_id", c.callbackContext.ConnectionId.String(), + ) + + c.stateMutex.Lock() + started := c.started + c.stateMutex.Unlock() + if started { + msg := NewMsgDone() + if err = c.SendMessage(msg); err != nil { + return + } + _ = c.Protocol.Stop() // Error ignored - method returns SendMessage error + go func() { + <-c.DoneChan() + close(c.queryResultChan) + close(c.acquireResultChan) + }() + } else { + // If protocol was never started, close channels immediately close(c.queryResultChan) close(c.acquireResultChan) - }() + } }) + return err } // Acquire starts the acquire process for the specified chain point @@ -875,14 +910,14 @@ func (c *Client) handleAcquired() error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) - // Check for shutdown + c.acquiredMutex.Lock() + defer c.acquiredMutex.Unlock() + c.acquired = true select { case <-c.DoneChan(): return protocol.ErrProtocolShuttingDown - default: + case c.acquireResultChan <- nil: } - c.acquired = true - c.acquireResultChan <- nil c.currentEra = -1 return nil } @@ -895,18 +930,20 @@ func (c *Client) handleFailure(msg protocol.Message) error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) - // Check for shutdown - select { - case <-c.DoneChan(): - return protocol.ErrProtocolShuttingDown - default: - } msgFailure := msg.(*MsgFailure) switch msgFailure.Failure { case AcquireFailurePointTooOld: - c.acquireResultChan <- ErrAcquireFailurePointTooOld + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.acquireResultChan <- ErrAcquireFailurePointTooOld: + } case AcquireFailurePointNotOnChain: - c.acquireResultChan <- ErrAcquireFailurePointNotOnChain + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.acquireResultChan <- ErrAcquireFailurePointNotOnChain: + } default: return fmt.Errorf("unknown failure type: %d", msgFailure.Failure) } @@ -921,20 +958,21 @@ func (c *Client) handleResult(msg protocol.Message) error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) - // Check for shutdown + msgResult := msg.(*MsgResult) select { case <-c.DoneChan(): return protocol.ErrProtocolShuttingDown - default: + case c.queryResultChan <- msgResult.Result: } - msgResult := msg.(*MsgResult) - c.queryResultChan <- msgResult.Result return nil } func (c *Client) acquire(acquireTarget AcquireTarget) error { + c.acquiredMutex.Lock() + acquired := c.acquired + c.acquiredMutex.Unlock() var msg protocol.Message - if c.acquired { + if acquired { switch t := acquireTarget.(type) { case AcquireSpecificPoint: msg = NewMsgReAcquire(t.Point) @@ -972,6 +1010,8 @@ func (c *Client) release() error { if err := c.SendMessage(msg); err != nil { return err } + c.acquiredMutex.Lock() + defer c.acquiredMutex.Unlock() c.acquired = false c.currentEra = -1 return nil @@ -979,7 +1019,10 @@ func (c *Client) release() error { func (c *Client) runQuery(query any, result any) error { msg := NewMsgQuery(query) - if !c.acquired { + c.acquiredMutex.Lock() + acquired := c.acquired + c.acquiredMutex.Unlock() + if !acquired { if err := c.acquire(AcquireVolatileTip{}); err != nil { return err } diff --git a/protocol/localstatequery/client_test.go b/protocol/localstatequery/client_test.go index 06ad24ed..26ab9a06 100644 --- a/protocol/localstatequery/client_test.go +++ b/protocol/localstatequery/client_test.go @@ -22,7 +22,7 @@ import ( "testing" "time" - "github.com/blinklabs-io/gouroboros" + ouroboros "github.com/blinklabs-io/gouroboros" "github.com/blinklabs-io/gouroboros/cbor" "github.com/blinklabs-io/gouroboros/internal/test" "github.com/blinklabs-io/gouroboros/ledger" @@ -109,7 +109,7 @@ func runTest( if ok { t.Fatal(err.Error()) } - case <-time.After(2 * time.Second): + case <-time.After(5 * time.Second): t.Fatalf("did not complete within timeout") } // Close Ouroboros connection @@ -353,3 +353,24 @@ func TestGenesisConfigJSON(t *testing.T) { t.Logf("Successfully validated the GenesisConfigResult after JSON marshalling and unmarshalling.") } } + +func TestClientShutdown(t *testing.T) { + runTest( + t, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtCResponse, + }, + func(t *testing.T, oConn *ouroboros.Connection) { + if oConn.LocalStateQuery() == nil { + t.Fatalf("LocalStateQuery client is nil") + } + // Start the client + oConn.LocalStateQuery().Client.Start() + // Stop the client + if err := oConn.LocalStateQuery().Client.Stop(); err != nil { + t.Fatalf("unexpected error when stopping client: %s", err) + } + }, + ) +} diff --git a/protocol/localtxmonitor/client.go b/protocol/localtxmonitor/client.go index f119dd97..7b9d2099 100644 --- a/protocol/localtxmonitor/client.go +++ b/protocol/localtxmonitor/client.go @@ -27,6 +27,7 @@ type Client struct { config *Config callbackContext CallbackContext busyMutex sync.Mutex + acquiredMutex sync.Mutex acquired bool acquiredSlot uint64 acquireResultChan chan bool @@ -35,6 +36,9 @@ type Client struct { getSizesResultChan chan MsgReplyGetSizesResult onceStart sync.Once onceStop sync.Once + stateMutex sync.Mutex + started bool + stopped bool } // NewClient returns a new LocalTxMonitor client object @@ -84,6 +88,15 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { func (c *Client) Start() { c.onceStart.Do(func() { + c.stateMutex.Lock() + if c.stopped { + // Do not start a client that has already been stopped + c.stateMutex.Unlock() + return + } + c.started = true + c.stateMutex.Unlock() + c.Protocol.Logger(). Debug("starting client protocol", "component", "network", @@ -91,14 +104,6 @@ func (c *Client) Start() { "connection_id", c.callbackContext.ConnectionId.String(), ) c.Protocol.Start() - // Start goroutine to cleanup resources on protocol shutdown - go func() { - <-c.DoneChan() - close(c.acquireResultChan) - close(c.hasTxResultChan) - close(c.nextTxResultChan) - close(c.getSizesResultChan) - }() }) } @@ -106,17 +111,45 @@ func (c *Client) Start() { func (c *Client) Stop() error { var err error c.onceStop.Do(func() { + c.stateMutex.Lock() + started := c.started + c.stopped = true + c.stateMutex.Unlock() + c.Protocol.Logger(). Debug("stopping client protocol", "component", "network", "protocol", ProtocolName, "connection_id", c.callbackContext.ConnectionId.String(), ) - c.busyMutex.Lock() - defer c.busyMutex.Unlock() - msg := NewMsgDone() - if err = c.SendMessage(msg); err != nil { - return + if started { + c.busyMutex.Lock() + msg := NewMsgDone() + if err = c.SendMessage(msg); err != nil { + c.busyMutex.Unlock() + return + } + c.busyMutex.Unlock() + } + + // Call Protocol.Stop() after releasing locks to avoid potential deadlocks + _ = c.Protocol.Stop() + + // Defer closing channels until protocol fully shuts down (only if started) + if started { + go func() { + <-c.DoneChan() + close(c.acquireResultChan) + close(c.hasTxResultChan) + close(c.nextTxResultChan) + close(c.getSizesResultChan) + }() + } else { + // If protocol was never started, close channels immediately + close(c.acquireResultChan) + close(c.hasTxResultChan) + close(c.nextTxResultChan) + close(c.getSizesResultChan) } }) return err @@ -124,6 +157,13 @@ func (c *Client) Stop() error { // Acquire starts the acquire process for a current mempool snapshot func (c *Client) Acquire() error { + c.stateMutex.Lock() + stopped := c.stopped + c.stateMutex.Unlock() + if stopped { + return protocol.ErrProtocolShuttingDown + } + c.Protocol.Logger(). Debug("calling Acquire()", "component", "network", @@ -138,6 +178,13 @@ func (c *Client) Acquire() error { // Release releases the previously acquired mempool snapshot func (c *Client) Release() error { + c.stateMutex.Lock() + stopped := c.stopped + c.stateMutex.Unlock() + if stopped { + return protocol.ErrProtocolShuttingDown + } + c.Protocol.Logger(). Debug("calling Release()", "component", "network", @@ -152,6 +199,13 @@ func (c *Client) Release() error { // HasTx returns whether or not the specified transaction ID exists in the mempool snapshot func (c *Client) HasTx(txId []byte) (bool, error) { + c.stateMutex.Lock() + stopped := c.stopped + c.stateMutex.Unlock() + if stopped { + return false, protocol.ErrProtocolShuttingDown + } + c.Protocol.Logger(). Debug(fmt.Sprintf("calling HasTx(txId: %x)", txId), "component", "network", @@ -161,7 +215,10 @@ func (c *Client) HasTx(txId []byte) (bool, error) { ) c.busyMutex.Lock() defer c.busyMutex.Unlock() - if !c.acquired { + c.acquiredMutex.Lock() + acquired := c.acquired + c.acquiredMutex.Unlock() + if !acquired { if err := c.acquire(); err != nil { return false, err } @@ -179,6 +236,13 @@ func (c *Client) HasTx(txId []byte) (bool, error) { // NextTx returns the next transaction in the mempool snapshot func (c *Client) NextTx() ([]byte, error) { + c.stateMutex.Lock() + stopped := c.stopped + c.stateMutex.Unlock() + if stopped { + return nil, protocol.ErrProtocolShuttingDown + } + c.Protocol.Logger(). Debug("calling NextTx()", "component", "network", @@ -188,7 +252,10 @@ func (c *Client) NextTx() ([]byte, error) { ) c.busyMutex.Lock() defer c.busyMutex.Unlock() - if !c.acquired { + c.acquiredMutex.Lock() + acquired := c.acquired + c.acquiredMutex.Unlock() + if !acquired { if err := c.acquire(); err != nil { return nil, err } @@ -206,6 +273,13 @@ func (c *Client) NextTx() ([]byte, error) { // GetSizes returns the capacity (in bytes), size (in bytes), and number of transactions in the mempool snapshot func (c *Client) GetSizes() (uint32, uint32, uint32, error) { + c.stateMutex.Lock() + stopped := c.stopped + c.stateMutex.Unlock() + if stopped { + return 0, 0, 0, protocol.ErrProtocolShuttingDown + } + c.Protocol.Logger(). Debug("calling GetSizes()", "component", "network", @@ -215,7 +289,10 @@ func (c *Client) GetSizes() (uint32, uint32, uint32, error) { ) c.busyMutex.Lock() defer c.busyMutex.Unlock() - if !c.acquired { + c.acquiredMutex.Lock() + acquired := c.acquired + c.acquiredMutex.Unlock() + if !acquired { if err := c.acquire(); err != nil { return 0, 0, 0, err } @@ -261,9 +338,15 @@ func (c *Client) handleAcquired(msg protocol.Message) error { "connection_id", c.callbackContext.ConnectionId.String(), ) msgAcquired := msg.(*MsgAcquired) + c.acquiredMutex.Lock() c.acquired = true c.acquiredSlot = msgAcquired.SlotNo - c.acquireResultChan <- true + c.acquiredMutex.Unlock() + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.acquireResultChan <- true: + } return nil } @@ -276,7 +359,11 @@ func (c *Client) handleReplyHasTx(msg protocol.Message) error { "connection_id", c.callbackContext.ConnectionId.String(), ) msgReplyHasTx := msg.(*MsgReplyHasTx) - c.hasTxResultChan <- msgReplyHasTx.Result + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.hasTxResultChan <- msgReplyHasTx.Result: + } return nil } @@ -289,7 +376,11 @@ func (c *Client) handleReplyNextTx(msg protocol.Message) error { "connection_id", c.callbackContext.ConnectionId.String(), ) msgReplyNextTx := msg.(*MsgReplyNextTx) - c.nextTxResultChan <- msgReplyNextTx.Transaction.Tx + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.nextTxResultChan <- msgReplyNextTx.Transaction.Tx: + } return nil } @@ -302,7 +393,11 @@ func (c *Client) handleReplyGetSizes(msg protocol.Message) error { "connection_id", c.callbackContext.ConnectionId.String(), ) msgReplyGetSizes := msg.(*MsgReplyGetSizes) - c.getSizesResultChan <- msgReplyGetSizes.Result + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.getSizesResultChan <- msgReplyGetSizes.Result: + } return nil } @@ -324,6 +419,8 @@ func (c *Client) release() error { if err := c.SendMessage(msg); err != nil { return err } + c.acquiredMutex.Lock() c.acquired = false + c.acquiredMutex.Unlock() return nil } diff --git a/protocol/localtxmonitor/client_test.go b/protocol/localtxmonitor/client_test.go index c3529599..56fcc247 100644 --- a/protocol/localtxmonitor/client_test.go +++ b/protocol/localtxmonitor/client_test.go @@ -90,7 +90,7 @@ func runTest( if ok { t.Fatal(err.Error()) } - case <-time.After(2 * time.Second): + case <-time.After(5 * time.Second): t.Fatalf("did not complete within timeout") } // Close Ouroboros connection @@ -296,3 +296,24 @@ func TestNextTx(t *testing.T) { }, ) } + +func TestClientShutdown(t *testing.T) { + runTest( + t, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtCResponse, + }, + func(t *testing.T, oConn *ouroboros.Connection) { + if oConn.LocalTxMonitor() == nil { + t.Fatalf("LocalTxMonitor client is nil") + } + // Start the client + oConn.LocalTxMonitor().Client.Start() + // Stop the client + if err := oConn.LocalTxMonitor().Client.Stop(); err != nil { + t.Fatalf("unexpected error when stopping client: %s", err) + } + }, + ) +} diff --git a/protocol/localtxsubmission/client.go b/protocol/localtxsubmission/client.go index be32008a..170ef0ff 100644 --- a/protocol/localtxsubmission/client.go +++ b/protocol/localtxsubmission/client.go @@ -25,12 +25,15 @@ import ( // Client implements the LocalTxSubmission client type Client struct { *protocol.Protocol - config *Config - callbackContext CallbackContext - busyMutex sync.Mutex - submitResultChan chan error - onceStart sync.Once - onceStop sync.Once + config *Config + callbackContext CallbackContext + busyMutex sync.Mutex + submitResultChan chan error + onceStart sync.Once + onceStop sync.Once + stateMutex sync.Mutex + started bool + closeSubmitResultOnce sync.Once } // NewClient returns a new LocalTxSubmission client object @@ -73,18 +76,17 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { func (c *Client) Start() { c.onceStart.Do(func() { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + c.Protocol.Logger(). Debug("starting client protocol", "component", "network", "protocol", ProtocolName, "connection_id", c.callbackContext.ConnectionId.String(), ) + c.started = true c.Protocol.Start() - // Start goroutine to cleanup resources on protocol shutdown - go func() { - <-c.DoneChan() - close(c.submitResultChan) - }() }) } @@ -92,6 +94,9 @@ func (c *Client) Start() { func (c *Client) Stop() error { var err error c.onceStop.Do(func() { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + c.Protocol.Logger(). Debug("stopping client protocol", "component", "network", @@ -101,8 +106,20 @@ func (c *Client) Stop() error { c.busyMutex.Lock() defer c.busyMutex.Unlock() msg := NewMsgDone() - if err = c.SendMessage(msg); err != nil { - return + if sendErr := c.SendMessage(msg); sendErr != nil { + err = sendErr + } + // Always attempt to stop the protocol, even if SendMessage failed + _ = c.Protocol.Stop() // Error ignored - method returns SendMessage error if any + // Defer closing channel until protocol fully shuts down (only if started) + if c.started { + go func() { + <-c.DoneChan() + c.closeSubmitResultChan() + }() + } else { + // If protocol was never started, close channel immediately + c.closeSubmitResultChan() } }) return err @@ -137,6 +154,8 @@ func (c *Client) messageHandler(msg protocol.Message) error { err = c.handleAcceptTx() case MessageTypeRejectTx: err = c.handleRejectTx(msg) + case MessageTypeDone: + err = c.handleDone() default: err = fmt.Errorf( "%s: received unexpected message type %d", @@ -155,13 +174,11 @@ func (c *Client) handleAcceptTx() error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) - // Check for shutdown select { case <-c.DoneChan(): return protocol.ErrProtocolShuttingDown - default: + case c.submitResultChan <- nil: } - c.submitResultChan <- nil return nil } @@ -173,12 +190,6 @@ func (c *Client) handleRejectTx(msg protocol.Message) error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) - // Check for shutdown - select { - case <-c.DoneChan(): - return protocol.ErrProtocolShuttingDown - default: - } msgRejectTx := msg.(*MsgRejectTx) rejectErr, err := ledger.NewTxSubmitErrorFromCbor(msgRejectTx.Reason) if err != nil { @@ -188,6 +199,29 @@ func (c *Client) handleRejectTx(msg protocol.Message) error { Reason: rejectErr, ReasonCbor: []byte(msgRejectTx.Reason), } - c.submitResultChan <- err + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.submitResultChan <- err: + } return nil } + +func (c *Client) handleDone() error { + c.Protocol.Logger(). + Debug("received done from server", + "component", "network", + "protocol", ProtocolName, + "role", "client", + "connection_id", c.callbackContext.ConnectionId.String(), + ) + // Server is shutting down, close the result channel to unblock any waiting operations + c.closeSubmitResultChan() + return nil +} + +func (c *Client) closeSubmitResultChan() { + c.closeSubmitResultOnce.Do(func() { + close(c.submitResultChan) + }) +} diff --git a/protocol/localtxsubmission/client_test.go b/protocol/localtxsubmission/client_test.go index a90b65af..6a6fabd8 100644 --- a/protocol/localtxsubmission/client_test.go +++ b/protocol/localtxsubmission/client_test.go @@ -83,7 +83,7 @@ func runTest( if ok { t.Fatal(err.Error()) } - case <-time.After(2 * time.Second): + case <-time.After(5 * time.Second): t.Fatalf("did not complete within timeout") } // Close Ouroboros connection @@ -163,3 +163,54 @@ func TestSubmitTxRject(t *testing.T) { }, ) } + +func TestClientShutdown(t *testing.T) { + runTest( + t, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtCResponse, + }, + func(t *testing.T, oConn *ouroboros.Connection) { + if oConn.LocalTxSubmission() == nil { + t.Fatalf("LocalTxSubmission client is nil") + } + // Start the client + oConn.LocalTxSubmission().Client.Start() + // Stop the client + if err := oConn.LocalTxSubmission().Client.Stop(); err != nil { + t.Fatalf("unexpected error when stopping client: %s", err) + } + }, + ) +} + +func TestSubmitTxServerShutdown(t *testing.T) { + testTx := test.DecodeHexString("abcdef0123456789") + conversation := append( + conversationHandshakeSubmitTx, + ouroboros_mock.ConversationEntryOutput{ + ProtocolId: localtxsubmission.ProtocolId, + IsResponse: true, + Messages: []protocol.Message{ + localtxsubmission.NewMsgDone(), + }, + }, + ) + runTest( + t, + conversation, + func(t *testing.T, oConn *ouroboros.Connection) { + err := oConn.LocalTxSubmission().Client.SubmitTx( + ledger.TxTypeBabbage, + testTx, + ) + if err == nil { + t.Fatalf("expected protocol shutdown error, got nil") + } + if !errors.Is(err, protocol.ErrProtocolShuttingDown) { + t.Fatalf("expected ErrProtocolShuttingDown, got: %s", err) + } + }, + ) +} diff --git a/protocol/localtxsubmission/localtxsubmission.go b/protocol/localtxsubmission/localtxsubmission.go index 091e834f..b3c38683 100644 --- a/protocol/localtxsubmission/localtxsubmission.go +++ b/protocol/localtxsubmission/localtxsubmission.go @@ -43,6 +43,10 @@ var StateMap = protocol.StateMap{ MsgType: MessageTypeSubmitTx, NewState: stateBusy, }, + { + MsgType: MessageTypeDone, + NewState: stateDone, + }, }, }, stateBusy: protocol.StateMapEntry{ @@ -56,6 +60,10 @@ var StateMap = protocol.StateMap{ MsgType: MessageTypeRejectTx, NewState: stateIdle, }, + { + MsgType: MessageTypeDone, + NewState: stateDone, + }, }, }, stateDone: protocol.StateMapEntry{ diff --git a/protocol/peersharing/client.go b/protocol/peersharing/client.go index 9e5feb97..38f16417 100644 --- a/protocol/peersharing/client.go +++ b/protocol/peersharing/client.go @@ -16,6 +16,7 @@ package peersharing import ( "fmt" + "sync" "github.com/blinklabs-io/gouroboros/protocol" ) @@ -26,6 +27,11 @@ type Client struct { config *Config callbackContext CallbackContext sharePeersChan chan []PeerAddress + onceStart sync.Once + onceStop sync.Once + stateMutex sync.Mutex + started bool + stopped bool } // NewClient returns a new PeerSharing client object @@ -66,6 +72,58 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { return c } +// Start starts the PeerSharing client protocol +func (c *Client) Start() { + c.onceStart.Do(func() { + c.stateMutex.Lock() + defer c.stateMutex.Unlock() + + if c.stopped { + // Cannot start a client that has been stopped + return + } + + c.Protocol.Logger(). + Debug("starting client protocol", + "component", "network", + "protocol", ProtocolName, + "connection_id", c.callbackContext.ConnectionId.String(), + ) + c.started = true + c.Protocol.Start() + }) +} + +// Stop stops the PeerSharing client protocol +func (c *Client) Stop() error { + var err error + c.onceStop.Do(func() { + c.stateMutex.Lock() + c.stopped = true + started := c.started + c.stateMutex.Unlock() + + c.Protocol.Logger(). + Debug("stopping client protocol", + "component", "network", + "protocol", ProtocolName, + "connection_id", c.callbackContext.ConnectionId.String(), + ) + // Defer closing channel until protocol fully shuts down (only if started) + if started { + go func() { + <-c.DoneChan() + close(c.sharePeersChan) + }() + } else { + // If protocol was never started, close channel immediately + close(c.sharePeersChan) + } + err = c.Protocol.Stop() + }) + return err +} + func (c *Client) GetPeers(amount uint8) ([]PeerAddress, error) { c.Protocol.Logger(). Debug(fmt.Sprintf("calling GetPeers(amount: %d)", amount), @@ -109,6 +167,10 @@ func (c *Client) handleSharePeers(msg protocol.Message) error { "connection_id", c.callbackContext.ConnectionId.String(), ) msgSharePeers := msg.(*MsgSharePeers) - c.sharePeersChan <- msgSharePeers.PeerAddresses + select { + case <-c.DoneChan(): + return protocol.ErrProtocolShuttingDown + case c.sharePeersChan <- msgSharePeers.PeerAddresses: + } return nil } diff --git a/protocol/peersharing/client_test.go b/protocol/peersharing/client_test.go new file mode 100644 index 00000000..830c2a56 --- /dev/null +++ b/protocol/peersharing/client_test.go @@ -0,0 +1,114 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package peersharing_test + +import ( + "fmt" + "testing" + "time" + + ouroboros "github.com/blinklabs-io/gouroboros" + ouroboros_mock "github.com/blinklabs-io/ouroboros-mock" + + "go.uber.org/goleak" +) + +type testInnerFunc func(*testing.T, *ouroboros.Connection) + +func runTest( + t *testing.T, + conversation []ouroboros_mock.ConversationEntry, + innerFunc testInnerFunc, +) { + defer goleak.VerifyNone(t) + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleClient, + conversation, + ) + // Async mock connection error handler + asyncErrChan := make(chan error, 1) + go func() { + err := <-mockConn.(*ouroboros_mock.Connection).ErrorChan() + if err != nil { + asyncErrChan <- fmt.Errorf("received unexpected error: %w", err) + } + close(asyncErrChan) + }() + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + // Async error handler + oConnErrChan := make(chan error, 1) + go func() { + err, ok := <-oConn.ErrorChan() + if !ok { + return + } + oConnErrChan <- fmt.Errorf("unexpected Ouroboros error: %w", err) + }() + // Run test inner function + innerFunc(t, oConn) + // Check for errors from Ouroboros connection + select { + case err := <-oConnErrChan: + t.Fatal(err.Error()) + default: + } + // Wait for mock connection shutdown + select { + case err, ok := <-asyncErrChan: + if ok { + t.Fatal(err.Error()) + } + case <-time.After(5 * time.Second): + t.Fatalf("did not complete within timeout") + } + // Close Ouroboros connection + if err := oConn.Close(); err != nil { + t.Fatalf("unexpected error when closing Ouroboros object: %s", err) + } + // Wait for connection shutdown + select { + case <-oConn.ErrorChan(): + case <-time.After(10 * time.Second): + t.Errorf("did not shutdown within timeout") + } +} + +func TestClientShutdown(t *testing.T) { + runTest( + t, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + func(t *testing.T, oConn *ouroboros.Connection) { + if oConn.PeerSharing() == nil { + t.Fatalf("PeerSharing client is nil") + } + // Start the client + oConn.PeerSharing().Client.Start() + // Stop the client + if err := oConn.PeerSharing().Client.Stop(); err != nil { + t.Fatalf("unexpected error when stopping client: %s", err) + } + }, + ) +} diff --git a/protocol/peersharing/server.go b/protocol/peersharing/server.go index 713d3e09..3413c351 100644 --- a/protocol/peersharing/server.go +++ b/protocol/peersharing/server.go @@ -17,6 +17,7 @@ package peersharing import ( "errors" "fmt" + "sync" "github.com/blinklabs-io/gouroboros/protocol" ) @@ -27,6 +28,7 @@ type Server struct { config *Config callbackContext CallbackContext protoOptions protocol.ProtocolOptions + onceStop sync.Once } // NewServer returns a new PeerSharing server object @@ -44,6 +46,14 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { return s } +func (s *Server) Stop() error { + var err error + s.onceStop.Do(func() { + err = s.Protocol.Stop() + }) + return err +} + func (s *Server) initProtocol() { protoConfig := protocol.ProtocolConfig{ Name: ProtocolName, @@ -106,7 +116,7 @@ func (s *Server) handleShareRequest(msg protocol.Message) error { return nil } -func (s *Server) handleDone(msg protocol.Message) error { +func (s *Server) handleDone(_ protocol.Message) error { s.Protocol.Logger(). Debug("done", "component", "network", @@ -114,8 +124,12 @@ func (s *Server) handleDone(msg protocol.Message) error { "role", "server", "connection_id", s.callbackContext.ConnectionId.String(), ) - // Restart protocol - s.Stop() + // Stop current protocol instance before restarting + if s.Protocol != nil { + if err := s.Protocol.Stop(); err != nil { + return err + } + } s.initProtocol() s.Start() return nil diff --git a/protocol/peersharing/server_test.go b/protocol/peersharing/server_test.go new file mode 100644 index 00000000..9ca73d11 --- /dev/null +++ b/protocol/peersharing/server_test.go @@ -0,0 +1,82 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package peersharing_test + +import ( + "fmt" + "testing" + "time" + + ouroboros "github.com/blinklabs-io/gouroboros" + ouroboros_mock "github.com/blinklabs-io/ouroboros-mock" + + "go.uber.org/goleak" +) + +func TestServerShutdown(t *testing.T) { + t.Skip("Skipping server test due to mock server issues with NtN protocol") + defer goleak.VerifyNone(t) + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleServer, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + asyncErrChan := make(chan error, 1) + go func() { + err := <-mockConn.(*ouroboros_mock.Connection).ErrorChan() + if err != nil { + asyncErrChan <- fmt.Errorf("received unexpected error: %w", err) + } + close(asyncErrChan) + }() + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + if oConn.PeerSharing() == nil { + t.Fatalf("PeerSharing server is nil") + } + // Start the server + oConn.PeerSharing().Server.Start() + // Stop the server + if err := oConn.PeerSharing().Server.Stop(); err != nil { + t.Fatalf("unexpected error when stopping server: %s", err) + } + // Wait for mock connection shutdown + select { + case err, ok := <-asyncErrChan: + if ok { + t.Fatal(err.Error()) + } + case <-time.After(2 * time.Second): + t.Fatalf("did not complete within timeout") + } + // Close Ouroboros connection + if err := oConn.Close(); err != nil { + t.Fatalf("unexpected error when closing Ouroboros object: %s", err) + } + // Wait for connection shutdown + select { + case <-oConn.ErrorChan(): + case <-time.After(10 * time.Second): + t.Errorf("did not shutdown within timeout") + } +} diff --git a/protocol/protocol.go b/protocol/protocol.go index a35a5607..eb0cc1d2 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -172,8 +172,12 @@ func (p *Protocol) Start() { }) } -// Stop shuts down the mini-protocol -func (p *Protocol) Stop() { +// Stop shuts down the mini-protocol by unregistering it from the muxer. +// This method returns an error for API consistency and future extensibility, +// but currently always returns nil as the shutdown operation cannot fail. +// Callers should check for errors in case future implementations add +// failure modes (e.g., muxer communication issues). +func (p *Protocol) Stop() error { p.onceStop.Do(func() { // Unregister protocol from muxer muxerProtocolRole := muxer.ProtocolRoleInitiator @@ -185,6 +189,7 @@ func (p *Protocol) Stop() { muxerProtocolRole, ) }) + return nil } // Logger returns the protocol logger @@ -243,7 +248,7 @@ func (p *Protocol) SendMessage(msg Message) error { if limit > 0 && p.pendingSendBytes+msgLen > limit { p.pendingBytesMu.Unlock() p.SendError(ErrProtocolViolationQueueExceeded) - p.Stop() + _ = p.Stop() // Error ignored in error path return ErrProtocolViolationQueueExceeded } p.pendingSendBytes += msgLen @@ -273,15 +278,34 @@ func (p *Protocol) sendLoop() { close(p.sendDoneChan) }() + shuttingDown := false for { select { case <-p.recvDoneChan: - // Break out of send loop if we're shutting down - return + // Enter shutdown mode - drain remaining messages before exiting + shuttingDown = true case <-p.sendReadyChan: // We are ready to send based on state map } + // If shutting down, drain remaining messages without using len() + if shuttingDown { + for { + select { + case msg, ok := <-p.sendQueueChan: + if !ok { + // Channel closed, exit + return + } + // Send message individually during shutdown + p.sendMessage(msg) + default: + // No more messages available, exit + return + } + } + } + // Read queued messages and write into buffer payloadBuf := bytes.NewBuffer(nil) msgCount := 0 @@ -290,8 +314,20 @@ func (p *Protocol) sendLoop() { // Get next message from send queue select { case <-p.recvDoneChan: - // Break out of send loop if we're shutting down - return + // If we weren't already shutting down, enter shutdown mode + shuttingDown = true + // Drain remaining messages without using len() + for { + select { + case msg, ok := <-p.sendQueueChan: + if !ok { + return + } + p.sendMessage(msg) + default: + return + } + } case msg, ok := <-p.sendQueueChan: if !ok { // We're shutting down @@ -344,6 +380,9 @@ func (p *Protocol) sendLoop() { break outer } // Check if there are any more queued messages + // NOTE: len() on a channel is not thread-safe in general, but in this specific case + // it's acceptable as a pragmatic solution since we're the only reader and the + // race window is small. This avoids more complex synchronization. if len(p.sendQueueChan) == 0 { break outer } @@ -378,6 +417,31 @@ func (p *Protocol) sendLoop() { } } +// sendMessage encodes and sends a single message +// NOTE: This is used only during shutdown draining and does not update +// pendingSendBytes or transition protocol state, as shutdown has already +// been initiated and these accounting operations are no longer relevant. +func (p *Protocol) sendMessage(msg Message) { + var data []byte + if msg.Cbor() != nil { + data = msg.Cbor() + } else { + var err error + data, err = cbor.Encode(msg) + if err != nil { + p.SendError(err) + return + } + msg.SetCbor(data) + } + segment := muxer.NewSegment( + p.config.ProtocolId, + data, + p.Role() == ProtocolRoleServer, + ) + p.muxerSendChan <- segment +} + func (p *Protocol) readLoop() { leftoverData := false readBuffer := bytes.NewBuffer(nil) @@ -449,7 +513,7 @@ func (p *Protocol) readLoop() { if limit > 0 && p.pendingRecvBytes+msgLen > limit { p.pendingBytesMu.Unlock() p.SendError(ErrProtocolViolationQueueExceeded) - p.Stop() + _ = p.Stop() // Error ignored in error path return } p.pendingRecvBytes += msgLen diff --git a/protocol/txsubmission/server.go b/protocol/txsubmission/server.go index 03b0c336..4c77e5fd 100644 --- a/protocol/txsubmission/server.go +++ b/protocol/txsubmission/server.go @@ -17,6 +17,8 @@ package txsubmission import ( "errors" "fmt" + "sync" + "sync/atomic" "github.com/blinklabs-io/gouroboros/ledger/common" "github.com/blinklabs-io/gouroboros/protocol" @@ -28,9 +30,14 @@ type Server struct { config *Config callbackContext CallbackContext protoOptions protocol.ProtocolOptions - ackCount int + ackCount int32 requestTxIdsResultChan chan requestTxIdsResult requestTxsResultChan chan []TxBody + done chan struct{} + doneMutex sync.Mutex + onceStop sync.Once + restartMutex sync.Mutex + stopped bool // indicates permanent stop has been requested } type requestTxIdsResult struct { @@ -44,8 +51,9 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { config: cfg, // Save this for re-use later protoOptions: protoOptions, - requestTxIdsResultChan: make(chan requestTxIdsResult), - requestTxsResultChan: make(chan []TxBody), + requestTxIdsResultChan: make(chan requestTxIdsResult, 1), + requestTxsResultChan: make(chan []TxBody, 1), + done: make(chan struct{}), } s.callbackContext = CallbackContext{ Server: s, @@ -80,15 +88,45 @@ func (s *Server) Start() { "connection_id", s.callbackContext.ConnectionId.String(), ) s.Protocol.Start() - // Start goroutine to cleanup resources on protocol shutdown - go func() { - // We create our own vars for these channels since they get replaced on restart - requestTxIdsResultChan := s.requestTxIdsResultChan - requestTxsResultChan := s.requestTxsResultChan - <-s.DoneChan() - close(requestTxIdsResultChan) - close(requestTxsResultChan) - }() +} + +// Stop stops the server protocol. +func (s *Server) Stop() error { + s.onceStop.Do(func() { + s.restartMutex.Lock() + defer s.restartMutex.Unlock() + s.Protocol.Logger(). + Debug("stopping server protocol", + "component", "network", + "protocol", ProtocolName, + "connection_id", s.callbackContext.ConnectionId.String(), + ) + s.stopped = true + s.doneMutex.Lock() + select { + case <-s.done: + // Already closed + default: + close(s.done) + } + s.doneMutex.Unlock() + _ = s.Protocol.Stop() // Error ignored - method returns nil by design + }) + return nil +} + +// doneChan returns the current done channel, safely accessed under mutex +func (s *Server) doneChan() <-chan struct{} { + s.doneMutex.Lock() + defer s.doneMutex.Unlock() + return s.done +} + +// IsStopped returns true if the server has been permanently stopped +func (s *Server) IsStopped() bool { + s.restartMutex.Lock() + defer s.restartMutex.Unlock() + return s.stopped } // RequestTxIds requests the next set of TX identifiers from the remote node's mempool @@ -115,20 +153,21 @@ func (s *Server) RequestTxIds( Error("TxSubmission request count exceeded", "requested", reqCount, "limit", MaxRequestCount) return nil, protocol.ErrProtocolViolationRequestExceeded } - if s.ackCount < 0 { + ackCount := atomic.LoadInt32(&s.ackCount) + if ackCount < 0 { s.Protocol.Logger(). - Error("TxSubmission ack count must be non-negative", "ack_count", s.ackCount) + Error("TxSubmission ack count must be non-negative", "ack_count", ackCount) return nil, protocol.ErrProtocolViolationRequestExceeded } - if s.ackCount > MaxAckCount { + if ackCount > MaxAckCount { s.Protocol.Logger(). - Error("TxSubmission ack count exceeded", "ack_count", s.ackCount, "limit", MaxAckCount) + Error("TxSubmission ack count exceeded", "ack_count", ackCount, "limit", MaxAckCount) return nil, protocol.ErrProtocolViolationRequestExceeded } // Safe conversions after validation //nolint:gosec // Already validated above to be non-negative and within uint16 range - ack := uint16(s.ackCount) + ack := uint16(ackCount) //nolint:gosec // Already validated above to be non-negative and within uint16 range req := uint16(reqCount) msg := NewMsgRequestTxIds(blocking, ack, req) @@ -136,16 +175,24 @@ func (s *Server) RequestTxIds( return nil, err } // Wait for result - result, ok := <-s.requestTxIdsResultChan - if !ok { + s.restartMutex.Lock() + resultChan := s.requestTxIdsResultChan + s.restartMutex.Unlock() + select { + case result, ok := <-resultChan: + if !ok { + return nil, protocol.ErrProtocolShuttingDown + } + if result.err != nil { + return nil, result.err + } + // Update ack count for next call + // #nosec G115 - len(result.txIds) is bounded by MaxRequestCount (65535) which fits in int32 + atomic.StoreInt32(&s.ackCount, int32(len(result.txIds))) + return result.txIds, nil + case <-s.doneChan(): return nil, protocol.ErrProtocolShuttingDown } - if result.err != nil { - return nil, result.err - } - // Update ack count for next call - s.ackCount = len(result.txIds) - return result.txIds, nil } // RequestTxs requests the content of the requested TX identifiers from the remote node's mempool @@ -169,10 +216,16 @@ func (s *Server) RequestTxs(txIds []TxId) ([]TxBody, error) { return nil, err } // Wait for result + s.restartMutex.Lock() + resultChan := s.requestTxsResultChan + s.restartMutex.Unlock() select { - case <-s.DoneChan(): + case <-s.doneChan(): return nil, protocol.ErrProtocolShuttingDown - case txs := <-s.requestTxsResultChan: + case txs, ok := <-resultChan: + if !ok { + return nil, protocol.ErrProtocolShuttingDown + } return txs, nil } } @@ -207,9 +260,11 @@ func (s *Server) handleReplyTxIds(msg protocol.Message) error { "connection_id", s.callbackContext.ConnectionId.String(), ) msgReplyTxIds := msg.(*MsgReplyTxIds) + s.restartMutex.Lock() s.requestTxIdsResultChan <- requestTxIdsResult{ txIds: msgReplyTxIds.TxIds, } + s.restartMutex.Unlock() return nil } @@ -222,7 +277,9 @@ func (s *Server) handleReplyTxs(msg protocol.Message) error { "connection_id", s.callbackContext.ConnectionId.String(), ) msgReplyTxs := msg.(*MsgReplyTxs) + s.restartMutex.Lock() s.requestTxsResultChan <- msgReplyTxs.Txs + s.restartMutex.Unlock() return nil } @@ -234,9 +291,16 @@ func (s *Server) handleDone() error { "role", "server", "connection_id", s.callbackContext.ConnectionId.String(), ) - // Signal the RequestTxIds function to stop waiting - s.requestTxIdsResultChan <- requestTxIdsResult{ + // Signal the RequestTxIds function to stop waiting (non-blocking) + s.restartMutex.Lock() + resultChan := s.requestTxIdsResultChan + s.restartMutex.Unlock() + select { + case resultChan <- requestTxIdsResult{ err: ErrStopServerProcess, + }: + default: + // No one is waiting, which is fine } // Call the user callback function if s.config != nil && s.config.DoneFunc != nil { @@ -245,13 +309,43 @@ func (s *Server) handleDone() error { } } // Restart protocol - s.Stop() + s.restartMutex.Lock() + // Check if permanent stop has been requested + if s.stopped { + s.restartMutex.Unlock() + return nil + } + // Stop current protocol (without using onceStop since we're restarting) + s.Protocol.Logger(). + Debug("stopping server protocol for restart", + "component", "network", + "protocol", ProtocolName, + "connection_id", s.callbackContext.ConnectionId.String(), + ) + s.doneMutex.Lock() + select { + case <-s.done: + // Already closed by Stop() + default: + close(s.done) + } + s.doneMutex.Unlock() + stopErr := s.Protocol.Stop() s.initProtocol() - s.requestTxIdsResultChan = make(chan requestTxIdsResult) - s.requestTxsResultChan = make(chan []TxBody) - s.ackCount = 0 + s.requestTxIdsResultChan = make(chan requestTxIdsResult, 1) + s.requestTxsResultChan = make(chan []TxBody, 1) + s.doneMutex.Lock() + s.done = make(chan struct{}) + s.doneMutex.Unlock() + atomic.StoreInt32(&s.ackCount, 0) + s.restartMutex.Unlock() + // Check again if permanent stop has been requested (TOCTOU protection) + if s.IsStopped() { + return nil + } + // Start the new protocol outside the lock for better responsiveness s.Start() - return nil + return stopErr } func (s *Server) handleInit() error { diff --git a/protocol/txsubmission/server_concurrency_test.go b/protocol/txsubmission/server_concurrency_test.go new file mode 100644 index 00000000..92e6fd50 --- /dev/null +++ b/protocol/txsubmission/server_concurrency_test.go @@ -0,0 +1,143 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package txsubmission_test + +import ( + "sync" + "testing" + "time" + + ouroboros "github.com/blinklabs-io/gouroboros" + ouroboros_mock "github.com/blinklabs-io/ouroboros-mock" + + "go.uber.org/goleak" +) + +// TestServerConcurrentStop tests that concurrent Stop calls don't cause deadlocks +func TestServerConcurrentStop(t *testing.T) { + t.Skip("Skipping server test due to mock server issues with NtN protocol") + defer goleak.VerifyNone(t) + + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleServer, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + defer func() { + if err := oConn.Close(); err != nil { + t.Errorf("unexpected error when closing Ouroboros object: %s", err) + } + }() + + server := oConn.TxSubmission().Server + if server == nil { + t.Fatalf("TxSubmission server is nil") + } + + // Start the server + server.Start() + + // Run concurrent Stop operations + var wg sync.WaitGroup + const numGoroutines = 5 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + // All Stop calls should succeed (idempotent) + if err := server.Stop(); err != nil { + t.Errorf( + "goroutine %d: unexpected error when stopping server: %s", + id, + err, + ) + } + }(i) + } + + // Wait for all goroutines to complete with timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // All goroutines completed successfully + case <-time.After(5 * time.Second): + t.Fatal("concurrent Stop operations timed out - possible deadlock") + } +} + +// TestServerStopSetsStoppedFlag tests that calling Stop() properly sets the stopped flag +// TODO: Once mock infrastructure supports triggering handleDone(), add a test that verifies +// Stop() prevents handleDone() from restarting the protocol +func TestServerStopSetsStoppedFlag(t *testing.T) { + t.Skip("Skipping server test due to mock server issues with NtN protocol") + defer goleak.VerifyNone(t) + + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleServer, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + defer func() { + if err := oConn.Close(); err != nil { + t.Errorf("unexpected error when closing Ouroboros object: %s", err) + } + }() + + server := oConn.TxSubmission().Server + if server == nil { + t.Fatalf("TxSubmission server is nil") + } + + // Start the server + server.Start() + + // Stop the server + if err := server.Stop(); err != nil { + t.Errorf("unexpected error when stopping server: %s", err) + } + + // Verify that the server is marked as stopped + if !server.IsStopped() { + t.Error("server should be marked as stopped after calling Stop()") + } +} diff --git a/protocol/txsubmission/server_test.go b/protocol/txsubmission/server_test.go new file mode 100644 index 00000000..7da08d81 --- /dev/null +++ b/protocol/txsubmission/server_test.go @@ -0,0 +1,82 @@ +// Copyright 2025 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package txsubmission_test + +import ( + "fmt" + "testing" + "time" + + ouroboros "github.com/blinklabs-io/gouroboros" + ouroboros_mock "github.com/blinklabs-io/ouroboros-mock" + + "go.uber.org/goleak" +) + +func TestServerShutdown(t *testing.T) { + t.Skip("Skipping server test due to mock server issues with NtN protocol") + defer goleak.VerifyNone(t) + mockConn := ouroboros_mock.NewConnection( + ouroboros_mock.ProtocolRoleServer, + []ouroboros_mock.ConversationEntry{ + ouroboros_mock.ConversationEntryHandshakeRequestGeneric, + ouroboros_mock.ConversationEntryHandshakeNtNResponse, + }, + ) + asyncErrChan := make(chan error, 1) + go func() { + err := <-mockConn.(*ouroboros_mock.Connection).ErrorChan() + if err != nil { + asyncErrChan <- fmt.Errorf("received unexpected error: %w", err) + } + close(asyncErrChan) + }() + oConn, err := ouroboros.New( + ouroboros.WithConnection(mockConn), + ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic), + ouroboros.WithNodeToNode(true), + ) + if err != nil { + t.Fatalf("unexpected error when creating Ouroboros object: %s", err) + } + if oConn.TxSubmission() == nil { + t.Fatalf("TxSubmission server is nil") + } + // Start the server + oConn.TxSubmission().Server.Start() + // Stop the server + if err := oConn.TxSubmission().Server.Stop(); err != nil { + t.Fatalf("unexpected error when stopping server: %s", err) + } + // Wait for mock connection shutdown + select { + case err, ok := <-asyncErrChan: + if ok { + t.Fatal(err.Error()) + } + case <-time.After(2 * time.Second): + t.Fatalf("did not complete within timeout") + } + // Close Ouroboros connection + if err := oConn.Close(); err != nil { + t.Fatalf("unexpected error when closing Ouroboros object: %s", err) + } + // Wait for connection shutdown + select { + case <-oConn.ErrorChan(): + case <-time.After(10 * time.Second): + t.Errorf("did not shutdown within timeout") + } +}