Skip to content

Commit

Permalink
channeldb: update QueryPayments to use sequence nr index and paginator
Browse files Browse the repository at this point in the history
Use the new paginatior strcut for payments. Add some tests which will
specifically test cases on and around the missing index we force in our
test to ensure that we properly handle this case. We also add a sanity
check in the test that checks that we can query when we have no
payments.
  • Loading branch information
carlaKC committed Jun 10, 2020
1 parent 38624e8 commit ab594ea
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 51 deletions.
99 changes: 53 additions & 46 deletions channeldb/payments.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"io"
"math"
"sort"
"time"

Expand Down Expand Up @@ -511,62 +510,70 @@ type PaymentsResponse struct {
func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) {
var resp PaymentsResponse

allPayments, err := db.FetchPayments()
if err != nil {
return resp, err
}
if err := kvdb.View(db, func(tx kvdb.RTx) error {
// Get the root payments bucket.
paymentsBucket := tx.ReadBucket(paymentsRootBucket)
if paymentsBucket == nil {
return nil
}

if len(allPayments) == 0 {
return resp, nil
}
// Get the index bucket which maps sequence number -> payment
// hash and duplicate bool. If we have a payments bucket, we
// should have an indexes bucket as well.
indexes := tx.ReadBucket(paymentsIndexBucket)
if indexes == nil {
return fmt.Errorf("index bucket does not exist")
}

indexExclusiveLimit := query.IndexOffset
// In backward pagination, if the index limit is the default 0 value,
// we set our limit to maxint to include all payments from the highest
// sequence number on.
if query.Reversed && indexExclusiveLimit == 0 {
indexExclusiveLimit = math.MaxInt64
}
// accumulatePayments gets payments with the sequence number
// and hash provided and adds them to our list of payments if
// they meet the criteria of our query. It returns the number
// of payments that were added.
accumulatePayments := func(sequenceKey, hash []byte) (bool,
error) {

for i := range allPayments {
var payment *MPPayment
r := bytes.NewReader(hash)
paymentHash, err := deserializePaymentIndex(r)
if err != nil {
return false, err
}

// If we have the max number of payments we want, exit.
if uint64(len(resp.Payments)) == query.MaxPayments {
break
}
payment, err := fetchPaymentWithSequenceNumber(
tx, paymentHash, sequenceKey,
)
if err != nil {
return false, err
}

if query.Reversed {
payment = allPayments[len(allPayments)-1-i]
// To keep compatibility with the old API, we only
// return non-succeeded payments if requested.
if payment.Status != StatusSucceeded &&
!query.IncludeIncomplete {

// In the reversed direction, skip over all payments
// that have sequence numbers greater than or equal to
// the index offset. We skip payments with equal index
// because the offset is exclusive.
if payment.SequenceNum >= indexExclusiveLimit {
continue
}
} else {
payment = allPayments[i]

// In the forward direction, skip over all payments that
// have sequence numbers less than or equal to the index
// offset. We skip payments with equal indexes because
// the index offset is exclusive.
if payment.SequenceNum <= indexExclusiveLimit {
continue
return false, err
}

// At this point, we've exhausted the offset, so we'll
// begin collecting invoices found within the range.
resp.Payments = append(resp.Payments, payment)
return true, nil
}

// To keep compatibility with the old API, we only return
// non-succeeded payments if requested.
if payment.Status != StatusSucceeded &&
!query.IncludeIncomplete {
// Create a paginator which reads from our sequence index bucket
// with the parameters provided by the payments query.
paginator := newPaginator(
indexes.ReadCursor(), query.Reversed, query.IndexOffset,
query.MaxPayments,
)

continue
// Run a paginated query, adding payments to our response.
if err := paginator.query(accumulatePayments); err != nil {
return err
}

resp.Payments = append(resp.Payments, payment)
return nil
}); err != nil {
return resp, err
}

// Need to swap the payments slice order if reversed order.
Expand All @@ -585,7 +592,7 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) {
resp.Payments[len(resp.Payments)-1].SequenceNum
}

return resp, err
return resp, nil
}

// fetchPaymentWithSequenceNumber get the payment which matches the payment hash
Expand Down
63 changes: 58 additions & 5 deletions channeldb/payments_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,24 @@ func TestRouteSerialization(t *testing.T) {
}

// deletePayment removes a payment with paymentHash from the payments database.
func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash) {
func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, seqNr uint64) {
t.Helper()

err := kvdb.Update(db, func(tx kvdb.RwTx) error {
payments := tx.ReadWriteBucket(paymentsRootBucket)

// Delete the payment bucket.
err := payments.DeleteNestedBucket(paymentHash[:])
if err != nil {
return err
}

return nil
key := make([]byte, 8)
byteOrder.PutUint64(key, seqNr)

// Delete the index that references this payment.
indexes := tx.ReadWriteBucket(paymentsIndexBucket)
return indexes.Delete(key)
})

if err != nil {
Expand Down Expand Up @@ -350,6 +356,42 @@ func TestQueryPayments(t *testing.T) {
lastIndex: 7,
expectedSeqNrs: []uint64{3, 4, 5, 6, 7},
},
{
name: "query payments reverse before index gap",
query: PaymentsQuery{
IndexOffset: 3,
MaxPayments: 7,
Reversed: true,
IncludeIncomplete: true,
},
firstIndex: 1,
lastIndex: 1,
expectedSeqNrs: []uint64{1},
},
{
name: "query payments reverse on index gap",
query: PaymentsQuery{
IndexOffset: 2,
MaxPayments: 7,
Reversed: true,
IncludeIncomplete: true,
},
firstIndex: 1,
lastIndex: 1,
expectedSeqNrs: []uint64{1},
},
{
name: "query payments forward on index gap",
query: PaymentsQuery{
IndexOffset: 2,
MaxPayments: 2,
Reversed: false,
IncludeIncomplete: true,
},
firstIndex: 3,
lastIndex: 4,
expectedSeqNrs: []uint64{3, 4},
},
}

for _, tt := range tests {
Expand All @@ -358,11 +400,16 @@ func TestQueryPayments(t *testing.T) {
t.Parallel()

db, cleanup, err := makeTestDB()
defer cleanup()

if err != nil {
t.Fatalf("unable to init db: %v", err)
}
defer cleanup()

// Make a preliminary query to make sure it's ok to
// query when we have no payments.
resp, err := db.QueryPayments(tt.query)
require.NoError(t, err)
require.Len(t, resp.Payments, 0)

// Populate the database with a set of test payments.
// We create 6 original payments, deleting the payment
Expand Down Expand Up @@ -391,7 +438,13 @@ func TestQueryPayments(t *testing.T) {

// Immediately delete the payment with index 2.
if i == 1 {
deletePayment(t, db, info.PaymentHash)
pmt, err := pControl.FetchPayment(
info.PaymentHash,
)
require.NoError(t, err)

deletePayment(t, db, info.PaymentHash,
pmt.SequenceNum)
}

// If we are on the last payment entry, add a
Expand Down

0 comments on commit ab594ea

Please sign in to comment.