diff --git a/channeldb/codec.go b/channeldb/codec.go index 78d6169476..f6903175f8 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -192,6 +192,11 @@ func WriteElement(w io.Writer, element interface{}) error { return err } + case paymentIndexType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + case lnwire.FundingFlag: if err := binary.Write(w, byteOrder, e); err != nil { return err @@ -406,6 +411,11 @@ func ReadElement(r io.Reader, element interface{}) error { return err } + case *paymentIndexType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + case *lnwire.FundingFlag: if err := binary.Read(r, byteOrder, e); err != nil { return err diff --git a/channeldb/db.go b/channeldb/db.go index 1347a58fbb..564dbb7be2 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -16,6 +16,7 @@ import ( mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration13" + "github.com/lightningnetwork/lnd/channeldb/migration16" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lnwire" @@ -144,6 +145,19 @@ var ( number: 14, migration: mig.CreateTLB(payAddrIndexBucket), }, + { + // Initialize payment index bucket which will be used + // to index payments by sequence number. This index will + // be used to allow more efficient ListPayments queries. + number: 15, + migration: mig.CreateTLB(paymentsIndexBucket), + }, + { + // Add our existing payments to the index bucket created + // in migration 15. + number: 16, + migration: migration16.MigrateSequenceIndex, + }, } // Big endian is the preferred byte order, due to cursor scans over @@ -257,6 +271,7 @@ var topLevelBuckets = [][]byte{ fwdPackagesKey, invoiceBucket, payAddrIndexBucket, + paymentsIndexBucket, nodeInfoBucket, nodeBucket, edgeBucket, diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index eea9df034a..e0ec219128 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -1007,6 +1007,18 @@ func TestQueryInvoices(t *testing.T) { // still pending. expected: pendingInvoices[len(pendingInvoices)-15:], }, + // Fetch all invoices paginating backwards, with an index offset + // that is beyond our last offset. We expect all invoices to be + // returned. + { + query: InvoiceQuery{ + IndexOffset: numInvoices * 2, + PendingOnly: false, + Reversed: true, + NumMaxInvoices: numInvoices, + }, + expected: invoices, + }, } for i, testCase := range testCases { diff --git a/channeldb/invoices.go b/channeldb/invoices.go index a7ed432454..07de2add23 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -839,85 +839,47 @@ func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { if invoices == nil { return ErrNoInvoicesCreated } + + // Get the add index bucket which we will use to iterate through + // our indexed invoices. invoiceAddIndex := invoices.NestedReadBucket(addIndexBucket) if invoiceAddIndex == nil { return ErrNoInvoicesCreated } - // keyForIndex is a helper closure that retrieves the invoice - // key for the given add index of an invoice. - keyForIndex := func(c kvdb.RCursor, index uint64) []byte { - var keyIndex [8]byte - byteOrder.PutUint64(keyIndex[:], index) - _, invoiceKey := c.Seek(keyIndex[:]) - return invoiceKey - } - - // nextKey is a helper closure to determine what the next - // invoice key is when iterating over the invoice add index. - nextKey := func(c kvdb.RCursor) ([]byte, []byte) { - if q.Reversed { - return c.Prev() - } - return c.Next() - } - - // We'll be using a cursor to seek into the database and return - // a slice of invoices. We'll need to determine where to start - // our cursor depending on the parameters set within the query. - c := invoiceAddIndex.ReadCursor() - invoiceKey := keyForIndex(c, q.IndexOffset+1) - - // If the query is specifying reverse iteration, then we must - // handle a few offset cases. - if q.Reversed { - switch q.IndexOffset { - - // This indicates the default case, where no offset was - // specified. In that case we just start from the last - // invoice. - case 0: - _, invoiceKey = c.Last() - - // This indicates the offset being set to the very - // first invoice. Since there are no invoices before - // this offset, and the direction is reversed, we can - // return without adding any invoices to the response. - case 1: - return nil - - // Otherwise we start iteration at the invoice prior to - // the offset. - default: - invoiceKey = keyForIndex(c, q.IndexOffset-1) - } - } - - // If we know that a set of invoices exists, then we'll begin - // our seek through the bucket in order to satisfy the query. - // We'll continue until either we reach the end of the range, or - // reach our max number of invoices. - for ; invoiceKey != nil; _, invoiceKey = nextKey(c) { - // If our current return payload exceeds the max number - // of invoices, then we'll exit now. - if uint64(len(resp.Invoices)) >= q.NumMaxInvoices { - break - } + // Create a paginator which reads from our add index bucket with + // the parameters provided by the invoice query. + paginator := newPaginator( + invoiceAddIndex.ReadCursor(), q.Reversed, q.IndexOffset, + q.NumMaxInvoices, + ) - invoice, err := fetchInvoice(invoiceKey, invoices) + // accumulateInvoices looks up an invoice based on the index we + // are given, adds it to our set of invoices if it has the right + // characteristics for our query and returns the number of items + // we have added to our set of invoices. + accumulateInvoices := func(_, indexValue []byte) (bool, error) { + invoice, err := fetchInvoice(indexValue, invoices) if err != nil { - return err + return false, err } - // Skip any settled or canceled invoices if the caller is - // only interested in pending ones. + // Skip any settled or canceled invoices if the caller + // is only interested in pending ones. if q.PendingOnly && !invoice.IsPending() { - continue + return false, nil } // At this point, we've exhausted the offset, so we'll // begin collecting invoices found within the range. resp.Invoices = append(resp.Invoices, invoice) + return true, nil + } + + // Query our paginator using accumulateInvoices to build up a + // set of invoices. + if err := paginator.query(accumulateInvoices); err != nil { + return err } // If we iterated through the add index in reverse order, then diff --git a/channeldb/log.go b/channeldb/log.go index f59426f0a9..75ba2a5f7e 100644 --- a/channeldb/log.go +++ b/channeldb/log.go @@ -6,6 +6,7 @@ import ( mig "github.com/lightningnetwork/lnd/channeldb/migration" "github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration13" + "github.com/lightningnetwork/lnd/channeldb/migration16" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" ) @@ -33,4 +34,5 @@ func UseLogger(logger btclog.Logger) { migration_01_to_11.UseLogger(logger) migration12.UseLogger(logger) migration13.UseLogger(logger) + migration16.UseLogger(logger) } diff --git a/channeldb/migration16/log.go b/channeldb/migration16/log.go new file mode 100644 index 0000000000..cb946854cf --- /dev/null +++ b/channeldb/migration16/log.go @@ -0,0 +1,14 @@ +package migration16 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/channeldb/migration16/migration.go b/channeldb/migration16/migration.go new file mode 100644 index 0000000000..b984f08378 --- /dev/null +++ b/channeldb/migration16/migration.go @@ -0,0 +1,191 @@ +package migration16 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb/kvdb" +) + +var ( + paymentsRootBucket = []byte("payments-root-bucket") + + paymentSequenceKey = []byte("payment-sequence-key") + + duplicatePaymentsBucket = []byte("payment-duplicate-bucket") + + paymentsIndexBucket = []byte("payments-index-bucket") + + byteOrder = binary.BigEndian +) + +// paymentIndexType indicates the type of index we have recorded in the payment +// indexes bucket. +type paymentIndexType uint8 + +// paymentIndexTypeHash is a payment index type which indicates that we have +// created an index of payment sequence number to payment hash. +const paymentIndexTypeHash paymentIndexType = 0 + +// paymentIndex stores all the information we require to create an index by +// sequence number for a payment. +type paymentIndex struct { + // paymentHash is the hash of the payment, which is its key in the + // payment root bucket. + paymentHash []byte + + // sequenceNumbers is the set of sequence numbers associated with this + // payment hash. There will be more than one sequence number in the + // case where duplicate payments are present. + sequenceNumbers [][]byte +} + +// MigrateSequenceIndex migrates the payments db to contain a new bucket which +// provides an index from sequence number to payment hash. This is required +// for more efficient sequential lookup of payments, which are keyed by payment +// hash before this migration. +func MigrateSequenceIndex(tx kvdb.RwTx) error { + log.Infof("Migrating payments to add sequence number index") + + // Get a list of indices we need to write. + indexList, err := getPaymentIndexList(tx) + if err != nil { + return err + } + + // Create the top level bucket that we will use to index payments in. + bucket, err := tx.CreateTopLevelBucket(paymentsIndexBucket) + if err != nil { + return err + } + + // Write an index for each of our payments. + for _, index := range indexList { + // Write indexes for each of our sequence numbers. + for _, seqNr := range index.sequenceNumbers { + err := putIndex(bucket, seqNr, index.paymentHash) + if err != nil { + return err + } + } + } + + return nil +} + +// putIndex performs a sanity check that ensures we are not writing duplicate +// indexes to disk then creates the index provided. +func putIndex(bucket kvdb.RwBucket, sequenceNr, paymentHash []byte) error { + // Add a sanity check that we do not already have an entry with + // this sequence number. + existingEntry := bucket.Get(sequenceNr) + if existingEntry != nil { + return fmt.Errorf("sequence number: %x duplicated", + sequenceNr) + } + + bytes, err := serializePaymentIndexEntry(paymentHash) + if err != nil { + return err + } + + return bucket.Put(sequenceNr, bytes) +} + +// serializePaymentIndexEntry serializes a payment hash typed index. The value +// produced contains a payment index type (which can be used in future to +// signal different payment index types) and the payment hash. +func serializePaymentIndexEntry(hash []byte) ([]byte, error) { + var b bytes.Buffer + + err := binary.Write(&b, byteOrder, paymentIndexTypeHash) + if err != nil { + return nil, err + } + + if err := wire.WriteVarBytes(&b, 0, hash); err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +// getPaymentIndexList gets a list of indices we need to write for our current +// set of payments. +func getPaymentIndexList(tx kvdb.RTx) ([]paymentIndex, error) { + // Iterate over all payments and store their indexing keys. This is + // needed, because no modifications are allowed inside a Bucket.ForEach + // loop. + paymentsBucket := tx.ReadBucket(paymentsRootBucket) + if paymentsBucket == nil { + return nil, nil + } + + var indexList []paymentIndex + err := paymentsBucket.ForEach(func(k, v []byte) error { + // Get the bucket which contains the payment, fail if the key + // does not have a bucket. + bucket := paymentsBucket.NestedReadBucket(k) + if bucket == nil { + return fmt.Errorf("non bucket element in " + + "payments bucket") + } + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return fmt.Errorf("nil sequence number bytes") + } + + seqNrs, err := fetchSequenceNumbers(bucket) + if err != nil { + return err + } + + // Create an index object with our payment hash and sequence + // numbers and append it to our set of indexes. + index := paymentIndex{ + paymentHash: k, + sequenceNumbers: seqNrs, + } + + indexList = append(indexList, index) + return nil + }) + if err != nil { + return nil, err + } + + return indexList, nil +} + +// fetchSequenceNumbers fetches all the sequence numbers associated with a +// payment, including those belonging to any duplicate payments. +func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) { + seqNum := paymentBucket.Get(paymentSequenceKey) + if seqNum == nil { + return nil, errors.New("expected sequence number") + } + + sequenceNumbers := [][]byte{seqNum} + + // Get the duplicate payments bucket, if it has no duplicates, just + // return early with the payment sequence number. + duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket) + if duplicates == nil { + return sequenceNumbers, nil + } + + // If we do have duplicated, they are keyed by sequence number, so we + // iterate through the duplicates bucket and add them to our set of + // sequence numbers. + if err := duplicates.ForEach(func(k, v []byte) error { + sequenceNumbers = append(sequenceNumbers, k) + return nil + }); err != nil { + return nil, err + } + + return sequenceNumbers, nil +} diff --git a/channeldb/migration16/migration_test.go b/channeldb/migration16/migration_test.go new file mode 100644 index 0000000000..626bedcb50 --- /dev/null +++ b/channeldb/migration16/migration_test.go @@ -0,0 +1,144 @@ +package migration16 + +import ( + "encoding/hex" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/kvdb" + "github.com/lightningnetwork/lnd/channeldb/migtest" +) + +var ( + hexStr = migtest.Hex + + hash1Str = "02acee76ebd53d00824410cf6adecad4f50334dac702bd5a2d3ba01b91709f0e" + hash1 = hexStr(hash1Str) + paymentID1 = hexStr("0000000000000001") + + hash2Str = "62eb3f0a48f954e495d0c14ac63df04a67cefa59dafdbcd3d5046d1f5647840c" + hash2 = hexStr(hash2Str) + paymentID2 = hexStr("0000000000000002") + + paymentID3 = hexStr("0000000000000003") + + // pre is the data in the payments root bucket in database version 13 format. + pre = map[string]interface{}{ + // A payment without duplicates. + hash1: map[string]interface{}{ + "payment-sequence-key": paymentID1, + }, + + // A payment with a duplicate. + hash2: map[string]interface{}{ + "payment-sequence-key": paymentID2, + "payment-duplicate-bucket": map[string]interface{}{ + paymentID3: map[string]interface{}{ + "payment-sequence-key": paymentID3, + }, + }, + }, + } + + preFails = map[string]interface{}{ + // A payment without duplicates. + hash1: map[string]interface{}{ + "payment-sequence-key": paymentID1, + "payment-duplicate-bucket": map[string]interface{}{ + paymentID1: map[string]interface{}{ + "payment-sequence-key": paymentID1, + }, + }, + }, + } + + // post is the expected data after migration. + post = map[string]interface{}{ + paymentID1: paymentHashIndex(hash1Str), + paymentID2: paymentHashIndex(hash2Str), + paymentID3: paymentHashIndex(hash2Str), + } +) + +// paymentHashIndex produces a string that represents the value we expect for +// our payment indexes from a hex encoded payment hash string. +func paymentHashIndex(hashStr string) string { + hash, err := hex.DecodeString(hashStr) + if err != nil { + panic(err) + } + + bytes, err := serializePaymentIndexEntry(hash) + if err != nil { + panic(err) + } + + return string(bytes) +} + +// MigrateSequenceIndex asserts that the database is properly migrated to +// contain a payments index. +func TestMigrateSequenceIndex(t *testing.T) { + tests := []struct { + name string + shouldFail bool + pre map[string]interface{} + post map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + pre: pre, + post: post, + }, + { + name: "duplicate sequence number", + shouldFail: true, + pre: preFails, + post: post, + }, + { + name: "no payments", + shouldFail: false, + pre: nil, + post: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + // Before the migration we have a payments bucket. + before := func(tx kvdb.RwTx) error { + return migtest.RestoreDB( + tx, paymentsRootBucket, test.pre, + ) + } + + // After the migration, we should have an untouched + // payments bucket and a new index bucket. + after := func(tx kvdb.RwTx) error { + if err := migtest.VerifyDB( + tx, paymentsRootBucket, test.pre, + ); err != nil { + return err + } + + // If we expect our migration to fail, we don't + // expect an index bucket. + if test.shouldFail { + return nil + } + + return migtest.VerifyDB( + tx, paymentsIndexBucket, test.post, + ) + } + + migtest.ApplyMigration( + t, before, after, MigrateSequenceIndex, + test.shouldFail, + ) + }) + } +} diff --git a/channeldb/paginate.go b/channeldb/paginate.go new file mode 100644 index 0000000000..22ec4fb465 --- /dev/null +++ b/channeldb/paginate.go @@ -0,0 +1,140 @@ +package channeldb + +import "github.com/lightningnetwork/lnd/channeldb/kvdb" + +type paginator struct { + // cursor is the cursor which we are using to iterate through a bucket. + cursor kvdb.RCursor + + // reversed indicates whether we are paginating forwards or backwards. + reversed bool + + // indexOffset is the index from which we will begin querying. + indexOffset uint64 + + // totalItems is the total number of items we allow in our response. + totalItems uint64 +} + +// newPaginator returns a struct which can be used to query an indexed bucket +// in pages. +func newPaginator(c kvdb.RCursor, reversed bool, + indexOffset, totalItems uint64) paginator { + + return paginator{ + cursor: c, + reversed: reversed, + indexOffset: indexOffset, + totalItems: totalItems, + } +} + +// keyValueForIndex seeks our cursor to a given index and returns the key and +// value at that position. +func (p paginator) keyValueForIndex(index uint64) ([]byte, []byte) { + var keyIndex [8]byte + byteOrder.PutUint64(keyIndex[:], index) + return p.cursor.Seek(keyIndex[:]) +} + +// lastIndex returns the last value in our index, if our index is empty it +// returns 0. +func (p paginator) lastIndex() uint64 { + keyIndex, _ := p.cursor.Last() + if keyIndex == nil { + return 0 + } + + return byteOrder.Uint64(keyIndex) +} + +// nextKey is a helper closure to determine what key we should use next when +// we are iterating, depending on whether we are iterating forwards or in +// reverse. +func (p paginator) nextKey() ([]byte, []byte) { + if p.reversed { + return p.cursor.Prev() + } + return p.cursor.Next() +} + +// cursorStart gets the index key and value for the first item we are looking +// up, taking into account that we may be paginating in reverse. The index +// offset provided is *excusive* so we will start with the item after the offset +// for forwards queries, and the item before the index for backwards queries. +func (p paginator) cursorStart() ([]byte, []byte) { + indexKey, indexValue := p.keyValueForIndex(p.indexOffset + 1) + + // If the query is specifying reverse iteration, then we must + // handle a few offset cases. + if p.reversed { + switch { + + // This indicates the default case, where no offset was + // specified. In that case we just start from the last + // entry. + case p.indexOffset == 0: + indexKey, indexValue = p.cursor.Last() + + // This indicates the offset being set to the very + // first entry. Since there are no entries before + // this offset, and the direction is reversed, we can + // return without adding any invoices to the response. + case p.indexOffset == 1: + return nil, nil + + // If we have been given an index offset that is beyond our last + // index value, we just return the last indexed value in our set + // since we are querying in reverse. We do not cover the case + // where our index offset equals our last index value, because + // index offset is exclusive, so we would want to start at the + // value before our last index. + case p.indexOffset > p.lastIndex(): + return p.cursor.Last() + + // Otherwise we have an index offset which is within our set of + // indexed keys, and we want to start at the item before our + // offset. We seek to our index offset, then return the element + // before it. We do this rather than p.indexOffset-1 to account + // for indexes that have gaps. + default: + p.keyValueForIndex(p.indexOffset) + indexKey, indexValue = p.cursor.Prev() + } + } + + return indexKey, indexValue +} + +// query gets the start point for our index offset and iterates through keys +// in our index until we reach the total number of items required for the query +// or we run out of cursor values. This function takes a fetchAndAppend function +// which is responsible for looking up the entry at that index, adding the entry +// to its set of return items (if desired) and return a boolean which indicates +// whether the item was added. This is required to allow the paginator to +// determine when the response has the maximum number of required items. +func (p paginator) query(fetchAndAppend func(k, v []byte) (bool, error)) error { + indexKey, indexValue := p.cursorStart() + + var totalItems int + for ; indexKey != nil; indexKey, indexValue = p.nextKey() { + // If our current return payload exceeds the max number + // of invoices, then we'll exit now. + if uint64(totalItems) >= p.totalItems { + break + } + + added, err := fetchAndAppend(indexKey, indexValue) + if err != nil { + return err + } + + // If we added an item to our set in the latest fetch and append + // we increment our total count. + if added { + totalItems++ + } + } + + return nil +} diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index 99d2000c43..5a538134ef 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "io" "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lntypes" @@ -74,6 +75,11 @@ var ( // errNoAttemptInfo is returned when no attempt info is stored yet. errNoAttemptInfo = errors.New("unable to find attempt info for " + "inflight payment") + + // errNoSequenceNrIndex is returned when an attempt to lookup a payment + // index is made for a sequence number that is not indexed. + errNoSequenceNrIndex = errors.New("payment sequence number index " + + "does not exist") ) // PaymentControl implements persistence for payments and payment attempts. @@ -152,6 +158,27 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, return err } + // Before we set our new sequence number, we check whether this + // payment has a previously set sequence number and remove its + // index entry if it exists. This happens in the case where we + // have a previously attempted payment which was left in a state + // where we can retry. + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes != nil { + indexBucket := tx.ReadWriteBucket(paymentsIndexBucket) + if err := indexBucket.Delete(seqBytes); err != nil { + return err + } + } + + // Once we have obtained a sequence number, we add an entry + // to our index bucket which will map the sequence number to + // our payment hash. + err = createPaymentIndexEntry(tx, sequenceNum, info.PaymentHash) + if err != nil { + return err + } + err = bucket.Put(paymentSequenceKey, sequenceNum) if err != nil { return err @@ -183,6 +210,58 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, return updateErr } +// paymentIndexTypeHash is a payment index type which indicates that we have +// created an index of payment sequence number to payment hash. +type paymentIndexType uint8 + +// paymentIndexTypeHash is a payment index type which indicates that we have +// created an index of payment sequence number to payment hash. +const paymentIndexTypeHash paymentIndexType = 0 + +// createPaymentIndexEntry creates a payment hash typed index for a payment. The +// index produced contains a payment index type (which can be used in future to +// signal different payment index types) and the payment hash. +func createPaymentIndexEntry(tx kvdb.RwTx, sequenceNumber []byte, + hash lntypes.Hash) error { + + var b bytes.Buffer + if err := WriteElements(&b, paymentIndexTypeHash, hash[:]); err != nil { + return err + } + + indexes := tx.ReadWriteBucket(paymentsIndexBucket) + return indexes.Put(sequenceNumber, b.Bytes()) +} + +// deserializePaymentIndex deserializes a payment index entry. This function +// currently only supports deserialization of payment hash indexes, and will +// fail for other types. +func deserializePaymentIndex(r io.Reader) (lntypes.Hash, error) { + var ( + indexType paymentIndexType + paymentHash []byte + ) + + if err := ReadElements(r, &indexType, &paymentHash); err != nil { + return lntypes.Hash{}, err + } + + // While we only have on payment index type, we do not need to use our + // index type to deserialize the index. However, we sanity check that + // this type is as expected, since we had to read it out anyway. + if indexType != paymentIndexTypeHash { + return lntypes.Hash{}, fmt.Errorf("unknown payment index "+ + "type: %v", indexType) + } + + hash, err := lntypes.MakeHash(paymentHash) + if err != nil { + return lntypes.Hash{}, err + } + + return hash, nil +} + // RegisterAttempt atomically records the provided HTLCAttemptInfo to the // DB. func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index c470a8f5f1..147e54525b 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -1,6 +1,7 @@ package channeldb import ( + "bytes" "crypto/rand" "crypto/sha256" "fmt" @@ -9,9 +10,13 @@ import ( "testing" "time" + "github.com/btcsuite/btcwallet/walletdb" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/record" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func genPreimage() ([32]byte, error) { @@ -70,6 +75,7 @@ func TestPaymentControlSwitchFail(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + assertPaymentIndex(t, pControl, info.PaymentHash) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -88,6 +94,11 @@ func TestPaymentControlSwitchFail(t *testing.T) { t, pControl, info.PaymentHash, info, &failReason, nil, ) + // Lookup the payment so we can get its old sequence number before it is + // overwritten. + payment, err := pControl.FetchPayment(info.PaymentHash) + assert.NoError(t, err) + // Sends the htlc again, which should succeed since the prior payment // failed. err = pControl.InitPayment(info.PaymentHash, info) @@ -95,6 +106,11 @@ func TestPaymentControlSwitchFail(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + // Check that our index has been updated, and the old index has been + // removed. + assertPaymentIndex(t, pControl, info.PaymentHash) + assertNoIndex(t, pControl, payment.SequenceNum) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -145,7 +161,6 @@ func TestPaymentControlSwitchFail(t *testing.T) { // Settle the attempt and verify that status was changed to // StatusSucceeded. - var payment *MPPayment payment, err = pControl.SettleAttempt( info.PaymentHash, attempt.AttemptID, &HTLCSettleInfo{ @@ -209,6 +224,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + assertPaymentIndex(t, pControl, info.PaymentHash) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -326,7 +342,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) { assertPaymentStatus(t, pControl, info.PaymentHash, StatusUnknown) } -// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only +// TestPaymentControlDeleteNonInFlight checks that calling DeletePayments only // deletes payments from the database that are not in-flight. func TestPaymentControlDeleteNonInFligt(t *testing.T) { t.Parallel() @@ -338,23 +354,37 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { t.Fatalf("unable to init db: %v", err) } + // Create a sequence number for duplicate payments that will not collide + // with the sequence numbers for the payments we create. These values + // start at 1, so 9999 is a safe bet for this test. + var duplicateSeqNr = 9999 + pControl := NewPaymentControl(db) payments := []struct { - failed bool - success bool + failed bool + success bool + hasDuplicate bool }{ { - failed: true, - success: false, + failed: true, + success: false, + hasDuplicate: false, + }, + { + failed: false, + success: true, + hasDuplicate: false, }, { - failed: false, - success: true, + failed: false, + success: false, + hasDuplicate: false, }, { - failed: false, - success: false, + failed: false, + success: true, + hasDuplicate: true, }, } @@ -430,6 +460,16 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { t, pControl, info.PaymentHash, info, nil, htlc, ) } + + // If the payment is intended to have a duplicate payment, we + // add one. + if p.hasDuplicate { + appendDuplicatePayment( + t, pControl.db, info.PaymentHash, + uint64(duplicateSeqNr), + ) + duplicateSeqNr++ + } } // Delete payments. @@ -451,6 +491,21 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { if status != StatusInFlight { t.Fatalf("expected in-fligth status, got %v", status) } + + // Finally, check that we only have a single index left in the payment + // index bucket. + var indexCount int + err = kvdb.View(db, func(tx walletdb.ReadTx) error { + index := tx.ReadBucket(paymentsIndexBucket) + + return index.ForEach(func(k, v []byte) error { + indexCount++ + return nil + }) + }) + require.NoError(t, err) + + require.Equal(t, 1, indexCount) } // TestPaymentControlMultiShard checks the ability of payment control to @@ -495,6 +550,7 @@ func TestPaymentControlMultiShard(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } + assertPaymentIndex(t, pControl, info.PaymentHash) assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( t, pControl, info.PaymentHash, info, nil, nil, @@ -910,3 +966,55 @@ func assertPaymentInfo(t *testing.T, p *PaymentControl, hash lntypes.Hash, t.Fatal("expected no settle info") } } + +// fetchPaymentIndexEntry gets the payment hash for the sequence number provided +// from our payment indexes bucket. +func fetchPaymentIndexEntry(_ *testing.T, p *PaymentControl, + sequenceNumber uint64) (*lntypes.Hash, error) { + + var hash lntypes.Hash + + if err := kvdb.View(p.db, func(tx walletdb.ReadTx) error { + indexBucket := tx.ReadBucket(paymentsIndexBucket) + key := make([]byte, 8) + byteOrder.PutUint64(key, sequenceNumber) + + indexValue := indexBucket.Get(key) + if indexValue == nil { + return errNoSequenceNrIndex + } + + r := bytes.NewReader(indexValue) + + var err error + hash, err = deserializePaymentIndex(r) + return err + + }); err != nil { + return nil, err + } + + return &hash, nil +} + +// assertPaymentIndex looks up the index for a payment in the db and checks +// that its payment hash matches the expected hash passed in. +func assertPaymentIndex(t *testing.T, p *PaymentControl, + expectedHash lntypes.Hash) { + + // Lookup the payment so that we have its sequence number and check + // that is has correctly been indexed in the payment indexes bucket. + pmt, err := p.FetchPayment(expectedHash) + require.NoError(t, err) + + hash, err := fetchPaymentIndexEntry(t, p, pmt.SequenceNum) + require.NoError(t, err) + assert.Equal(t, expectedHash, *hash) +} + +// assertNoIndex checks that an index for the sequence number provided does not +// exist. +func assertNoIndex(t *testing.T, p *PaymentControl, seqNr uint64) { + _, err := fetchPaymentIndexEntry(t, p, seqNr) + require.Equal(t, errNoSequenceNrIndex, err) +} diff --git a/channeldb/payments.go b/channeldb/payments.go index d1ec60706d..5c2475bd2b 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -3,9 +3,9 @@ package channeldb import ( "bytes" "encoding/binary" + "errors" "fmt" "io" - "math" "sort" "time" @@ -92,6 +92,35 @@ var ( // paymentFailInfoKey is a key used in the payment's sub-bucket to // store information about the reason a payment failed. paymentFailInfoKey = []byte("payment-fail-info") + + // paymentsIndexBucket is the name of the top-level bucket within the + // database that stores an index of payment sequence numbers to its + // payment hash. + // payments-sequence-index-bucket + // |--: + // |--... + // |--: + paymentsIndexBucket = []byte("payments-index-bucket") +) + +var ( + // ErrNoSequenceNumber is returned if we lookup a payment which does + // not have a sequence number. + ErrNoSequenceNumber = errors.New("sequence number not found") + + // ErrDuplicateNotFound is returned when we lookup a payment by its + // index and cannot find a payment with a matching sequence number. + ErrDuplicateNotFound = errors.New("duplicate payment not found") + + // ErrNoDuplicateBucket is returned when we expect to find duplicates + // when looking up a payment from its index, but the payment does not + // have any. + ErrNoDuplicateBucket = errors.New("expected duplicate bucket") + + // ErrNoDuplicateNestedBucket is returned if we do not find duplicate + // payments in their own sub-bucket. + ErrNoDuplicateNestedBucket = errors.New("nested duplicate bucket not " + + "found") ) // FailureReason encodes the reason a payment ultimately failed. @@ -481,62 +510,70 @@ type PaymentsResponse struct { func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) { var resp PaymentsResponse - allPayments, err := db.FetchPayments() - if err != nil { - return resp, err - } + if err := kvdb.View(db, func(tx kvdb.RTx) error { + // Get the root payments bucket. + paymentsBucket := tx.ReadBucket(paymentsRootBucket) + if paymentsBucket == nil { + return nil + } - if len(allPayments) == 0 { - return resp, nil - } + // Get the index bucket which maps sequence number -> payment + // hash and duplicate bool. If we have a payments bucket, we + // should have an indexes bucket as well. + indexes := tx.ReadBucket(paymentsIndexBucket) + if indexes == nil { + return fmt.Errorf("index bucket does not exist") + } - indexExclusiveLimit := query.IndexOffset - // In backward pagination, if the index limit is the default 0 value, - // we set our limit to maxint to include all payments from the highest - // sequence number on. - if query.Reversed && indexExclusiveLimit == 0 { - indexExclusiveLimit = math.MaxInt64 - } + // accumulatePayments gets payments with the sequence number + // and hash provided and adds them to our list of payments if + // they meet the criteria of our query. It returns the number + // of payments that were added. + accumulatePayments := func(sequenceKey, hash []byte) (bool, + error) { - for i := range allPayments { - var payment *MPPayment + r := bytes.NewReader(hash) + paymentHash, err := deserializePaymentIndex(r) + if err != nil { + return false, err + } - // If we have the max number of payments we want, exit. - if uint64(len(resp.Payments)) == query.MaxPayments { - break - } + payment, err := fetchPaymentWithSequenceNumber( + tx, paymentHash, sequenceKey, + ) + if err != nil { + return false, err + } - if query.Reversed { - payment = allPayments[len(allPayments)-1-i] + // To keep compatibility with the old API, we only + // return non-succeeded payments if requested. + if payment.Status != StatusSucceeded && + !query.IncludeIncomplete { - // In the reversed direction, skip over all payments - // that have sequence numbers greater than or equal to - // the index offset. We skip payments with equal index - // because the offset is exclusive. - if payment.SequenceNum >= indexExclusiveLimit { - continue - } - } else { - payment = allPayments[i] - - // In the forward direction, skip over all payments that - // have sequence numbers less than or equal to the index - // offset. We skip payments with equal indexes because - // the index offset is exclusive. - if payment.SequenceNum <= indexExclusiveLimit { - continue + return false, err } + + // At this point, we've exhausted the offset, so we'll + // begin collecting invoices found within the range. + resp.Payments = append(resp.Payments, payment) + return true, nil } - // To keep compatibility with the old API, we only return - // non-succeeded payments if requested. - if payment.Status != StatusSucceeded && - !query.IncludeIncomplete { + // Create a paginator which reads from our sequence index bucket + // with the parameters provided by the payments query. + paginator := newPaginator( + indexes.ReadCursor(), query.Reversed, query.IndexOffset, + query.MaxPayments, + ) - continue + // Run a paginated query, adding payments to our response. + if err := paginator.query(accumulatePayments); err != nil { + return err } - resp.Payments = append(resp.Payments, payment) + return nil + }); err != nil { + return resp, err } // Need to swap the payments slice order if reversed order. @@ -555,7 +592,84 @@ func (db *DB) QueryPayments(query PaymentsQuery) (PaymentsResponse, error) { resp.Payments[len(resp.Payments)-1].SequenceNum } - return resp, err + return resp, nil +} + +// fetchPaymentWithSequenceNumber get the payment which matches the payment hash +// *and* sequence number provided from the database. This is required because +// we previously had more than one payment per hash, so we have multiple indexes +// pointing to a single payment; we want to retrieve the correct one. +func fetchPaymentWithSequenceNumber(tx kvdb.RTx, paymentHash lntypes.Hash, + sequenceNumber []byte) (*MPPayment, error) { + + // We can now lookup the payment keyed by its hash in + // the payments root bucket. + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err != nil { + return nil, err + } + + // A single payment hash can have multiple payments associated with it. + // We lookup our sequence number first, to determine whether this is + // the payment we are actually looking for. + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return nil, ErrNoSequenceNumber + } + + // If this top level payment has the sequence number we are looking for, + // return it. + if bytes.Equal(seqBytes, sequenceNumber) { + return fetchPayment(bucket) + } + + // If we were not looking for the top level payment, we are looking for + // one of our duplicate payments. We need to iterate through the seq + // numbers in this bucket to find the correct payments. If we do not + // find a duplicate payments bucket here, something is wrong. + dup := bucket.NestedReadBucket(duplicatePaymentsBucket) + if dup == nil { + return nil, ErrNoDuplicateBucket + } + + var duplicatePayment *MPPayment + err = dup.ForEach(func(k, v []byte) error { + subBucket := dup.NestedReadBucket(k) + if subBucket == nil { + // We one bucket for each duplicate to be found. + return ErrNoDuplicateNestedBucket + } + + seqBytes := subBucket.Get(duplicatePaymentSequenceKey) + if seqBytes == nil { + return err + } + + // If this duplicate payment is not the sequence number we are + // looking for, we can continue. + if !bytes.Equal(seqBytes, sequenceNumber) { + return nil + } + + duplicatePayment, err = fetchDuplicatePayment(subBucket) + if err != nil { + return err + } + + return nil + }) + if err != nil { + return nil, err + } + + // If none of the duplicate payments matched our sequence number, we + // failed to find the payment with this sequence number; something is + // wrong. + if duplicatePayment == nil { + return nil, ErrDuplicateNotFound + } + + return duplicatePayment, nil } // DeletePayments deletes all completed and failed payments from the DB. @@ -566,7 +680,15 @@ func (db *DB) DeletePayments() error { return nil } - var deleteBuckets [][]byte + var ( + // deleteBuckets is the set of payment buckets we need + // to delete. + deleteBuckets [][]byte + + // deleteIndexes is the set of indexes pointing to these + // payments that need to be deleted. + deleteIndexes [][]byte + ) err := payments.ForEach(func(k, _ []byte) error { bucket := payments.NestedReadWriteBucket(k) if bucket == nil { @@ -589,7 +711,18 @@ func (db *DB) DeletePayments() error { return nil } + // Add the bucket to the set of buckets we can delete. deleteBuckets = append(deleteBuckets, k) + + // Get all the sequence number associated with the + // payment, including duplicates. + seqNrs, err := fetchSequenceNumbers(bucket) + if err != nil { + return err + } + + deleteIndexes = append(deleteIndexes, seqNrs...) + return nil }) if err != nil { @@ -602,10 +735,49 @@ func (db *DB) DeletePayments() error { } } + // Get our index bucket and delete all indexes pointing to the + // payments we are deleting. + indexBucket := tx.ReadWriteBucket(paymentsIndexBucket) + for _, k := range deleteIndexes { + if err := indexBucket.Delete(k); err != nil { + return err + } + } + return nil }) } +// fetchSequenceNumbers fetches all the sequence numbers associated with a +// payment, including those belonging to any duplicate payments. +func fetchSequenceNumbers(paymentBucket kvdb.RBucket) ([][]byte, error) { + seqNum := paymentBucket.Get(paymentSequenceKey) + if seqNum == nil { + return nil, errors.New("expected sequence number") + } + + sequenceNumbers := [][]byte{seqNum} + + // Get the duplicate payments bucket, if it has no duplicates, just + // return early with the payment sequence number. + duplicates := paymentBucket.NestedReadBucket(duplicatePaymentsBucket) + if duplicates == nil { + return sequenceNumbers, nil + } + + // If we do have duplicated, they are keyed by sequence number, so we + // iterate through the duplicates bucket and add them to our set of + // sequence numbers. + if err := duplicates.ForEach(func(k, v []byte) error { + sequenceNumbers = append(sequenceNumbers, k) + return nil + }); err != nil { + return nil, err + } + + return sequenceNumbers, nil +} + // nolint: dupl func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error { var scratch [8]byte diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 2f0d88bcd9..9e790c3e39 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -9,11 +9,13 @@ import ( "time" "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcwallet/walletdb" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" ) var ( @@ -163,18 +165,24 @@ func TestRouteSerialization(t *testing.T) { } // deletePayment removes a payment with paymentHash from the payments database. -func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash) { +func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, seqNr uint64) { t.Helper() err := kvdb.Update(db, func(tx kvdb.RwTx) error { payments := tx.ReadWriteBucket(paymentsRootBucket) + // Delete the payment bucket. err := payments.DeleteNestedBucket(paymentHash[:]) if err != nil { return err } - return nil + key := make([]byte, 8) + byteOrder.PutUint64(key, seqNr) + + // Delete the index that references this payment. + indexes := tx.ReadWriteBucket(paymentsIndexBucket) + return indexes.Delete(key) }) if err != nil { @@ -188,6 +196,10 @@ func deletePayment(t *testing.T, db *DB, paymentHash lntypes.Hash) { func TestQueryPayments(t *testing.T) { // Define table driven test for QueryPayments. // Test payments have sequence indices [1, 3, 4, 5, 6, 7]. + // Note that the payment with index 7 has the same payment hash as 6, + // and is stored in a nested bucket within payment 6 rather than being + // its own entry in the payments bucket. We do this to test retrieval + // of legacy payments. tests := []struct { name string query PaymentsQuery @@ -344,6 +356,42 @@ func TestQueryPayments(t *testing.T) { lastIndex: 7, expectedSeqNrs: []uint64{3, 4, 5, 6, 7}, }, + { + name: "query payments reverse before index gap", + query: PaymentsQuery{ + IndexOffset: 3, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments reverse on index gap", + query: PaymentsQuery{ + IndexOffset: 2, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments forward on index gap", + query: PaymentsQuery{ + IndexOffset: 2, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 3, + lastIndex: 4, + expectedSeqNrs: []uint64{3, 4}, + }, } for _, tt := range tests { @@ -352,17 +400,28 @@ func TestQueryPayments(t *testing.T) { t.Parallel() db, cleanup, err := makeTestDB() - defer cleanup() - if err != nil { t.Fatalf("unable to init db: %v", err) } + defer cleanup() + + // Make a preliminary query to make sure it's ok to + // query when we have no payments. + resp, err := db.QueryPayments(tt.query) + require.NoError(t, err) + require.Len(t, resp.Payments, 0) // Populate the database with a set of test payments. - numberOfPayments := 7 + // We create 6 original payments, deleting the payment + // at index 2 so that we cover the case where sequence + // numbers are missing. We also add a duplicate payment + // to the last payment added to test the legacy case + // where we have duplicates in the nested duplicates + // bucket. + nonDuplicatePayments := 6 pControl := NewPaymentControl(db) - for i := 0; i < numberOfPayments; i++ { + for i := 0; i < nonDuplicatePayments; i++ { // Generate a test payment. info, _, _, err := genInfo() if err != nil { @@ -379,7 +438,29 @@ func TestQueryPayments(t *testing.T) { // Immediately delete the payment with index 2. if i == 1 { - deletePayment(t, db, info.PaymentHash) + pmt, err := pControl.FetchPayment( + info.PaymentHash, + ) + require.NoError(t, err) + + deletePayment(t, db, info.PaymentHash, + pmt.SequenceNum) + } + + // If we are on the last payment entry, add a + // duplicate payment with sequence number equal + // to the parent payment + 1. + if i == (nonDuplicatePayments - 1) { + pmt, err := pControl.FetchPayment( + info.PaymentHash, + ) + require.NoError(t, err) + + appendDuplicatePayment( + t, pControl.db, + info.PaymentHash, + pmt.SequenceNum+1, + ) } } @@ -424,3 +505,210 @@ func TestQueryPayments(t *testing.T) { }) } } + +// TestFetchPaymentWithSequenceNumber tests lookup of payments with their +// sequence number. It sets up one payment with no duplicates, and another with +// two duplicates in its duplicates bucket then uses these payments to test the +// case where a specific duplicate is not found and the duplicates bucket is not +// present when we expect it to be. +func TestFetchPaymentWithSequenceNumber(t *testing.T) { + db, cleanup, err := makeTestDB() + require.NoError(t, err) + + defer cleanup() + + pControl := NewPaymentControl(db) + + // Generate a test payment which does not have duplicates. + noDuplicates, _, _, err := genInfo() + require.NoError(t, err) + + // Create a new payment entry in the database. + err = pControl.InitPayment(noDuplicates.PaymentHash, noDuplicates) + require.NoError(t, err) + + // Fetch the payment so we can get its sequence nr. + noDuplicatesPayment, err := pControl.FetchPayment( + noDuplicates.PaymentHash, + ) + require.NoError(t, err) + + // Generate a test payment which we will add duplicates to. + hasDuplicates, _, _, err := genInfo() + require.NoError(t, err) + + // Create a new payment entry in the database. + err = pControl.InitPayment(hasDuplicates.PaymentHash, hasDuplicates) + require.NoError(t, err) + + // Fetch the payment so we can get its sequence nr. + hasDuplicatesPayment, err := pControl.FetchPayment( + hasDuplicates.PaymentHash, + ) + require.NoError(t, err) + + // We declare the sequence numbers used here so that we can reference + // them in tests. + var ( + duplicateOneSeqNr = hasDuplicatesPayment.SequenceNum + 1 + duplicateTwoSeqNr = hasDuplicatesPayment.SequenceNum + 2 + ) + + // Add two duplicates to our second payment. + appendDuplicatePayment( + t, db, hasDuplicates.PaymentHash, duplicateOneSeqNr, + ) + appendDuplicatePayment( + t, db, hasDuplicates.PaymentHash, duplicateTwoSeqNr, + ) + + tests := []struct { + name string + paymentHash lntypes.Hash + sequenceNumber uint64 + expectedErr error + }{ + { + name: "lookup payment without duplicates", + paymentHash: noDuplicates.PaymentHash, + sequenceNumber: noDuplicatesPayment.SequenceNum, + expectedErr: nil, + }, + { + name: "lookup payment with duplicates", + paymentHash: hasDuplicates.PaymentHash, + sequenceNumber: hasDuplicatesPayment.SequenceNum, + expectedErr: nil, + }, + { + name: "lookup first duplicate", + paymentHash: hasDuplicates.PaymentHash, + sequenceNumber: duplicateOneSeqNr, + expectedErr: nil, + }, + { + name: "lookup second duplicate", + paymentHash: hasDuplicates.PaymentHash, + sequenceNumber: duplicateTwoSeqNr, + expectedErr: nil, + }, + { + name: "lookup non-existent duplicate", + paymentHash: hasDuplicates.PaymentHash, + sequenceNumber: 999999, + expectedErr: ErrDuplicateNotFound, + }, + { + name: "lookup duplicate, no duplicates bucket", + paymentHash: noDuplicates.PaymentHash, + sequenceNumber: duplicateTwoSeqNr, + expectedErr: ErrNoDuplicateBucket, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + err := kvdb.Update(db, + func(tx walletdb.ReadWriteTx) error { + + var seqNrBytes [8]byte + byteOrder.PutUint64( + seqNrBytes[:], test.sequenceNumber, + ) + + _, err := fetchPaymentWithSequenceNumber( + tx, test.paymentHash, seqNrBytes[:], + ) + return err + }) + require.Equal(t, test.expectedErr, err) + }) + } +} + +// appendDuplicatePayment adds a duplicate payment to an existing payment. Note +// that this function requires a unique sequence number. +// +// This code is *only* intended to replicate legacy duplicate payments in lnd, +// our current schema does not allow duplicates. +func appendDuplicatePayment(t *testing.T, db *DB, paymentHash lntypes.Hash, + seqNr uint64) { + + err := kvdb.Update(db, func(tx walletdb.ReadWriteTx) error { + bucket, err := fetchPaymentBucketUpdate( + tx, paymentHash, + ) + if err != nil { + return err + } + + // Create the duplicates bucket if it is not + // present. + dup, err := bucket.CreateBucketIfNotExists( + duplicatePaymentsBucket, + ) + if err != nil { + return err + } + + var sequenceKey [8]byte + byteOrder.PutUint64(sequenceKey[:], seqNr) + + // Create duplicate payments for the two dup + // sequence numbers we've setup. + putDuplicatePayment(t, dup, sequenceKey[:], paymentHash) + + // Finally, once we have created our entry we add an index for + // it. + err = createPaymentIndexEntry(tx, sequenceKey[:], paymentHash) + require.NoError(t, err) + + return nil + }) + if err != nil { + t.Fatalf("could not create payment: %v", err) + } +} + +// putDuplicatePayment creates a duplicate payment in the duplicates bucket +// provided with the minimal information required for successful reading. +func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket, + sequenceKey []byte, paymentHash lntypes.Hash) { + + paymentBucket, err := duplicateBucket.CreateBucketIfNotExists( + sequenceKey, + ) + require.NoError(t, err) + + err = paymentBucket.Put(duplicatePaymentSequenceKey, sequenceKey) + require.NoError(t, err) + + // Generate fake information for the duplicate payment. + info, _, _, err := genInfo() + require.NoError(t, err) + + // Write the payment info to disk under the creation info key. This code + // is copied rather than using serializePaymentCreationInfo to ensure + // we always write in the legacy format used by duplicate payments. + var b bytes.Buffer + var scratch [8]byte + _, err = b.Write(paymentHash[:]) + require.NoError(t, err) + + byteOrder.PutUint64(scratch[:], uint64(info.Value)) + _, err = b.Write(scratch[:]) + require.NoError(t, err) + + err = serializeTime(&b, info.CreationTime) + require.NoError(t, err) + + byteOrder.PutUint32(scratch[:4], 0) + _, err = b.Write(scratch[:4]) + require.NoError(t, err) + + // Get the PaymentCreationInfo. + err = paymentBucket.Put(duplicatePaymentCreationInfoKey, b.Bytes()) + require.NoError(t, err) +}