From 5fd481be1159ea4a4f47054af856b3cd7764d44a Mon Sep 17 00:00:00 2001 From: Dirk McCormick Date: Mon, 29 Mar 2021 12:08:56 +0200 Subject: [PATCH] feat: respect pausing in pull channel monitor --- channelmonitor/channelmonitor.go | 42 ++++++++ channelmonitor/channelmonitor_test.go | 147 ++++++++++++++++++++++++++ impl/integration_test.go | 136 ++++++++++++++++++------ 3 files changed, 290 insertions(+), 35 deletions(-) diff --git a/channelmonitor/channelmonitor.go b/channelmonitor/channelmonitor.go index f79fb2b4..2c7037b8 100644 --- a/channelmonitor/channelmonitor.go +++ b/channelmonitor/channelmonitor.go @@ -53,6 +53,8 @@ type Config struct { // Max time to wait for the responder to send a Complete message once all // data has been sent CompleteTimeout time.Duration + // Max time to wait for a resume after the responder paused the channel - for pull channels + PullPauseTimeout time.Duration } func NewMonitor(mgr monitorAPI, cfg *Config) *Monitor { @@ -91,6 +93,9 @@ func checkConfig(cfg *Config) { if cfg.CompleteTimeout <= 0 { panic(fmt.Sprintf(prefix+"CompleteTimeout is %s but must be > 0", cfg.CompleteTimeout)) } + if cfg.MonitorPullChannels && cfg.PullPauseTimeout <= 0 { + panic(fmt.Sprintf(prefix+"PullPauseTimeout is %s but must be > 0", cfg.PullPauseTimeout)) + } } // This interface just makes it easier to abstract some methods between the @@ -563,6 +568,7 @@ type monitoredPullChannel struct { statsLk sync.RWMutex received uint64 dataRatePoints chan uint64 + pausedAt time.Time } func newMonitoredPullChannel( @@ -585,6 +591,30 @@ func (mc *monitoredPullChannel) checkDataRate() { mc.statsLk.Lock() defer mc.statsLk.Unlock() + // If the channel is currently paused + if !mc.pausedAt.IsZero() { + // Check if the channel has been paused for too long + pausedFor := time.Since(mc.pausedAt) + if pausedFor > mc.cfg.PullPauseTimeout { + log.Warn(fmt.Sprintf("%s: paused for too long, restarting channel: ", mc.chid) + + fmt.Sprintf("paused since %s: %s is longer than pause timeout %s", + mc.pausedAt, pausedFor, mc.cfg.PullPauseTimeout)) + + // Reset pause so we can continue checking the data rate + mc.pausedAt = time.Time{} + + // Restart the channel + go mc.restartChannel() + + return + } + + // Channel is paused, so wait for a resume before checking the data rate + log.Debugf("%s: is paused since %s; waiting for resume to continue checking data rate", + mc.chid, mc.pausedAt) + return + } + // Before returning, add the current data rate stats to the queue defer func() { mc.dataRatePoints <- mc.received @@ -624,5 +654,17 @@ func (mc *monitoredPullChannel) onDTEvent(event datatransfer.Event, channelState // Some data was received so reset the consecutive restart counter mc.resetConsecutiveRestarts() + + case datatransfer.PauseResponder: + // The sender has paused sending data + mc.statsLk.Lock() + mc.pausedAt = time.Now() + mc.statsLk.Unlock() + + case datatransfer.ResumeResponder: + // The sender has resumed sending data + mc.statsLk.Lock() + mc.pausedAt = time.Time{} + mc.statsLk.Unlock() } } diff --git a/channelmonitor/channelmonitor_test.go b/channelmonitor/channelmonitor_test.go index 885d8d4c..ff54e6d7 100644 --- a/channelmonitor/channelmonitor_test.go +++ b/channelmonitor/channelmonitor_test.go @@ -66,6 +66,7 @@ func TestPushChannelMonitorAutoRestart(t *testing.T) { MinBytesTransferred: 1, MaxConsecutiveRestarts: 3, CompleteTimeout: time.Hour, + PullPauseTimeout: time.Hour, }) m.Start() mch := m.AddPushChannel(ch1).(*monitoredPushChannel) @@ -152,6 +153,7 @@ func TestPullChannelMonitorAutoRestart(t *testing.T) { MinBytesTransferred: 1, MaxConsecutiveRestarts: 3, CompleteTimeout: time.Hour, + PullPauseTimeout: time.Hour, }) m.Start() mch := m.AddPullChannel(ch1).(*monitoredPullChannel) @@ -315,6 +317,7 @@ func TestPushChannelMonitorDataRate(t *testing.T) { MinBytesTransferred: tc.minBytesSent, MaxConsecutiveRestarts: 3, CompleteTimeout: time.Hour, + PullPauseTimeout: time.Hour, }) // Note: Don't start monitor, we'll call checkDataRate() manually @@ -384,6 +387,7 @@ func TestPullChannelMonitorDataRate(t *testing.T) { MinBytesTransferred: tc.minBytesTransferred, MaxConsecutiveRestarts: 3, CompleteTimeout: time.Hour, + PullPauseTimeout: time.Hour, }) // Note: Don't start monitor, we'll call checkDataRate() manually @@ -414,6 +418,139 @@ func TestPullChannelMonitorDataRate(t *testing.T) { } } +func TestPullChannelMonitorPausing(t *testing.T) { + ch := &mockChannelState{chid: ch1} + mockAPI := newMockMonitorAPI(ch, false) + + checkIfRestarted := func(expectRestart bool) { + select { + case <-time.After(5 * time.Millisecond): + if expectRestart { + require.Fail(t, "failed to restart channel") + } + case <-mockAPI.restarts: + if !expectRestart { + require.Fail(t, "expected no channel restart") + } + } + } + + minBytesTransferred := uint64(10) + m := NewMonitor(mockAPI, &Config{ + MonitorPullChannels: true, + AcceptTimeout: time.Hour, + Interval: time.Hour, + ChecksPerInterval: 1, + MinBytesTransferred: minBytesTransferred, + MaxConsecutiveRestarts: 3, + CompleteTimeout: time.Hour, + PullPauseTimeout: time.Hour, + }) + + // Note: Don't start monitor, we'll call checkDataRate() manually + + m.AddPullChannel(ch1) + + lastRcvd := uint64(5) + mockAPI.dataReceived(lastRcvd) + m.checkDataRate() + + // Some data received, but less than required amount + mockAPI.dataReceived(lastRcvd + minBytesTransferred/2) + + // If responder is paused, the monitor should ignore data + // rate checking until responder resumes + mockAPI.pauseResponder() + m.checkDataRate() // Should be ignored because responder is paused + m.checkDataRate() + m.checkDataRate() + + // Should not restart + checkIfRestarted(false) + + // Resume the responder + mockAPI.resumeResponder() + + // Receive some data + lastRcvd = 100 + mockAPI.dataReceived(lastRcvd) + + // Should not restart because received data exceeds minimum required + m.checkDataRate() + checkIfRestarted(false) + + // Pause responder again + mockAPI.pauseResponder() + m.checkDataRate() // Should be ignored because responder is paised + m.checkDataRate() + m.checkDataRate() + + // Resume responder + mockAPI.resumeResponder() + + // Not enough data received, should restart + mockAPI.dataReceived(lastRcvd + minBytesTransferred/2) + m.checkDataRate() + checkIfRestarted(true) +} + +func TestPullChannelMonitorPauseTimeout(t *testing.T) { + ch := &mockChannelState{chid: ch1} + mockAPI := newMockMonitorAPI(ch, false) + + checkIfRestarted := func(expectRestart bool) { + select { + case <-time.After(5 * time.Millisecond): + if expectRestart { + require.Fail(t, "failed to restart channel") + } + case <-mockAPI.restarts: + if !expectRestart { + require.Fail(t, "expected no channel restart") + } + } + } + + minBytesTransferred := uint64(10) + pullPauseTimeout := 50 * time.Millisecond + m := NewMonitor(mockAPI, &Config{ + MonitorPullChannels: true, + AcceptTimeout: time.Hour, + Interval: time.Hour, + ChecksPerInterval: 1, + MinBytesTransferred: minBytesTransferred, + MaxConsecutiveRestarts: 3, + CompleteTimeout: time.Hour, + PullPauseTimeout: pullPauseTimeout, + }) + + // Note: Don't start monitor, we'll call checkDataRate() manually + + m.AddPullChannel(ch1) + + lastRcvd := uint64(5) + mockAPI.dataReceived(lastRcvd) + m.checkDataRate() + + // Some data received, but less than required amount + mockAPI.dataReceived(lastRcvd + minBytesTransferred/2) + + // If responder is paused, the monitor should ignore data + // rate checking until responder resumes + mockAPI.pauseResponder() + m.checkDataRate() // Should be ignored because responder is paused + + // Should not restart + checkIfRestarted(false) + + // Pause timeout elapses + time.Sleep(pullPauseTimeout * 2) + + // Should detect timeout has elapsed and restart + m.checkDataRate() + checkIfRestarted(true) +} + func TestChannelMonitorMaxConsecutiveRestarts(t *testing.T) { runTest := func(name string, isPush bool) { t.Run(name, func(t *testing.T) { @@ -430,6 +567,7 @@ func TestChannelMonitorMaxConsecutiveRestarts(t *testing.T) { MinBytesTransferred: 2, MaxConsecutiveRestarts: uint32(maxConsecutiveRestarts), CompleteTimeout: time.Hour, + PullPauseTimeout: time.Hour, }) // Note: Don't start monitor, we'll call checkDataRate() manually @@ -550,6 +688,7 @@ func TestChannelMonitorTimeouts(t *testing.T) { MinBytesTransferred: 1, MaxConsecutiveRestarts: 1, CompleteTimeout: completeTimeout, + PullPauseTimeout: time.Hour, }) m.Start() @@ -715,6 +854,14 @@ func (m *mockMonitorAPI) sendDataErrorEvent() { m.callSubscriber(datatransfer.Event{Code: datatransfer.SendDataError}, m.ch) } +func (m *mockMonitorAPI) pauseResponder() { + m.callSubscriber(datatransfer.Event{Code: datatransfer.PauseResponder}, m.ch) +} + +func (m *mockMonitorAPI) resumeResponder() { + m.callSubscriber(datatransfer.Event{Code: datatransfer.ResumeResponder}, m.ch) +} + type mockChannelState struct { chid datatransfer.ChannelID queued uint64 diff --git a/impl/integration_test.go b/impl/integration_test.go index 77bd3ade..6b0a07f0 100644 --- a/impl/integration_test.go +++ b/impl/integration_test.go @@ -5,6 +5,7 @@ import ( "context" "math/rand" "os" + "sync/atomic" "testing" "time" @@ -672,6 +673,7 @@ func TestAutoRestart(t *testing.T) { RestartBackoff: 500 * time.Millisecond, MaxConsecutiveRestarts: 5, CompleteTimeout: 100 * time.Millisecond, + PullPauseTimeout: 100 * time.Millisecond, }) initiator, err := NewDataTransfer(gsData.DtDs1, gsData.TempDir1, gsData.DtNet1, initiatorGSTspt, restartConf) require.NoError(t, err) @@ -935,20 +937,30 @@ func (r *retrievalRevalidator) OnComplete(chid datatransfer.ChannelID) (bool, da func TestSimulatedRetrievalFlow(t *testing.T) { ctx := context.Background() testCases := map[string]struct { - unpauseRequestorDelay time.Duration - unpauseResponderDelay time.Duration - pausePoints []uint64 + unpauseResponderDelay time.Duration + channelMonitorPauseTimeout time.Duration + restartExpected bool }{ - "fast unseal, payment channel ready": { - pausePoints: []uint64{1000, 3000, 6000, 10000, 15000}, + // Simulate a retrieval where the provider pauses the data transfer + // while the file is being unsealed, for just a moment + "fast unseal": { + unpauseResponderDelay: 0, + channelMonitorPauseTimeout: 500 * time.Millisecond, + restartExpected: false, }, - "fast unseal, payment channel not ready": { - unpauseRequestorDelay: 100 * time.Millisecond, - pausePoints: []uint64{1000, 3000, 6000, 10000, 15000}, + // Simulate a retrieval where the provider pauses the data transfer + // while the file is being unsealed, for a relatively long time + "slow unseal": { + unpauseResponderDelay: 200 * time.Millisecond, + channelMonitorPauseTimeout: 500 * time.Millisecond, + restartExpected: false, }, - "slow unseal, payment channel ready": { - unpauseResponderDelay: 200 * time.Millisecond, - pausePoints: []uint64{1000, 3000, 6000, 10000, 15000}, + // Simulate a retrieval where the provider pauses the data transfer + // while the file is being unsealed, for longer than the unpause timeout + "unseal slower than unpause timeout": { + unpauseResponderDelay: 200 * time.Millisecond, + channelMonitorPauseTimeout: 100 * time.Millisecond, + restartExpected: true, }, } for testCase, config := range testCases { @@ -956,61 +968,102 @@ func TestSimulatedRetrievalFlow(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 4*time.Second) defer cancel() + // The provider will pause sending data and wait for a payment + // voucher at each of these points in the transfer + pausePoints := []uint64{1000, 3000, 6000, 10000, 15000} + gsData := testutil.NewGraphsyncTestingData(ctx, t, nil, nil) - host1 := gsData.Host1 // initiator, data sender + provHost := gsData.Host1 + // Load a file into the provider blockstore root := gsData.LoadUnixFSFile(t, false) rootCid := root.(cidlink.Link).Cid - tp1 := gsData.SetupGSTransportHost1() - tp2 := gsData.SetupGSTransportHost2() - dt1, err := NewDataTransfer(gsData.DtDs1, gsData.TempDir1, gsData.DtNet1, tp1) + tpProv := gsData.SetupGSTransportHost1() + tpClient := gsData.SetupGSTransportHost2() + + // Set up restart config for client + clientRestartConf := ChannelRestartConfig(channelmonitor.Config{ + MonitorPullChannels: true, + AcceptTimeout: 100 * time.Millisecond, + Interval: 100 * time.Millisecond, + MinBytesTransferred: 1, + ChecksPerInterval: 2, + RestartBackoff: 500 * time.Millisecond, + MaxConsecutiveRestarts: 5, + CompleteTimeout: 100 * time.Millisecond, + PullPauseTimeout: config.channelMonitorPauseTimeout, + }) + + // Setup data transfer for client and provider + dtClient, err := NewDataTransfer(gsData.DtDs2, gsData.TempDir2, gsData.DtNet2, tpClient, clientRestartConf) require.NoError(t, err) - testutil.StartAndWaitForReady(ctx, t, dt1) - dt2, err := NewDataTransfer(gsData.DtDs2, gsData.TempDir2, gsData.DtNet2, tp2) + testutil.StartAndWaitForReady(ctx, t, dtClient) + + dtProv, err := NewDataTransfer(gsData.DtDs1, gsData.TempDir1, gsData.DtNet1, tpProv) require.NoError(t, err) - testutil.StartAndWaitForReady(ctx, t, dt2) + testutil.StartAndWaitForReady(ctx, t, dtProv) + + // Setup requester event listener var chid datatransfer.ChannelID errChan := make(chan struct{}, 2) + restarted := int32(0) clientPausePoint := 0 clientFinished := make(chan struct{}, 1) finalVoucherResult := testutil.NewFakeDTType() encodedFVR, err := encoding.Encode(finalVoucherResult) require.NoError(t, err) var clientSubscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { - if event.Code == datatransfer.Error { - errChan <- struct{}{} - } + // Watch for the final payment request from the provider, and + // respond with a payment voucher if event.Code == datatransfer.NewVoucherResult { lastVoucherResult := channelState.LastVoucherResult() encodedLVR, err := encoding.Encode(lastVoucherResult) require.NoError(t, err) if bytes.Equal(encodedLVR, encodedFVR) { - _ = dt2.SendVoucher(ctx, chid, testutil.NewFakeDTType()) + _ = dtClient.SendVoucher(ctx, chid, testutil.NewFakeDTType()) } } + // When data received exceeds a pause point, the provider will + // have paused the transfer while waiting for a payment voucher. + // So send a payment voucher from the client. if event.Code == datatransfer.DataReceived && - clientPausePoint < len(config.pausePoints) && - channelState.Received() > config.pausePoints[clientPausePoint] { - _ = dt2.SendVoucher(ctx, chid, testutil.NewFakeDTType()) + clientPausePoint < len(pausePoints) && + channelState.Received() > pausePoints[clientPausePoint] { + _ = dtClient.SendVoucher(ctx, chid, testutil.NewFakeDTType()) clientPausePoint++ } + + if event.Code == datatransfer.Restart { + atomic.AddInt32(&restarted, 1) + } + if event.Code == datatransfer.Error { + errChan <- struct{}{} + } if channelState.Status() == datatransfer.Completed { clientFinished <- struct{}{} } } - dt2.SubscribeToEvents(clientSubscriber) + dtClient.SubscribeToEvents(clientSubscriber) + + // Setup responder event listener providerFinished := make(chan struct{}, 1) providerAccepted := false var providerSubscriber datatransfer.Subscriber = func(event datatransfer.Event, channelState datatransfer.ChannelState) { + // The provider should immediately pause the channel when it + // receives an open channel request. if event.Code == datatransfer.PauseResponder { if !providerAccepted { providerAccepted = true + + // Simulate pausing while data is unsealed timer := time.NewTimer(config.unpauseResponderDelay) go func() { <-timer.C - _ = dt1.ResumeDataTransferChannel(ctx, chid) + + // Resume after unseal completes + _ = dtProv.ResumeDataTransferChannel(ctx, chid) }() } } @@ -1021,20 +1074,24 @@ func TestSimulatedRetrievalFlow(t *testing.T) { providerFinished <- struct{}{} } } - dt1.SubscribeToEvents(providerSubscriber) + dtProv.SubscribeToEvents(providerSubscriber) voucher := testutil.FakeDTType{Data: "applesauce"} sv := testutil.NewStubbedValidator() sv.ExpectPausePull() - require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dtProv.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + // Set up a revalidator on the provider that will pause at + // configurable pause points srv := &retrievalRevalidator{ - testutil.NewStubbedRevalidator(), 0, 0, config.pausePoints, finalVoucherResult, + testutil.NewStubbedRevalidator(), 0, 0, pausePoints, finalVoucherResult, } srv.ExpectSuccessRevalidation() - require.NoError(t, dt1.RegisterRevalidator(testutil.NewFakeDTType(), srv)) + require.NoError(t, dtProv.RegisterRevalidator(testutil.NewFakeDTType(), srv)) + + require.NoError(t, dtClient.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + require.NoError(t, dtClient.RegisterVoucherResultType(testutil.NewFakeDTType())) - require.NoError(t, dt2.RegisterVoucherResultType(testutil.NewFakeDTType())) - chid, err = dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) + chid, err = dtClient.OpenPullDataChannel(ctx, provHost.ID(), &voucher, rootCid, gsData.AllSelector) require.NoError(t, err) for providerFinished != nil || clientFinished != nil { @@ -1049,11 +1106,20 @@ func TestSimulatedRetrievalFlow(t *testing.T) { t.Fatal("received unexpected error") } } + + // Check if there was a restart, and whether it was expected + if restarted == 0 && config.restartExpected { + require.Fail(t, "expected restart but did not restart") + } + if restarted > 0 && !config.restartExpected { + require.Fail(t, "expected no restart but there was a restart") + } + sv.VerifyExpectations(t) srv.VerifyExpectations(t) gsData.VerifyFileTransferred(t, root, true) - require.Equal(t, srv.providerPausePoint, len(config.pausePoints)) - require.Equal(t, clientPausePoint, len(config.pausePoints)) + require.Equal(t, srv.providerPausePoint, len(pausePoints)) + require.Equal(t, clientPausePoint, len(pausePoints)) }) } }