diff --git a/impl/graphsync.go b/impl/graphsync.go index c5fa0084..0b1e036d 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -237,8 +237,8 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork, incomingRequestHooks.Register(selectorvalidator.SelectorValidator(maxRecursionDepth)) } responseAllocator := allocator.NewAllocator(gsConfig.totalMaxMemoryResponder, gsConfig.maxMemoryPerPeerResponder) - createMessageQueue := func(ctx context.Context, p peer.ID) peermanager.PeerQueue { - return messagequeue.New(ctx, p, network, responseAllocator, gsConfig.messageSendRetries, gsConfig.sendMessageTimeout) + createMessageQueue := func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) peermanager.PeerQueue { + return messagequeue.New(ctx, p, network, responseAllocator, gsConfig.messageSendRetries, gsConfig.sendMessageTimeout, onShutdown) } peerManager := peermanager.NewMessageManager(ctx, createMessageQueue) diff --git a/impl/graphsync_test.go b/impl/graphsync_test.go index d30f8320..b016cddb 100644 --- a/impl/graphsync_test.go +++ b/impl/graphsync_test.go @@ -523,7 +523,7 @@ func TestGraphsyncRoundTripIgnoreNBlocks(t *testing.T) { // create network ctx, collectTracing := testutil.SetupTracing(context.Background()) - ctx, cancel := context.WithTimeout(ctx, 1*time.Second) + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() td := newGsTestData(ctx, t) @@ -1039,6 +1039,20 @@ func TestNetworkDisconnect(t *testing.T) { drain(requestor) drain(responder) + // verify we can execute a request after disconnection + _, err = td.mn.LinkPeers(td.host1.ID(), td.host2.ID()) + require.NoError(t, err) + _, err = td.mn.ConnectPeers(td.host1.ID(), td.host2.ID()) + require.NoError(t, err) + requestCtx, requestCancel = context.WithTimeout(ctx, 1*time.Second) + defer requestCancel() + progressChan, errChan = requestor.Request(requestCtx, td.host2.ID(), blockChain.TipLink, blockChain.Selector(), td.extension) + blockChain.VerifyWholeChain(ctx, progressChan) + testutil.VerifyEmptyErrors(ctx, t, errChan) + + drain(requestor) + drain(responder) + tracing := collectTracing(t) traceStrings := tracing.TracesToStrings() diff --git a/messagequeue/messagequeue.go b/messagequeue/messagequeue.go index 11d9c0c6..29f3a7a0 100644 --- a/messagequeue/messagequeue.go +++ b/messagequeue/messagequeue.go @@ -79,10 +79,11 @@ type MessageQueue struct { allocator Allocator maxRetries int sendMessageTimeout time.Duration + onShutdown func(peer.ID) } // New creats a new MessageQueue. -func New(ctx context.Context, p peer.ID, network MessageNetwork, allocator Allocator, maxRetries int, sendMessageTimeout time.Duration) *MessageQueue { +func New(ctx context.Context, p peer.ID, network MessageNetwork, allocator Allocator, maxRetries int, sendMessageTimeout time.Duration, onShutdown func(peer.ID)) *MessageQueue { return &MessageQueue{ ctx: ctx, network: network, @@ -93,6 +94,7 @@ func New(ctx context.Context, p peer.ID, network MessageNetwork, allocator Alloc allocator: allocator, maxRetries: maxRetries, sendMessageTimeout: sendMessageTimeout, + onShutdown: onShutdown, } } @@ -154,6 +156,7 @@ func (mq *MessageQueue) runQueue() { defer func() { _ = mq.allocator.ReleasePeerMemory(mq.p) mq.eventPublisher.Shutdown() + mq.onShutdown(mq.p) }() mq.eventPublisher.Startup() for { diff --git a/messagequeue/messagequeue_test.go b/messagequeue/messagequeue_test.go index 4cad4f14..a7263ffc 100644 --- a/messagequeue/messagequeue_test.go +++ b/messagequeue/messagequeue_test.go @@ -28,7 +28,7 @@ func TestStartupAndShutdown(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() - peer := testutil.GeneratePeers(1)[0] + targetPeer := testutil.GeneratePeers(1)[0] messagesSent := make(chan gsmsg.GraphSyncMessage) resetChan := make(chan struct{}, 1) fullClosedChan := make(chan struct{}, 1) @@ -37,7 +37,7 @@ func TestStartupAndShutdown(t *testing.T) { messageNetwork := &fakeMessageNetwork{nil, nil, messageSender, &waitGroup} allocator := allocator2.NewAllocator(1<<30, 1<<30) - messageQueue := New(ctx, peer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout) + messageQueue := New(ctx, targetPeer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout, func(peer.ID) {}) messageQueue.Startup() id := graphsync.NewRequestID() priority := graphsync.Priority(rand.Int31()) @@ -62,7 +62,7 @@ func TestShutdownDuringMessageSend(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() - peer := testutil.GeneratePeers(1)[0] + targetPeer := testutil.GeneratePeers(1)[0] messagesSent := make(chan gsmsg.GraphSyncMessage) resetChan := make(chan struct{}, 1) fullClosedChan := make(chan struct{}, 1) @@ -75,7 +75,7 @@ func TestShutdownDuringMessageSend(t *testing.T) { messageNetwork := &fakeMessageNetwork{nil, nil, messageSender, &waitGroup} allocator := allocator2.NewAllocator(1<<30, 1<<30) - messageQueue := New(ctx, peer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout) + messageQueue := New(ctx, targetPeer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout, func(peer.ID) {}) messageQueue.Startup() id := graphsync.NewRequestID() priority := graphsync.Priority(rand.Int31()) @@ -114,7 +114,7 @@ func TestProcessingNotification(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() - peer := testutil.GeneratePeers(1)[0] + targetPeer := testutil.GeneratePeers(1)[0] messagesSent := make(chan gsmsg.GraphSyncMessage) resetChan := make(chan struct{}, 1) fullClosedChan := make(chan struct{}, 1) @@ -123,7 +123,7 @@ func TestProcessingNotification(t *testing.T) { messageNetwork := &fakeMessageNetwork{nil, nil, messageSender, &waitGroup} allocator := allocator2.NewAllocator(1<<30, 1<<30) - messageQueue := New(ctx, peer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout) + messageQueue := New(ctx, targetPeer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout, func(peer.ID) {}) messageQueue.Startup() waitGroup.Add(1) blks := testutil.GenerateBlocksOfSize(3, 128) @@ -187,7 +187,7 @@ func TestDedupingMessages(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() - peer := testutil.GeneratePeers(1)[0] + targetPeer := testutil.GeneratePeers(1)[0] messagesSent := make(chan gsmsg.GraphSyncMessage) resetChan := make(chan struct{}, 1) fullClosedChan := make(chan struct{}, 1) @@ -196,7 +196,7 @@ func TestDedupingMessages(t *testing.T) { messageNetwork := &fakeMessageNetwork{nil, nil, messageSender, &waitGroup} allocator := allocator2.NewAllocator(1<<30, 1<<30) - messageQueue := New(ctx, peer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout) + messageQueue := New(ctx, targetPeer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout, func(peer.ID) {}) messageQueue.Startup() waitGroup.Add(1) id := graphsync.NewRequestID() @@ -265,7 +265,7 @@ func TestSendsVeryLargeBlocksResponses(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - peer := testutil.GeneratePeers(1)[0] + targetPeer := testutil.GeneratePeers(1)[0] messagesSent := make(chan gsmsg.GraphSyncMessage) resetChan := make(chan struct{}, 1) fullClosedChan := make(chan struct{}, 1) @@ -274,7 +274,7 @@ func TestSendsVeryLargeBlocksResponses(t *testing.T) { messageNetwork := &fakeMessageNetwork{nil, nil, messageSender, &waitGroup} allocator := allocator2.NewAllocator(1<<30, 1<<30) - messageQueue := New(ctx, peer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout) + messageQueue := New(ctx, targetPeer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout, func(peer.ID) {}) messageQueue.Startup() waitGroup.Add(1) @@ -334,7 +334,7 @@ func TestSendsResponsesMemoryPressure(t *testing.T) { // use allocator with very small limit allocator := allocator2.NewAllocator(1000, 1000) - messageQueue := New(ctx, p, messageNetwork, allocator, messageSendRetries, sendMessageTimeout) + messageQueue := New(ctx, p, messageNetwork, allocator, messageSendRetries, sendMessageTimeout, func(peer.ID) {}) messageQueue.Startup() waitGroup.Add(1) @@ -381,7 +381,7 @@ func TestNetworkErrorClearResponses(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - peer := testutil.GeneratePeers(1)[0] + targetPeer := testutil.GeneratePeers(1)[0] messagesSent := make(chan gsmsg.GraphSyncMessage) resetChan := make(chan struct{}, 1) fullClosedChan := make(chan struct{}, 1) @@ -393,7 +393,7 @@ func TestNetworkErrorClearResponses(t *testing.T) { allocator := allocator2.NewAllocator(1<<30, 1<<30) // we use only a retry count of 1 to avoid multiple send attempts for each message - messageQueue := New(ctx, peer, messageNetwork, allocator, 1, sendMessageTimeout) + messageQueue := New(ctx, targetPeer, messageNetwork, allocator, 1, sendMessageTimeout, func(peer.ID) {}) messageQueue.Startup() waitGroup.Add(1) diff --git a/peermanager/peermanager.go b/peermanager/peermanager.go index 4204d4f3..e2ea0ea2 100644 --- a/peermanager/peermanager.go +++ b/peermanager/peermanager.go @@ -16,7 +16,7 @@ type PeerProcess interface { type PeerHandler interface{} // PeerProcessFactory provides a function that will create a PeerQueue. -type PeerProcessFactory func(ctx context.Context, p peer.ID) PeerHandler +type PeerProcessFactory func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) PeerHandler type peerProcessInstance struct { refcnt int @@ -105,7 +105,7 @@ func (pm *PeerManager) GetProcess( func (pm *PeerManager) getOrCreate(p peer.ID) *peerProcessInstance { pqi, ok := pm.peerProcesses[p] if !ok { - pq := pm.createPeerProcess(pm.ctx, p) + pq := pm.createPeerProcess(pm.ctx, p, pm.onQueueShutdown) if pprocess, ok := pq.(PeerProcess); ok { pprocess.Startup() } @@ -114,3 +114,9 @@ func (pm *PeerManager) getOrCreate(p peer.ID) *peerProcessInstance { } return pqi } + +func (pm *PeerManager) onQueueShutdown(p peer.ID) { + pm.peerProcessesLk.Lock() + defer pm.peerProcessesLk.Unlock() + delete(pm.peerProcesses, p) +} diff --git a/peermanager/peermanager_test.go b/peermanager/peermanager_test.go index 2c4cb3d0..8f09e750 100644 --- a/peermanager/peermanager_test.go +++ b/peermanager/peermanager_test.go @@ -17,7 +17,7 @@ func (fp *fakePeerProcess) Shutdown() {} func TestAddingAndRemovingPeers(t *testing.T) { ctx := context.Background() - peerProcessFatory := func(ctx context.Context, p peer.ID) PeerHandler { + peerProcessFatory := func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) PeerHandler { return &fakePeerProcess{} } diff --git a/peermanager/peermessagemanager.go b/peermanager/peermessagemanager.go index f21acc29..dc3f2000 100644 --- a/peermanager/peermessagemanager.go +++ b/peermanager/peermessagemanager.go @@ -15,7 +15,7 @@ type PeerQueue interface { } // PeerQueueFactory provides a function that will create a PeerQueue. -type PeerQueueFactory func(ctx context.Context, p peer.ID) PeerQueue +type PeerQueueFactory func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) PeerQueue // PeerMessageManager manages message queues for peers type PeerMessageManager struct { @@ -25,8 +25,8 @@ type PeerMessageManager struct { // NewMessageManager generates a new manger for sending messages func NewMessageManager(ctx context.Context, createPeerQueue PeerQueueFactory) *PeerMessageManager { return &PeerMessageManager{ - PeerManager: New(ctx, func(ctx context.Context, p peer.ID) PeerHandler { - return createPeerQueue(ctx, p) + PeerManager: New(ctx, func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) PeerHandler { + return createPeerQueue(ctx, p, onShutdown) }), } } diff --git a/peermanager/peermessagemanager_test.go b/peermanager/peermessagemanager_test.go index ac814b14..aa9c4b0b 100644 --- a/peermanager/peermessagemanager_test.go +++ b/peermanager/peermessagemanager_test.go @@ -27,6 +27,7 @@ var _ PeerQueue = (*fakePeer)(nil) type fakePeer struct { p peer.ID messagesSent chan messageSent + onShutdown func(peer.ID) } func (fp *fakePeer) AllocateAndBuildMessage(blkSize uint64, buildMessage func(b *messagequeue.Builder)) { @@ -50,10 +51,11 @@ func (fp *fakePeer) Shutdown() {} //} func makePeerQueueFactory(messagesSent chan messageSent) PeerQueueFactory { - return func(ctx context.Context, p peer.ID) PeerQueue { + return func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) PeerQueue { return &fakePeer{ p: p, messagesSent: messagesSent, + onShutdown: onShutdown, } } } diff --git a/responsemanager/responseassembler/responseassembler.go b/responsemanager/responseassembler/responseassembler.go index 48607ab1..5236174f 100644 --- a/responsemanager/responseassembler/responseassembler.go +++ b/responsemanager/responseassembler/responseassembler.go @@ -69,7 +69,7 @@ type ResponseAssembler struct { // New generates a new ResponseAssembler for sending responses func New(ctx context.Context, peerHandler PeerMessageHandler) *ResponseAssembler { return &ResponseAssembler{ - PeerManager: peermanager.New(ctx, func(ctx context.Context, p peer.ID) peermanager.PeerHandler { + PeerManager: peermanager.New(ctx, func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) peermanager.PeerHandler { return newTracker() }), peerHandler: peerHandler,