Skip to content

Commit

Permalink
routing: add cancelable context to interrupt payment attempts
Browse files Browse the repository at this point in the history
  • Loading branch information
hieblmi committed May 10, 2024
1 parent ec7df15 commit 0bfb3db
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 31 deletions.
6 changes: 4 additions & 2 deletions lnrpc/routerrpc/router_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest,
}

// Send the payment asynchronously.
s.cfg.Router.SendPaymentAsync(payment, paySession, shardTracker)
s.cfg.Router.SendPaymentAsync(
stream.Context(), payment, paySession, shardTracker,
)

// Track the payment and return.
return s.trackPayment(
Expand Down Expand Up @@ -987,7 +989,7 @@ func (s *Server) SetMissionControlConfig(ctx context.Context,
req.Config.HopProbability,
),
AprioriWeight: float64(req.Config.Weight),
CapacityFraction: routing.DefaultCapacityFraction,
CapacityFraction: routing.DefaultCapacityFraction, //nolint:lll
}
}

Expand Down
71 changes: 58 additions & 13 deletions routing/payment_lifecycle.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routing

import (
"context"
"errors"
"fmt"
"time"
Expand Down Expand Up @@ -167,7 +168,9 @@ func (p *paymentLifecycle) decideNextStep(
}

// resumePayment resumes the paymentLifecycle from the current state.
func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte,
*route.Route, error) {

// When the payment lifecycle loop exits, we make sure to signal any
// sub goroutine of the HTLC attempt to exit, then wait for them to
// return.
Expand Down Expand Up @@ -221,18 +224,18 @@ lifecycle:

// We now proceed our lifecycle with the following tasks in
// order,
// 1. check timeout.
// 1. check timeout and context.
// 2. request route.
// 3. create HTLC attempt.
// 4. send HTLC attempt.
// 5. collect HTLC attempt result.
//
// Before we attempt any new shard, we'll check to see if
// either we've gone past the payment attempt timeout, or the
// router is exiting. In either case, we'll stop this payment
// attempt short. If a timeout is not applicable, timeoutChan
// will be nil.
if err := p.checkTimeout(); err != nil {
// we've gone past the payment attempt timeout, or if the
// context was cancelled, or the router is exiting. In any of
// these cases, we'll stop this payment attempt short. If a
// timeout is not applicable, timeoutChan will be nil.
if err := p.checkTimeoutAndContext(ctx); err != nil {
return exitWithErr(err)
}

Expand Down Expand Up @@ -319,27 +322,69 @@ lifecycle:
}

// checkTimeout checks whether the payment has reached its timeout.
func (p *paymentLifecycle) checkTimeout() error {
func (p *paymentLifecycle) checkTimeoutAndContext(ctx context.Context) error {
failPayment := func(reason channeldb.FailureReason) error {
// By marking the payment failed, depending on whether it has
// inflight HTLCs or not, its status will now either be
// `StatusInflight` or `StatusFailed`. In either case, no more
// HTLCs will be attempted.
err := p.router.cfg.Control.FailPayment(p.identifier, reason)
if err != nil {
return fmt.Errorf("FailPayment got %w", err)
}

return nil
}

select {
case <-p.timeoutChan:
log.Warnf("payment attempt not completed before timeout")
err := failPayment(channeldb.FailureReasonTimeout)
if err != nil {
return err
}

case <-ctx.Done():
log.Warnf("payment attempt context canceled")
err := failPayment(channeldb.FailureReasonError)
if err != nil {
return err
}

return ctx.Err()

case <-p.router.quit:
return fmt.Errorf("check payment timeout got: %w",
ErrRouterShuttingDown)

// Fall through if we haven't hit our time limit.
default:
}

return nil
}

// checkContext checks whether the payment context has been cancelled.
func (p *paymentLifecycle) checkContext(ctx context.Context) error {
if ctx == nil {
return nil
}

select {
case <-ctx.Done():
// By marking the payment failed, depending on whether it has
// inflight HTLCs or not, its status will now either be
// `StatusInflight` or `StatusFailed`. In either case, no more
// HTLCs will be attempted.
err := p.router.cfg.Control.FailPayment(
p.identifier, channeldb.FailureReasonTimeout,
p.identifier, channeldb.FailureReasonError,
)
if err != nil {
return fmt.Errorf("FailPayment got %w", err)
}

case <-p.router.quit:
return fmt.Errorf("check payment timeout got: %w",
ErrRouterShuttingDown)
return ctx.Err()

// Fall through if we haven't hit our time limit.
default:
}

Expand Down
92 changes: 85 additions & 7 deletions routing/payment_lifecycle_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routing

import (
"context"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -153,15 +154,15 @@ type resumePaymentResult struct {

// sendPaymentAndAssertFailed calls `resumePayment` and asserts that an error
// is returned.
func sendPaymentAndAssertFailed(t *testing.T,
p *paymentLifecycle, errExpected error) {
func sendPaymentAndAssertFailed(t *testing.T, p *paymentLifecycle,
errExpected error) {

resultChan := make(chan *resumePaymentResult, 1)

// We now make a call to `resumePayment` and expect it to return the
// error.
go func() {
preimage, _, err := p.resumePayment()
preimage, _, err := p.resumePayment(context.Background())
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
Expand All @@ -179,6 +180,34 @@ func sendPaymentAndAssertFailed(t *testing.T,
}
}

// sendPaymentAndAssertFailed calls `resumePayment` and asserts that an error
// is returned.
func sendPaymentAndAssertContextCancelled(t *testing.T,
ctx context.Context, p *paymentLifecycle, errExpected error) {

resultChan := make(chan *resumePaymentResult, 1)

// We now make a call to `resumePayment` and expect it to return the
// error.
go func() {
preimage, _, err := p.resumePayment(ctx)
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
}
}()

// Validate the returned values or timeout.
select {
case r := <-resultChan:
require.Error(t, r.err, errExpected, "expected error")
require.Empty(t, r.preimage, "preimage should be empty")

case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for result")
}
}

// sendPaymentAndAssertSucceeded calls `resumePayment` and asserts that the
// returned preimage is correct.
func sendPaymentAndAssertSucceeded(t *testing.T,
Expand All @@ -189,7 +218,7 @@ func sendPaymentAndAssertSucceeded(t *testing.T,
// We now make a call to `resumePayment` and expect it to return the
// preimage.
go func() {
preimage, _, err := p.resumePayment()
preimage, _, err := p.resumePayment(context.Background())
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
Expand Down Expand Up @@ -278,6 +307,9 @@ func makeAttemptInfo(t *testing.T, amt int) channeldb.HTLCAttemptInfo {
func TestCheckTimeoutTimedOut(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

p := createTestPaymentLifecycle()

// Mock the control tower's `FailPayment` method.
Expand All @@ -295,7 +327,7 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
time.Sleep(1 * time.Millisecond)

// Call the function and expect no error.
err := p.checkTimeout()
err := p.checkTimeoutAndContext(ctx)
require.NoError(t, err)

// Assert that `FailPayment` is called as expected.
Expand All @@ -319,7 +351,7 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
time.Sleep(1 * time.Millisecond)

// Call the function and expect an error.
err = p.checkTimeout()
err = p.checkTimeoutAndContext(ctx)
require.ErrorIs(t, err, errDummy)

// Assert that `FailPayment` is called as expected.
Expand All @@ -331,10 +363,13 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
func TestCheckTimeoutOnRouterQuit(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

p := createTestPaymentLifecycle()

close(p.router.quit)
err := p.checkTimeout()
err := p.checkTimeoutAndContext(ctx)
require.ErrorIs(t, err, ErrRouterShuttingDown)
}

Expand Down Expand Up @@ -727,6 +762,49 @@ func TestResumePaymentFailOnTimeoutErr(t *testing.T) {
require.Zero(t, m.collectResultsCount)
}

// TestResumePaymentFailContextCancel checks that the lifecycle fails when the
// context is canceled and an error is returned from `checkContext`.
//
// NOTE: No parallel test because it overwrites global variables.
//
//nolint:paralleltest
func TestResumePaymentFailContextCancel(t *testing.T) {
// Create a test paymentLifecycle with the initial two calls mocked.
p, m := setupTestPaymentLifecycle(t)

paymentAmt := lnwire.MilliSatoshi(10000)

// We now enter the payment lifecycle loop.
//
// 1. calls `FetchPayment` and return the payment.
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()

// 2. calls `GetState` and return the state.
ps := &channeldb.MPPaymentState{
RemainingAmt: paymentAmt,
}
m.payment.On("GetState").Return(ps).Once()

// NOTE: GetStatus is only used to populate the logs which is not
// critical, so we loosen the checks on how many times it's been called.
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)

// 3. Cancel the context and skip the FailPayment error to trigger the
// context cancellation of the payment.
m.control.On(
"FailPayment", p.identifier, channeldb.FailureReasonError,
).Return(nil).Once()

ctx, cancel := context.WithCancel(context.Background())
cancel()

// Send the payment and assert that its context got cancelled.
sendPaymentAndAssertContextCancelled(t, ctx, p, ErrContextCancelled)

// Expected collectResultAsync to not be called.
require.Zero(t, m.collectResultsCount)
}

// TestResumePaymentFailOnStepErr checks that the lifecycle fails when an
// error is returned from `decideNextStep`.
//
Expand Down

0 comments on commit 0bfb3db

Please sign in to comment.