Skip to content

Commit

Permalink
routing: add cancelable context to interrupt payment attempts
Browse files Browse the repository at this point in the history
  • Loading branch information
hieblmi committed May 8, 2024
1 parent 481b761 commit 4fa701d
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 22 deletions.
6 changes: 4 additions & 2 deletions lnrpc/routerrpc/router_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest,
}

// Send the payment asynchronously.
s.cfg.Router.SendPaymentAsync(payment, paySession, shardTracker)
s.cfg.Router.SendPaymentAsync(
stream.Context(), payment, paySession, shardTracker,
)

// Track the payment and return.
return s.trackPayment(
Expand Down Expand Up @@ -987,7 +989,7 @@ func (s *Server) SetMissionControlConfig(ctx context.Context,
req.Config.HopProbability,
),
AprioriWeight: float64(req.Config.Weight),
CapacityFraction: routing.DefaultCapacityFraction,
CapacityFraction: routing.DefaultCapacityFraction, //nolint:lll
}
}

Expand Down
55 changes: 46 additions & 9 deletions routing/payment_lifecycle.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routing

import (
"context"
"errors"
"fmt"
"time"
Expand Down Expand Up @@ -167,7 +168,9 @@ func (p *paymentLifecycle) decideNextStep(
}

// resumePayment resumes the paymentLifecycle from the current state.
func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte,
*route.Route, error) {

// When the payment lifecycle loop exits, we make sure to signal any
// sub goroutine of the HTLC attempt to exit, then wait for them to
// return.
Expand Down Expand Up @@ -222,20 +225,27 @@ lifecycle:
// We now proceed our lifecycle with the following tasks in
// order,
// 1. check timeout.
// 2. request route.
// 3. create HTLC attempt.
// 4. send HTLC attempt.
// 5. collect HTLC attempt result.
// 2. check context.
// 3. request route.
// 4. create HTLC attempt.
// 5. send HTLC attempt.
// 6. collect HTLC attempt result.
//
// Before we attempt any new shard, we'll check to see if
// either we've gone past the payment attempt timeout, or the
// router is exiting. In either case, we'll stop this payment
// attempt short. If a timeout is not applicable, timeoutChan
// will be nil.
// we've gone past the payment attempt timeout, or the context
// was cancelled, or the router is exiting. In any of these
// cases, we'll stop this payment attempt short. If a timeout is
// not applicable, timeoutChan will be nil.
if err := p.checkTimeout(); err != nil {
return exitWithErr(err)
}

// Check the cancellation status of the context before
// proceeding.
if err := p.checkContext(ctx); err != nil {
return exitWithErr(err)
}

// Now decide the next step of the current lifecycle.
step, err := p.decideNextStep(payment)
if err != nil {
Expand Down Expand Up @@ -346,6 +356,33 @@ func (p *paymentLifecycle) checkTimeout() error {
return nil
}

// checkContext checks whether the payment context has been cancelled.
func (p *paymentLifecycle) checkContext(ctx context.Context) error {
if ctx == nil {
return nil
}

select {
case <-ctx.Done():
// By marking the payment failed, depending on whether it has
// inflight HTLCs or not, its status will now either be
// `StatusInflight` or `StatusFailed`. In either case, no more
// HTLCs will be attempted.
err := p.router.cfg.Control.FailPayment(
p.identifier, channeldb.FailureReasonError,
)
if err != nil {
return fmt.Errorf("FailPayment got %w", err)
}

return ctx.Err()

default:
}

return nil
}

// requestRoute is responsible for finding a route to be used to create an HTLC
// attempt.
func (p *paymentLifecycle) requestRoute(
Expand Down
76 changes: 74 additions & 2 deletions routing/payment_lifecycle_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routing

import (
"context"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -161,7 +162,7 @@ func sendPaymentAndAssertFailed(t *testing.T,
// We now make a call to `resumePayment` and expect it to return the
// error.
go func() {
preimage, _, err := p.resumePayment()
preimage, _, err := p.resumePayment(nil)
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
Expand All @@ -179,6 +180,34 @@ func sendPaymentAndAssertFailed(t *testing.T,
}
}

// sendPaymentAndAssertFailed calls `resumePayment` and asserts that an error
// is returned.
func sendPaymentAndAssertContextCancelled(t *testing.T,
ctx context.Context, p *paymentLifecycle, errExpected error) {

resultChan := make(chan *resumePaymentResult, 1)

// We now make a call to `resumePayment` and expect it to return the
// error.
go func() {
preimage, _, err := p.resumePayment(ctx)
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
}
}()

// Validate the returned values or timeout.
select {
case r := <-resultChan:
require.Error(t, r.err, errExpected, "expected error")
require.Empty(t, r.preimage, "preimage should be empty")

case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for result")
}
}

// sendPaymentAndAssertSucceeded calls `resumePayment` and asserts that the
// returned preimage is correct.
func sendPaymentAndAssertSucceeded(t *testing.T,
Expand All @@ -189,7 +218,7 @@ func sendPaymentAndAssertSucceeded(t *testing.T,
// We now make a call to `resumePayment` and expect it to return the
// preimage.
go func() {
preimage, _, err := p.resumePayment()
preimage, _, err := p.resumePayment(nil)
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
Expand Down Expand Up @@ -727,6 +756,49 @@ func TestResumePaymentFailOnTimeoutErr(t *testing.T) {
require.Zero(t, m.collectResultsCount)
}

// TestResumePaymentFailContextCancel checks that the lifecycle fails when the
// context is canceled and an error is returned from `checkContext`.
//
// NOTE: No parallel test because it overwrites global variables.
//
//nolint:paralleltest
func TestResumePaymentFailContextCancel(t *testing.T) {
// Create a test paymentLifecycle with the initial two calls mocked.
p, m := setupTestPaymentLifecycle(t)

paymentAmt := lnwire.MilliSatoshi(10000)

// We now enter the payment lifecycle loop.
//
// 1. calls `FetchPayment` and return the payment.
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()

// 2. calls `GetState` and return the state.
ps := &channeldb.MPPaymentState{
RemainingAmt: paymentAmt,
}
m.payment.On("GetState").Return(ps).Once()

// NOTE: GetStatus is only used to populate the logs which is not
// critical, so we loosen the checks on how many times it's been called.
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)

// 3. Cancel the context and skip the FailPayment error to trigger the
// context cancellation of the payment.
m.control.On(
"FailPayment", p.identifier, channeldb.FailureReasonError,
).Return(nil).Once()

ctx, cancel := context.WithCancel(context.Background())
cancel()

// Send the payment and assert that its context got cancelled.
sendPaymentAndAssertContextCancelled(t, ctx, p, ErrContextCancelled)

// Expected collectResultAsync to not be called.
require.Zero(t, m.collectResultsCount)
}

// TestResumePaymentFailOnStepErr checks that the lifecycle fails when an
// error is returned from `decideNextStep`.
//
Expand Down
23 changes: 14 additions & 9 deletions routing/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package routing

import (
"bytes"
"context"
"fmt"
"math"
"runtime"
Expand Down Expand Up @@ -92,6 +93,10 @@ var (
// shutting down.
ErrRouterShuttingDown = fmt.Errorf("router shutting down")

// ErrContextCancelled is returned if the context of the payment attempt
// is canceled.
ErrContextCancelled = errors.New("context cancelled")

// ErrSelfIntro is a failure returned when the source node of a
// route request is also the introduction node. This is not yet
// supported because LND does not support blinded forwardingg.
Expand Down Expand Up @@ -720,7 +725,7 @@ func (r *ChannelRouter) Start() error {
// also set a zero fee limit, as no more routes should
// be tried.
_, _, err := r.sendPayment(
0, payment.Info.PaymentIdentifier, 0,
nil, 0, payment.Info.PaymentIdentifier, 0,
paySession, shardTracker,
)
if err != nil {
Expand Down Expand Up @@ -2408,15 +2413,15 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
// Since this is the first time this payment is being made, we pass nil
// for the existing attempt.
return r.sendPayment(
payment.FeeLimit, payment.Identifier(),
nil, payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, paySession, shardTracker,
)
}

// SendPaymentAsync is the non-blocking version of SendPayment. The payment
// result needs to be retrieved via the control tower.
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
ps PaymentSession, st shards.ShardTracker) {
func (r *ChannelRouter) SendPaymentAsync(ctx context.Context,
payment *LightningPayment, ps PaymentSession, st shards.ShardTracker) {

// Since this is the first time this payment is being made, we pass nil
// for the existing attempt.
Expand All @@ -2428,7 +2433,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
spewPayment(payment))

_, _, err := r.sendPayment(
payment.FeeLimit, payment.Identifier(),
ctx, payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, ps, st,
)
if err != nil {
Expand Down Expand Up @@ -2698,9 +2703,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
// carry out its execution. After restarts, it is safe, and assumed, that the
// router will call this method for every payment still in-flight according to
// the ControlTower.
func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
identifier lntypes.Hash, timeout time.Duration,
paySession PaymentSession,
func (r *ChannelRouter) sendPayment(ctx context.Context,
feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash,
timeout time.Duration, paySession PaymentSession,
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {

// We'll also fetch the current block height, so we can properly
Expand All @@ -2717,7 +2722,7 @@ func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
shardTracker, timeout, currentHeight,
)

return p.resumePayment()
return p.resumePayment(ctx)
}

// extractChannelUpdate examines the error and extracts the channel update.
Expand Down

0 comments on commit 4fa701d

Please sign in to comment.