Skip to content

Commit

Permalink
routing: add cancelable context to interrupt payment attempts
Browse files Browse the repository at this point in the history
  • Loading branch information
hieblmi committed May 10, 2024
1 parent f27393b commit 2ab2bd9
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 30 deletions.
4 changes: 3 additions & 1 deletion lnrpc/routerrpc/router_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest,
}

// Send the payment asynchronously.
s.cfg.Router.SendPaymentAsync(payment, paySession, shardTracker)
s.cfg.Router.SendPaymentAsync(
stream.Context(), payment, paySession, shardTracker,
)

// Track the payment and return.
return s.trackPayment(
Expand Down
71 changes: 58 additions & 13 deletions routing/payment_lifecycle.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routing

import (
"context"
"errors"
"fmt"
"time"
Expand Down Expand Up @@ -167,7 +168,9 @@ func (p *paymentLifecycle) decideNextStep(
}

// resumePayment resumes the paymentLifecycle from the current state.
func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte,
*route.Route, error) {

// When the payment lifecycle loop exits, we make sure to signal any
// sub goroutine of the HTLC attempt to exit, then wait for them to
// return.
Expand Down Expand Up @@ -221,18 +224,18 @@ lifecycle:

// We now proceed our lifecycle with the following tasks in
// order,
// 1. check timeout.
// 1. check timeout and context.
// 2. request route.
// 3. create HTLC attempt.
// 4. send HTLC attempt.
// 5. collect HTLC attempt result.
//
// Before we attempt any new shard, we'll check to see if
// either we've gone past the payment attempt timeout, or the
// router is exiting. In either case, we'll stop this payment
// attempt short. If a timeout is not applicable, timeoutChan
// will be nil.
if err := p.checkTimeout(); err != nil {
// we've gone past the payment attempt timeout, or if the
// context was cancelled, or the router is exiting. In any of
// these cases, we'll stop this payment attempt short. If a
// timeout is not applicable, timeoutChan will be nil.
if err := p.checkTimeoutAndContext(ctx); err != nil {
return exitWithErr(err)
}

Expand Down Expand Up @@ -319,27 +322,69 @@ lifecycle:
}

// checkTimeout checks whether the payment has reached its timeout.
func (p *paymentLifecycle) checkTimeout() error {
func (p *paymentLifecycle) checkTimeoutAndContext(ctx context.Context) error {
failPayment := func(reason channeldb.FailureReason) error {
// By marking the payment failed, depending on whether it has
// inflight HTLCs or not, its status will now either be
// `StatusInflight` or `StatusFailed`. In either case, no more
// HTLCs will be attempted.
err := p.router.cfg.Control.FailPayment(p.identifier, reason)
if err != nil {
return fmt.Errorf("FailPayment got %w", err)
}

return nil
}

select {
case <-p.timeoutChan:
log.Warnf("payment attempt not completed before timeout")
err := failPayment(channeldb.FailureReasonTimeout)
if err != nil {
return err
}

case <-ctx.Done():
log.Warnf("payment attempt context canceled")
err := failPayment(channeldb.FailureReasonError)
if err != nil {
return err
}

return ctx.Err()

case <-p.router.quit:
return fmt.Errorf("check payment timeout got: %w",
ErrRouterShuttingDown)

// Fall through if we haven't hit our time limit.
default:
}

return nil
}

// checkContext checks whether the payment context has been cancelled.
func (p *paymentLifecycle) checkContext(ctx context.Context) error {

Check failure on line 368 in routing/payment_lifecycle.go

View workflow job for this annotation

GitHub Actions / lint code

func `(*paymentLifecycle).checkContext` is unused (unused)
if ctx == nil {
return nil
}

select {
case <-ctx.Done():
// By marking the payment failed, depending on whether it has
// inflight HTLCs or not, its status will now either be
// `StatusInflight` or `StatusFailed`. In either case, no more
// HTLCs will be attempted.
err := p.router.cfg.Control.FailPayment(
p.identifier, channeldb.FailureReasonTimeout,
p.identifier, channeldb.FailureReasonError,
)
if err != nil {
return fmt.Errorf("FailPayment got %w", err)
}

case <-p.router.quit:
return fmt.Errorf("check payment timeout got: %w",
ErrRouterShuttingDown)
return ctx.Err()

// Fall through if we haven't hit our time limit.
default:
}

Expand Down
92 changes: 85 additions & 7 deletions routing/payment_lifecycle_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routing

import (
"context"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -153,15 +154,15 @@ type resumePaymentResult struct {

// sendPaymentAndAssertFailed calls `resumePayment` and asserts that an error
// is returned.
func sendPaymentAndAssertFailed(t *testing.T,
p *paymentLifecycle, errExpected error) {
func sendPaymentAndAssertFailed(t *testing.T, p *paymentLifecycle,
errExpected error) {

resultChan := make(chan *resumePaymentResult, 1)

// We now make a call to `resumePayment` and expect it to return the
// error.
go func() {
preimage, _, err := p.resumePayment()
preimage, _, err := p.resumePayment(context.Background())
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
Expand All @@ -179,6 +180,34 @@ func sendPaymentAndAssertFailed(t *testing.T,
}
}

// sendPaymentAndAssertFailed calls `resumePayment` and asserts that an error
// is returned.
func sendPaymentAndAssertContextCancelled(t *testing.T,
ctx context.Context, p *paymentLifecycle, errExpected error) {

resultChan := make(chan *resumePaymentResult, 1)

// We now make a call to `resumePayment` and expect it to return the
// error.
go func() {
preimage, _, err := p.resumePayment(ctx)
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
}
}()

// Validate the returned values or timeout.
select {
case r := <-resultChan:
require.Error(t, r.err, errExpected, "expected error")
require.Empty(t, r.preimage, "preimage should be empty")

case <-time.After(testTimeout):
require.Fail(t, "timeout waiting for result")
}
}

// sendPaymentAndAssertSucceeded calls `resumePayment` and asserts that the
// returned preimage is correct.
func sendPaymentAndAssertSucceeded(t *testing.T,
Expand All @@ -189,7 +218,7 @@ func sendPaymentAndAssertSucceeded(t *testing.T,
// We now make a call to `resumePayment` and expect it to return the
// preimage.
go func() {
preimage, _, err := p.resumePayment()
preimage, _, err := p.resumePayment(context.Background())
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
Expand Down Expand Up @@ -278,6 +307,9 @@ func makeAttemptInfo(t *testing.T, amt int) channeldb.HTLCAttemptInfo {
func TestCheckTimeoutTimedOut(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

p := createTestPaymentLifecycle()

// Mock the control tower's `FailPayment` method.
Expand All @@ -295,7 +327,7 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
time.Sleep(1 * time.Millisecond)

// Call the function and expect no error.
err := p.checkTimeout()
err := p.checkTimeoutAndContext(ctx)
require.NoError(t, err)

// Assert that `FailPayment` is called as expected.
Expand All @@ -319,7 +351,7 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
time.Sleep(1 * time.Millisecond)

// Call the function and expect an error.
err = p.checkTimeout()
err = p.checkTimeoutAndContext(ctx)
require.ErrorIs(t, err, errDummy)

// Assert that `FailPayment` is called as expected.
Expand All @@ -331,10 +363,13 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
func TestCheckTimeoutOnRouterQuit(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

p := createTestPaymentLifecycle()

close(p.router.quit)
err := p.checkTimeout()
err := p.checkTimeoutAndContext(ctx)
require.ErrorIs(t, err, ErrRouterShuttingDown)
}

Expand Down Expand Up @@ -727,6 +762,49 @@ func TestResumePaymentFailOnTimeoutErr(t *testing.T) {
require.Zero(t, m.collectResultsCount)
}

// TestResumePaymentFailContextCancel checks that the lifecycle fails when the
// context is canceled and an error is returned from `checkContext`.
//
// NOTE: No parallel test because it overwrites global variables.
//
//nolint:paralleltest
func TestResumePaymentFailContextCancel(t *testing.T) {
// Create a test paymentLifecycle with the initial two calls mocked.
p, m := setupTestPaymentLifecycle(t)

paymentAmt := lnwire.MilliSatoshi(10000)

// We now enter the payment lifecycle loop.
//
// 1. calls `FetchPayment` and return the payment.
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()

// 2. calls `GetState` and return the state.
ps := &channeldb.MPPaymentState{
RemainingAmt: paymentAmt,
}
m.payment.On("GetState").Return(ps).Once()

// NOTE: GetStatus is only used to populate the logs which is not
// critical, so we loosen the checks on how many times it's been called.
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)

// 3. Cancel the context and skip the FailPayment error to trigger the
// context cancellation of the payment.
m.control.On(
"FailPayment", p.identifier, channeldb.FailureReasonError,
).Return(nil).Once()

ctx, cancel := context.WithCancel(context.Background())
cancel()

// Send the payment and assert that its context got cancelled.
sendPaymentAndAssertContextCancelled(t, ctx, p, ErrContextCancelled)

// Expected collectResultAsync to not be called.
require.Zero(t, m.collectResultsCount)
}

// TestResumePaymentFailOnStepErr checks that the lifecycle fails when an
// error is returned from `decideNextStep`.
//
Expand Down
31 changes: 22 additions & 9 deletions routing/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package routing

import (
"bytes"
"context"
"fmt"
"math"
"runtime"
Expand Down Expand Up @@ -92,6 +93,10 @@ var (
// shutting down.
ErrRouterShuttingDown = fmt.Errorf("router shutting down")

// ErrContextCancelled is returned if the context of the payment attempt
// is canceled.
ErrContextCancelled = errors.New("context cancelled")

// ErrSelfIntro is a failure returned when the source node of a
// route request is also the introduction node. This is not yet
// supported because LND does not support blinded forwardingg.
Expand Down Expand Up @@ -680,6 +685,10 @@ func (r *ChannelRouter) Start() error {
go func(payment *channeldb.MPPayment) {
defer r.wg.Done()

// Cancelable context for sendPayment.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Get the hashes used for the outstanding HTLCs.
htlcs := make(map[uint64]lntypes.Hash)
for _, a := range payment.HTLCs {
Expand Down Expand Up @@ -720,7 +729,7 @@ func (r *ChannelRouter) Start() error {
// also set a zero fee limit, as no more routes should
// be tried.
_, _, err := r.sendPayment(
0, payment.Info.PaymentIdentifier, 0,
ctx, 0, payment.Info.PaymentIdentifier, 0,
paySession, shardTracker,
)
if err != nil {
Expand Down Expand Up @@ -2397,6 +2406,10 @@ func (l *LightningPayment) Identifier() [32]byte {
func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
*route.Route, error) {

// Cancelable context for sendPayment.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

paySession, shardTracker, err := r.PreparePayment(payment)
if err != nil {
return [32]byte{}, nil, err
Expand All @@ -2408,15 +2421,15 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
// Since this is the first time this payment is being made, we pass nil
// for the existing attempt.
return r.sendPayment(
payment.FeeLimit, payment.Identifier(),
ctx, payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, paySession, shardTracker,
)
}

// SendPaymentAsync is the non-blocking version of SendPayment. The payment
// result needs to be retrieved via the control tower.
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
ps PaymentSession, st shards.ShardTracker) {
func (r *ChannelRouter) SendPaymentAsync(ctx context.Context,
payment *LightningPayment, ps PaymentSession, st shards.ShardTracker) {

// Since this is the first time this payment is being made, we pass nil
// for the existing attempt.
Expand All @@ -2428,7 +2441,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
spewPayment(payment))

_, _, err := r.sendPayment(
payment.FeeLimit, payment.Identifier(),
ctx, payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, ps, st,
)
if err != nil {
Expand Down Expand Up @@ -2698,9 +2711,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
// carry out its execution. After restarts, it is safe, and assumed, that the
// router will call this method for every payment still in-flight according to
// the ControlTower.
func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
identifier lntypes.Hash, timeout time.Duration,
paySession PaymentSession,
func (r *ChannelRouter) sendPayment(ctx context.Context,
feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash,
timeout time.Duration, paySession PaymentSession,
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {

// We'll also fetch the current block height, so we can properly
Expand All @@ -2717,7 +2730,7 @@ func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
shardTracker, timeout, currentHeight,
)

return p.resumePayment()
return p.resumePayment(ctx)
}

// extractChannelUpdate examines the error and extracts the channel update.
Expand Down

0 comments on commit 2ab2bd9

Please sign in to comment.