Skip to content

Commit

Permalink
routing: add cancelable context to interrupt payment attempts
Browse files Browse the repository at this point in the history
  • Loading branch information
hieblmi committed May 7, 2024
1 parent 7f05ba5 commit 02072fa
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 22 deletions.
6 changes: 4 additions & 2 deletions lnrpc/routerrpc/router_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest,
}

// Send the payment asynchronously.
s.cfg.Router.SendPaymentAsync(payment, paySession, shardTracker)
s.cfg.Router.SendPaymentAsync(
stream.Context(), payment, paySession, shardTracker,
)

// Track the payment and return.
return s.trackPayment(
Expand Down Expand Up @@ -987,7 +989,7 @@ func (s *Server) SetMissionControlConfig(ctx context.Context,
req.Config.HopProbability,
),
AprioriWeight: float64(req.Config.Weight),
CapacityFraction: routing.DefaultCapacityFraction,
CapacityFraction: routing.DefaultCapacityFraction, //nolint:lll
}
}

Expand Down
43 changes: 34 additions & 9 deletions routing/payment_lifecycle.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routing

import (
"context"
"errors"
"fmt"
"time"
Expand Down Expand Up @@ -167,7 +168,9 @@ func (p *paymentLifecycle) decideNextStep(
}

// resumePayment resumes the paymentLifecycle from the current state.
func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte,
*route.Route, error) {

// When the payment lifecycle loop exits, we make sure to signal any
// sub goroutine of the HTLC attempt to exit, then wait for them to
// return.
Expand Down Expand Up @@ -222,20 +225,27 @@ lifecycle:
// We now proceed our lifecycle with the following tasks in
// order,
// 1. check timeout.
// 2. request route.
// 3. create HTLC attempt.
// 4. send HTLC attempt.
// 5. collect HTLC attempt result.
// 2. check context.
// 3. request route.
// 4. create HTLC attempt.
// 5. send HTLC attempt.
// 6. collect HTLC attempt result.
//
// Before we attempt any new shard, we'll check to see if
// either we've gone past the payment attempt timeout, or the
// router is exiting. In either case, we'll stop this payment
// attempt short. If a timeout is not applicable, timeoutChan
// will be nil.
// we've gone past the payment attempt timeout, or the context
// was cancelled, or the router is exiting. In any of these
// cases, we'll stop this payment attempt short. If a timeout is
// not applicable, timeoutChan will be nil.
if err := p.checkTimeout(); err != nil {
return exitWithErr(err)
}

// Check the cancellation status of the context before
// proceeding.
if err := p.checkContext(ctx); err != nil {
return exitWithErr(err)
}

// Now decide the next step of the current lifecycle.
step, err := p.decideNextStep(payment)
if err != nil {
Expand Down Expand Up @@ -346,6 +356,21 @@ func (p *paymentLifecycle) checkTimeout() error {
return nil
}

func (p *paymentLifecycle) checkContext(ctx context.Context) error {
if ctx == nil {
return nil
}

select {
case <-ctx.Done():
return ctx.Err()

default:
}

return nil
}

// requestRoute is responsible for finding a route to be used to create an HTLC
// attempt.
func (p *paymentLifecycle) requestRoute(
Expand Down
4 changes: 2 additions & 2 deletions routing/payment_lifecycle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func sendPaymentAndAssertFailed(t *testing.T,
// We now make a call to `resumePayment` and expect it to return the
// error.
go func() {
preimage, _, err := p.resumePayment()
preimage, _, err := p.resumePayment(nil)
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
Expand Down Expand Up @@ -189,7 +189,7 @@ func sendPaymentAndAssertSucceeded(t *testing.T,
// We now make a call to `resumePayment` and expect it to return the
// preimage.
go func() {
preimage, _, err := p.resumePayment()
preimage, _, err := p.resumePayment(nil)
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
Expand Down
19 changes: 10 additions & 9 deletions routing/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package routing

import (
"bytes"
"context"
"fmt"
"math"
"runtime"
Expand Down Expand Up @@ -720,7 +721,7 @@ func (r *ChannelRouter) Start() error {
// also set a zero fee limit, as no more routes should
// be tried.
_, _, err := r.sendPayment(
0, payment.Info.PaymentIdentifier, 0,
nil, 0, payment.Info.PaymentIdentifier, 0,
paySession, shardTracker,
)
if err != nil {
Expand Down Expand Up @@ -2408,15 +2409,15 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
// Since this is the first time this payment is being made, we pass nil
// for the existing attempt.
return r.sendPayment(
payment.FeeLimit, payment.Identifier(),
nil, payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, paySession, shardTracker,
)
}

// SendPaymentAsync is the non-blocking version of SendPayment. The payment
// result needs to be retrieved via the control tower.
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
ps PaymentSession, st shards.ShardTracker) {
func (r *ChannelRouter) SendPaymentAsync(ctx context.Context,
payment *LightningPayment, ps PaymentSession, st shards.ShardTracker) {

// Since this is the first time this payment is being made, we pass nil
// for the existing attempt.
Expand All @@ -2428,7 +2429,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
spewPayment(payment))

_, _, err := r.sendPayment(
payment.FeeLimit, payment.Identifier(),
ctx, payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, ps, st,
)
if err != nil {
Expand Down Expand Up @@ -2698,9 +2699,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
// carry out its execution. After restarts, it is safe, and assumed, that the
// router will call this method for every payment still in-flight according to
// the ControlTower.
func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
identifier lntypes.Hash, timeout time.Duration,
paySession PaymentSession,
func (r *ChannelRouter) sendPayment(ctx context.Context,
feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash,
timeout time.Duration, paySession PaymentSession,
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {

// We'll also fetch the current block height, so we can properly
Expand All @@ -2717,7 +2718,7 @@ func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
shardTracker, timeout, currentHeight,
)

return p.resumePayment()
return p.resumePayment(ctx)
}

// extractChannelUpdate examines the error and extracts the channel update.
Expand Down

0 comments on commit 02072fa

Please sign in to comment.