diff --git a/channeldb/db.go b/channeldb/db.go index 076b1c2c1637..c421fd4908c8 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -47,6 +47,12 @@ var ( number: 1, migration: migrateNodeAndEdgeUpdateIndex, }, + { + // The version with added payment statuses + // for each existing payment + number: 2, + migration: paymentStatusesMigration, + }, } // Big endian is the preferred byte order, due to cursor scans over diff --git a/channeldb/db_test.go b/channeldb/db_test.go index f3e3c96e0e1a..7edae3c58d34 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -5,6 +5,8 @@ import ( "os" "path/filepath" "testing" + + "github.com/go-errors/errors" ) func TestOpenWithCreate(t *testing.T) { @@ -33,3 +35,56 @@ func TestOpenWithCreate(t *testing.T) { t.Fatalf("channeldb failed to create data directory") } } + +// applyMigration is a helper test function that encapsulates the general steps +// which are needed to properly check the result of applying migration function. +func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), + migrationFunc migration, shouldFail bool) { + + cdb, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatal(err) + } + + // beforeMigration usually used for populating the database + // with test data. + beforeMigration(cdb) + + // Create test meta info with zero database version and put it on disk. + // Than creating the version list pretending that new version was added. + meta := &Meta{DbVersionNumber: 0} + if err := cdb.PutMeta(meta); err != nil { + t.Fatalf("unable to store meta data: %v", err) + } + + versions := []version{ + { + number: 0, + migration: nil, + }, + { + number: 1, + migration: migrationFunc, + }, + } + + defer func() { + if r := recover(); r != nil { + err = errors.New(r) + } + + if err == nil && shouldFail { + t.Fatal("error wasn't received on migration stage") + } else if err != nil && !shouldFail { + t.Fatal("error was received on migration stage") + } + + // afterMigration usually used for checking the database state and + // throwing the error if something went wrong. + afterMigration(cdb) + }() + + // Sync with the latest version - applying migration function. + err = cdb.syncVersions(versions) +} diff --git a/channeldb/meta_test.go b/channeldb/meta_test.go index 5890d6692b28..4ef54c675bd9 100644 --- a/channeldb/meta_test.go +++ b/channeldb/meta_test.go @@ -117,59 +117,6 @@ func TestGlobalVersionList(t *testing.T) { } } -// applyMigration is a helper test function that encapsulates the general steps -// which are needed to properly check the result of applying migration function. -func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), - migrationFunc migration, shouldFail bool) { - - cdb, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatal(err) - } - - // beforeMigration usually used for populating the database - // with test data. - beforeMigration(cdb) - - // Create test meta info with zero database version and put it on disk. - // Than creating the version list pretending that new version was added. - meta := &Meta{DbVersionNumber: 0} - if err := cdb.PutMeta(meta); err != nil { - t.Fatalf("unable to store meta data: %v", err) - } - - versions := []version{ - { - number: 0, - migration: nil, - }, - { - number: 1, - migration: migrationFunc, - }, - } - - defer func() { - if r := recover(); r != nil { - err = errors.New(r) - } - - if err == nil && shouldFail { - t.Fatal("error wasn't received on migration stage") - } else if err != nil && !shouldFail { - t.Fatal("error was received on migration stage") - } - - // afterMigration usually used for checking the database state and - // throwing the error if something went wrong. - afterMigration(cdb) - }() - - // Sync with the latest version - applying migration function. - err = cdb.syncVersions(versions) -} - func TestMigrationWithPanic(t *testing.T) { t.Parallel() diff --git a/channeldb/migrations.go b/channeldb/migrations.go index 50ceec790954..8db32e440180 100644 --- a/channeldb/migrations.go +++ b/channeldb/migrations.go @@ -2,6 +2,7 @@ package channeldb import ( "bytes" + "crypto/sha256" "fmt" "github.com/coreos/bbolt" @@ -112,3 +113,46 @@ func migrateNodeAndEdgeUpdateIndex(tx *bolt.Tx) error { return nil } + +// paymentStatusesMigration is a database migration intended for adding payment +// statuses for each existing payment entity in bucket to be able control +// transitions of statuses and prevent cases such as double payment +func paymentStatusesMigration(tx *bolt.Tx) error { + // Get the bucket dedicated to storing payments + bucket := tx.Bucket(paymentBucket) + if bucket == nil { + return ErrNoPaymentsCreated + } + + // Get the bucket dedicated to storing statuses of payments, + // where a key is payment hash, value is payment status + paymentStatuses, err := tx.CreateBucketIfNotExists(paymentStatusBucket) + if err != nil { + return err + } + + log.Infof("Migration database adds to all existing payments " + + "statuses as Completed") + + // For each payment in the bucket, fetch all data. + return bucket.ForEach(func(k, v []byte) error { + // ignores if it is sub-bucket + if v == nil { + return nil + } + + r := bytes.NewReader(v) + payment, err := deserializeOutgoingPayment(r) + if err != nil { + return err + } + + // calculate payment hash for current payment + paymentHash := sha256.Sum256(payment.PaymentPreimage[:]) + + // tries to update status for current payment to completed + // if it fails - migration abort transaction and return payment bucket + // to previous state + return paymentStatuses.Put(paymentHash[:], StatusCompleted.Bytes()) + }) +} diff --git a/channeldb/migrations_test.go b/channeldb/migrations_test.go new file mode 100644 index 000000000000..bfc5baf87729 --- /dev/null +++ b/channeldb/migrations_test.go @@ -0,0 +1,71 @@ +package channeldb + +import ( + "crypto/sha256" + "testing" +) + +func TestPaymentStatusesMigration(t *testing.T) { + t.Parallel() + + fakePayment := makeFakePayment() + paymentHash := sha256.Sum256(fakePayment.PaymentPreimage[:]) + + // Add fake payment to the test database and verifies that it was created + // and there is only one payment and its status is not "Completed". + beforeMigrationFunc := func(d *DB) { + if err := d.AddPayment(fakePayment); err != nil { + t.Fatalf("unable to add payment: %v", err) + } + + payments, err := d.FetchAllPayments() + if err != nil { + t.Fatalf("unable to fetch payments: %v", err) + } + + if len(payments) != 1 { + t.Fatalf("wrong qty of paymets: expected 1, got %v", + len(payments)) + } + + paymentStatus, err := d.FetchPaymentStatus(paymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + // we should receive default status if we have any in database + if paymentStatus != StatusGrounded { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusGrounded.String(), paymentStatus.String()) + } + } + + // Verify that was created payment status "Completed" for our one fake + // payment. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("migration 'paymentStatusesMigration' wasn't applied") + } + + paymentStatus, err := d.FetchPaymentStatus(paymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != StatusCompleted { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusCompleted.String(), paymentStatus.String()) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + paymentStatusesMigration, + false) +} diff --git a/channeldb/payments.go b/channeldb/payments.go index 2a0cd914a492..bc285e782435 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -3,6 +3,7 @@ package channeldb import ( "bytes" "encoding/binary" + "errors" "io" "github.com/coreos/bbolt" @@ -17,8 +18,64 @@ var ( // which is a monotonically increasing uint64. BoltDB's sequence // feature is used for generating monotonically increasing id. paymentBucket = []byte("payments") + + // paymentStatusBucket is the name of the bucket within the database that + // stores the status of a payment indexed by the payment's preimage. + paymentStatusBucket = []byte("payment-status") +) + +// PaymentStatus represent current status of payment +type PaymentStatus byte + +const ( + // StatusGrounded is status where payment is initiated and received + // an intermittent failure + StatusGrounded PaymentStatus = 0 + + // StatusInFlight is status where payment is initiated, but a response + // has not been received + StatusInFlight PaymentStatus = 1 + + // StatusCompleted is status where payment is initiated and complete + // a payment successfully + StatusCompleted PaymentStatus = 2 ) +// Bytes returns status as slice of bytes +func (ps PaymentStatus) Bytes() []byte { + return []byte{byte(ps)} +} + +// FromBytes sets status from slice of bytes +func (ps *PaymentStatus) FromBytes(status []byte) error { + if len(status) != 1 { + return errors.New("payment status is empty") + } + + switch PaymentStatus(status[0]) { + case StatusGrounded, StatusInFlight, StatusCompleted: + *ps = PaymentStatus(status[0]) + default: + return errors.New("unknown payment status") + } + + return nil +} + +// String returns readable representation of payment status +func (ps PaymentStatus) String() string { + switch ps { + case StatusGrounded: + return "Grounded" + case StatusInFlight: + return "In Flight" + case StatusCompleted: + return "Completed" + default: + return "Unknown" + } +} + // OutgoingPayment represents a successful payment between the daemon and a // remote node. Details such as the total fee paid, and the time of the payment // are stored. @@ -129,6 +186,45 @@ func (db *DB) DeleteAllPayments() error { }) } +// UpdatePaymentStatus sets status for outgoing/finished payment to store status in +// local database. +func (db *DB) UpdatePaymentStatus(paymentHash [32]byte, status PaymentStatus) error { + return db.Batch(func(tx *bolt.Tx) error { + paymentStatuses, err := tx.CreateBucketIfNotExists(paymentStatusBucket) + if err != nil { + return err + } + + return paymentStatuses.Put(paymentHash[:], status.Bytes()) + }) +} + +// FetchPaymentStatus returns payment status for outgoing payment +// if status of the payment isn't found it set to default status "StatusGrounded". +func (db *DB) FetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) { + // default status for all payments that wasn't recorded in database + paymentStatus := StatusGrounded + + err := db.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket(paymentStatusBucket) + if bucket == nil { + return nil + } + + paymentStatusBytes := bucket.Get(paymentHash[:]) + if paymentStatusBytes == nil { + return nil + } + + return paymentStatus.FromBytes(paymentStatusBytes) + }) + if err != nil { + return StatusGrounded, err + } + + return paymentStatus, nil +} + func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error { var scratch [8]byte diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 450b4acff104..d13e039d0273 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -40,6 +40,14 @@ func makeFakePayment() *OutgoingPayment { return fakePayment } +func makeFakePaymentHash() [32]byte { + var paymentHash [32]byte + rBytes, _ := randomBytes(0, 32) + copy(paymentHash[:], rBytes) + + return paymentHash +} + // randomBytes creates random []byte with length in range [minLen, maxLen) func randomBytes(minLen, maxLen int) ([]byte, error) { randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen)) @@ -195,3 +203,51 @@ func TestOutgoingPaymentWorkflow(t *testing.T) { len(paymentsAfterDeletion), 0) } } + +func TestPaymentStatusWorkflow(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + + testCases := []struct { + paymentHash [32]byte + status PaymentStatus + }{ + { + paymentHash: makeFakePaymentHash(), + status: StatusGrounded, + }, + { + paymentHash: makeFakePaymentHash(), + status: StatusInFlight, + }, + { + paymentHash: makeFakePaymentHash(), + status: StatusCompleted, + }, + } + + for _, testCase := range testCases { + err := db.UpdatePaymentStatus(testCase.paymentHash, testCase.status) + if err != nil { + t.Fatalf("unable to put payment in DB: %v", err) + } + + status, err := db.FetchPaymentStatus(testCase.paymentHash) + if err != nil { + t.Fatalf("unable to fetch payments from DB: %v", err) + } + + if status != testCase.status { + t.Fatalf("Wrong payments status after reading from DB."+ + "Got %v, want %v", + spew.Sdump(status), + spew.Sdump(testCase.status), + ) + } + } +} diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index e0f3c07e8501..48588b60c8ae 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -3604,8 +3604,8 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { // as it's a duplicate request. _, err = n.aliceServer.htlcSwitch.SendHTLC(n.bobServer.PubKey(), htlc, newMockDeobfuscator()) - if err != nil { - t.Fatalf("error shouldn't have been received got: %v", err) + if err != ErrAlreadyPaid { + t.Fatalf("ErrAlreadyPaid should have been received got: %v", err) } } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 0ed96408a286..64a781f4090b 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -121,14 +121,25 @@ type mockServer struct { var _ Peer = (*mockServer)(nil) +func initDB() (*channeldb.DB, error) { + tempPath, err := ioutil.TempDir("", "switchdb") + if err != nil { + return nil, err + } + + db, err := channeldb.Open(tempPath) + if err != nil { + return nil, err + } + + return db, err +} + func initSwitchWithDB(db *channeldb.DB) (*Switch, error) { - if db == nil { - tempPath, err := ioutil.TempDir("", "switchdb") - if err != nil { - return nil, err - } + var err error - db, err = channeldb.Open(tempPath) + if db == nil { + db, err = initDB() if err != nil { return nil, err } diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index f11d1fd39d76..4d80ff41cdc4 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -2,25 +2,24 @@ package htlcswitch import ( "bytes" + "crypto/sha256" "fmt" "sync" "sync/atomic" "time" - "crypto/sha256" - "github.com/coreos/bbolt" "github.com/davecgh/go-spew/spew" + "github.com/go-errors/errors" "github.com/roasbeef/btcd/btcec" + "github.com/roasbeef/btcd/wire" + "github.com/roasbeef/btcutil" - "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" - "github.com/roasbeef/btcd/wire" - "github.com/roasbeef/btcutil" ) var ( @@ -171,6 +170,9 @@ type Switch struct { paymentSequencer Sequencer + // control provides verification of sending htlc mesages + control ControlTower + // circuits is storage for payment circuits which are used to // forward the settle/fail htlc updates back to the add htlc initiator. circuits CircuitMap @@ -246,10 +248,13 @@ func New(cfg Config) (*Switch, error) { return nil, err } + pControl := NewPaymentControl(cfg.DB) + return &Switch{ cfg: &cfg, circuits: circuitMap, paymentSequencer: sequencer, + control: pControl, linkIndex: make(map[lnwire.ChannelID]ChannelLink), mailOrchestrator: newMailOrchestrator(), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), @@ -304,6 +309,11 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC, deobfuscator ErrorDecrypter) ([sha256.Size]byte, error) { + // Verify message by ControlTower implementation. + if err := s.control.CheckSend(htlc); err != nil { + return zeroPreimage, err + } + // Create payment and add to the map of payment in order later to be // able to retrieve it and return response to the user. payment := &pendingPayment{ @@ -336,6 +346,10 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC, if err := s.forward(packet); err != nil { s.removePendingPayment(paymentID) + if err := s.control.Fail(htlc.PaymentHash); err != nil { + return zeroPreimage, err + } + return zeroPreimage, err } @@ -805,6 +819,10 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { payment.preimage <- htlc.PaymentPreimage s.removePendingPayment(pkt.incomingHTLCID) + if err := s.control.Success(pkt.circuit.PaymentHash); err != nil { + return err + } + // We've just received a fail update which means we can finalize the // user payment and return fail response. case *lnwire.UpdateFailHTLC: @@ -869,6 +887,10 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, FailureMessage: lnwire.FailPermanentChannelFailure{}, } + if err := s.control.Fail(pkt.circuit.PaymentHash); err != nil { + log.Error(err) + } + // A regular multi-hop payment error that we'll need to // decrypt. default: @@ -885,6 +907,10 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, ExtraMsg: userErr, FailureMessage: lnwire.NewTemporaryChannelFailure(nil), } + + if err := s.control.Fail(pkt.circuit.PaymentHash); err != nil { + log.Error(err) + } } } diff --git a/htlcswitch/switch_control.go b/htlcswitch/switch_control.go new file mode 100644 index 000000000000..e9b65abca11e --- /dev/null +++ b/htlcswitch/switch_control.go @@ -0,0 +1,142 @@ +package htlcswitch + +import ( + "errors" + "sync" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // ErrAlreadyPaid is used when we have already paid + ErrAlreadyPaid = errors.New("invoice is already paid") + + // ErrPaymentInFlight returns in case if payment is already "in flight" + ErrPaymentInFlight = errors.New("payment is in transition") + + // ErrPaymentNotInitiated returns in case if payment wasn't initiated + // in switch + ErrPaymentNotInitiated = errors.New("payment isn't initiated") + + // ErrPaymentAlreadyCompleted returns in case of attempt to complete + // completed payment + ErrPaymentAlreadyCompleted = errors.New("payment is already completed") +) + +// ControlTower is a controller interface of sending HTLC messages to switch +type ControlTower interface { + // CheckSend intercepts incoming message to provide checks + // and fail if specific message is not allowed by implementation + CheckSend(htlc *lnwire.UpdateAddHTLC) error + + // Success marks message transition as successful + Success(paymentHash [32]byte) error + + // Fail marks message transition as failed + Fail(paymentHash [32]byte) error +} + +// paymentControl is implementation of ControlTower to restrict double payment +// sending. +type paymentControl struct { + mx sync.Mutex + + db *channeldb.DB +} + +// NewPaymentControl creates a new instance of the paymentControl. +func NewPaymentControl(db *channeldb.DB) ControlTower { + return &paymentControl{ + db: db, + } +} + +// CheckSend checks that a sending htlc wasn't triggered before for specific +// payment hash, if so, should trigger error depends on current status +func (p *paymentControl) CheckSend(htlc *lnwire.UpdateAddHTLC) error { + p.mx.Lock() + defer p.mx.Unlock() + + // Retrieve current status of payment from local database. + paymentStatus, err := p.db.FetchPaymentStatus(htlc.PaymentHash) + if err != nil { + return err + } + + switch paymentStatus { + case channeldb.StatusGrounded: + // It is safe to reattempt a payment if we know that we haven't + // left one in flight prior to restarting and switch. + return p.db.UpdatePaymentStatus(htlc.PaymentHash, + channeldb.StatusInFlight) + + case channeldb.StatusInFlight: + // Not clear if it's safe to reinitiate a payment if there + // is already a payment in flight, so we should withhold any + // additional attempts to send to that payment hash. + return ErrPaymentInFlight + + case channeldb.StatusCompleted: + // It has been already paid and don't want to pay again. + return ErrAlreadyPaid + } + + return nil +} + +// Success proceed status changing of payment to next successful status +func (p *paymentControl) Success(paymentHash [32]byte) error { + p.mx.Lock() + defer p.mx.Unlock() + + paymentStatus, err := p.db.FetchPaymentStatus(paymentHash) + if err != nil { + return err + } + + switch paymentStatus { + case channeldb.StatusGrounded: + // Payment isn't initiated but received. + return ErrPaymentNotInitiated + + case channeldb.StatusInFlight: + // Successful transition from InFlight transition to Completed. + return p.db.UpdatePaymentStatus(paymentHash, channeldb.StatusCompleted) + + case channeldb.StatusCompleted: + // Payment is completed before in should be ignored. + return ErrPaymentAlreadyCompleted + } + + return nil +} + +// Fail proceed status changing of payment to initial status in case of failure +func (p *paymentControl) Fail(paymentHash [32]byte) error { + p.mx.Lock() + defer p.mx.Unlock() + + paymentStatus, err := p.db.FetchPaymentStatus(paymentHash) + if err != nil { + return err + } + + switch paymentStatus { + case channeldb.StatusGrounded: + // Unpredictable behavior when payment wasn't transited to + // StatusInFlight status and was failed. + return ErrPaymentNotInitiated + + case channeldb.StatusInFlight: + // If payment wasn't processed by some reason should return to + // default status to unlock retrying option for the same payment hash. + return p.db.UpdatePaymentStatus(paymentHash, channeldb.StatusGrounded) + + case channeldb.StatusCompleted: + // Payment is completed before and can't be moved to another status. + return ErrPaymentAlreadyCompleted + } + + return nil +} diff --git a/htlcswitch/switch_control_test.go b/htlcswitch/switch_control_test.go new file mode 100644 index 000000000000..46ff179b6af4 --- /dev/null +++ b/htlcswitch/switch_control_test.go @@ -0,0 +1,206 @@ +package htlcswitch + +import ( + "fmt" + "testing" + + "github.com/btcsuite/fastsha256" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" +) + +func genHtlc() (*lnwire.UpdateAddHTLC, error) { + preimage, err := genPreimage() + if err != nil { + return nil, fmt.Errorf("unable to generate preimage: %v", err) + } + + rhash := fastsha256.Sum256(preimage[:]) + htlc := &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + } + + return htlc, nil +} + +// TestPaymentControlSwitch checks the ability of payment control +// change states of payments +func TestPaymentControlSwitch(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + htlc, err := genHtlc() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // sends base htlc message which initiate base status + // and move it to StatusInFlight and verifies that it + // was changed + if err := pControl.CheckSend(htlc); err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + pStatus, err := db.FetchPaymentStatus(htlc.PaymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if pStatus != channeldb.StatusInFlight { + t.Fatalf("payment status mismatch: expected %v, got %v", + channeldb.StatusInFlight, pStatus) + } + + // verifies that status was changed to StatusCompleted + if err := pControl.Success(htlc.PaymentHash); err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + + pStatus, err = db.FetchPaymentStatus(htlc.PaymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if pStatus != channeldb.StatusCompleted { + t.Fatalf("payment status mismatch: expected %v, got %v", + channeldb.StatusCompleted, pStatus) + } +} + +// TestPaymentControlSwitchFail checks that payment status returns +// to initial status after fail +func TestPaymentControlSwitchFail(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + htlc, err := genHtlc() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // sends base htlc message which initiate StatusInFlight + if err := pControl.CheckSend(htlc); err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + pStatus, err := db.FetchPaymentStatus(htlc.PaymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if pStatus != channeldb.StatusInFlight { + t.Fatalf("payment status mismatch: expected %v, got %v", + channeldb.StatusInFlight, pStatus) + } + + // move payment to completed status, second payment should return error + pControl.Fail(htlc.PaymentHash) + + pStatus, err = db.FetchPaymentStatus(htlc.PaymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if pStatus != channeldb.StatusGrounded { + t.Fatalf("payment status mismatch: expected %v, got %v", + channeldb.StatusGrounded, pStatus) + } +} + +// TestPaymentControlSwitchDoubleSend checks the ability of payment control +// to prevent double sending of htlc message, when message is in StatusInFlight +func TestPaymentControlSwitchDoubleSend(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + htlc, err := genHtlc() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // sends base htlc message which initiate base status + // and move it to StatusInFlight and verifies that it + // was changed + if err := pControl.CheckSend(htlc); err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + pStatus, err := db.FetchPaymentStatus(htlc.PaymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if pStatus != channeldb.StatusInFlight { + t.Fatalf("payment status mismatch: expected %v, got %v", + channeldb.StatusInFlight, pStatus) + } + + // tries to initiate double sending of htlc message with the same + // payment hash + if err := pControl.CheckSend(htlc); err != ErrPaymentInFlight { + t.Fatalf("payment control wrong behaviour: " + + "double sending must trigger ErrPaymentInFlight error") + } +} + +// TestPaymentControlSwitchDoublePay checks the ability of payment control +// to prevent double payment +func TestPaymentControlSwitchDoublePay(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + htlc, err := genHtlc() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // sends base htlc message which initiate StatusInFlight + if err := pControl.CheckSend(htlc); err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + pStatus, err := db.FetchPaymentStatus(htlc.PaymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if pStatus != channeldb.StatusInFlight { + t.Fatalf("payment status mismatch: expected %v, got %v", + channeldb.StatusInFlight, pStatus) + } + + // move payment to completed status, second payment should return error + if err := pControl.Success(htlc.PaymentHash); err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + + if err := pControl.CheckSend(htlc); err != ErrAlreadyPaid { + t.Fatalf("payment control wrong behaviour:" + + " double payment must trigger ErrAlreadyPaid") + } +} diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 82932d3c04e8..e8686cde07ba 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1619,7 +1619,9 @@ func TestSwitchSendPayment(t *testing.T) { } case err := <-errChan: - t.Fatalf("unable to send payment: %v", err) + if err != ErrPaymentInFlight { + t.Fatalf("unable to send payment: %v", err) + } case <-time.After(time.Second): t.Fatal("request was not propagated to destination") } @@ -1636,11 +1638,11 @@ func TestSwitchSendPayment(t *testing.T) { t.Fatal("request was not propagated to destination") } - if s.numPendingPayments() != 2 { + if s.numPendingPayments() != 1 { t.Fatal("wrong amount of pending payments") } - if s.circuits.NumOpen() != 2 { + if s.circuits.NumOpen() != 1 { t.Fatal("wrong amount of circuits") } @@ -1676,29 +1678,6 @@ func TestSwitchSendPayment(t *testing.T) { t.Fatal("err wasn't received") } - packet = &htlcPacket{ - outgoingChanID: aliceChannelLink.ShortChanID(), - outgoingHTLCID: 1, - htlc: &lnwire.UpdateFailHTLC{ - Reason: reason, - }, - } - - // Send second failure response and check that user were able to - // receive the error. - if err := s.forward(packet); err != nil { - t.Fatalf("can't forward htlc packet: %v", err) - } - - select { - case err := <-errChan: - if err.Error() != errors.New(lnwire.CodeIncorrectPaymentAmount).Error() { - t.Fatal("err wasn't received") - } - case <-time.After(time.Second): - t.Fatal("err wasn't received") - } - if s.numPendingPayments() != 0 { t.Fatal("wrong amount of pending payments") } diff --git a/lnd_test.go b/lnd_test.go index 96ca8998d8ae..19621a63c72d 100644 --- a/lnd_test.go +++ b/lnd_test.go @@ -440,6 +440,17 @@ func completePaymentRequests(ctx context.Context, client lnrpc.LightningClient, return nil } +// makeFakePayHash creates random pre image hash +func makeFakePayHash(t *harnessTest) []byte { + randBuf := make([]byte, 32) + + if _, err := rand.Read(randBuf); err != nil { + t.Fatalf("internal error, cannot generate random string: %v", err) + } + + return randBuf +} + const ( AddrTypeWitnessPubkeyHash = lnrpc.NewAddressRequest_WITNESS_PUBKEY_HASH AddrTypeNestedPubkeyHash = lnrpc.NewAddressRequest_NESTED_PUBKEY_HASH @@ -1717,13 +1728,13 @@ func testChannelForceClosure(net *lntest.NetworkHarness, t *harnessTest) { if err != nil { t.Fatalf("unable to create payment stream for alice: %v", err) } + carolPubKey := carol.PubKey[:] - payHash := bytes.Repeat([]byte{2}, 32) for i := 0; i < numInvoices; i++ { err = alicePayStream.Send(&lnrpc.SendRequest{ Dest: carolPubKey, Amt: int64(paymentAmt), - PaymentHash: payHash, + PaymentHash: makeFakePayHash(t), FinalCltvDelta: defaultBitcoinTimeLockDelta, }) if err != nil { @@ -3656,9 +3667,16 @@ func testPrivateChannels(net *lntest.NetworkHarness, t *harnessTest) { const paymentAmt = 70000 payReqs := make([]string, numPayments) for i := 0; i < numPayments; i++ { + preimage := make([]byte, 32) + _, err := rand.Read(preimage) + if err != nil { + t.Fatalf("unable to generate preimage: %v", err) + } + invoice := &lnrpc.Invoice{ - Memo: "testing", - Value: paymentAmt, + Memo: "testing", + RPreimage: preimage, + Value: paymentAmt, } resp, err := net.Bob.AddInvoice(ctxb, invoice) if err != nil { @@ -3719,9 +3737,16 @@ func testPrivateChannels(net *lntest.NetworkHarness, t *harnessTest) { const paymentAmt60k = 60000 payReqs = make([]string, numPayments) for i := 0; i < numPayments; i++ { + preimage := make([]byte, 32) + _, err := rand.Read(preimage) + if err != nil { + t.Fatalf("unable to generate preimage: %v", err) + } + invoice := &lnrpc.Invoice{ - Memo: "testing", - Value: paymentAmt60k, + Memo: "testing", + RPreimage: preimage, + Value: paymentAmt60k, } resp, err := carol.AddInvoice(ctxb, invoice) if err != nil { @@ -4223,10 +4248,9 @@ func testInvoiceSubscriptions(net *lntest.NetworkHarness, t *harnessTest) { // TODO(roasbeef): make global list of invoices for each node to re-use // and avoid collisions const paymentAmt = 1000 - preimage := bytes.Repeat([]byte{byte(90)}, 32) invoice := &lnrpc.Invoice{ Memo: "testing", - RPreimage: preimage, + RPreimage: makeFakePayHash(t), Value: paymentAmt, } invoiceResp, err := net.Bob.AddInvoice(ctxb, invoice) @@ -5740,7 +5764,7 @@ out: // stream on payment error. ctxt, _ = context.WithTimeout(ctxb, timeout) sendReq := &lnrpc.SendRequest{ - PaymentHashString: hex.EncodeToString(bytes.Repeat([]byte("Z"), 32)), + PaymentHashString: hex.EncodeToString(makeFakePayHash(t)), DestString: hex.EncodeToString(carol.PubKey[:]), Amt: payAmt, } @@ -5869,6 +5893,12 @@ out: "instead: %v", resp.PaymentError) } + // Generate new invoice to not pay same invoice twice + carolInvoice, err = carol.AddInvoice(ctxb, invoiceReq) + if err != nil { + t.Fatalf("unable to generate carol invoice: %v", err) + } + // For our final test, we'll ensure that if a target link isn't // available for what ever reason then the payment fails accordingly. // @@ -6937,8 +6967,8 @@ func testMultiHopHtlcLocalTimeout(net *lntest.NetworkHarness, t *harnessTest) { // We'll create two random payment hashes unknown to carol, then send // each of them by manually specifying the HTLC details. carolPubKey := carol.PubKey[:] - dustPayHash := bytes.Repeat([]byte{1}, 32) - payHash := bytes.Repeat([]byte{2}, 32) + dustPayHash := makeFakePayHash(t) + payHash := makeFakePayHash(t) err = alicePayStream.Send(&lnrpc.SendRequest{ Dest: carolPubKey, Amt: int64(dustHtlcAmt), @@ -7392,7 +7422,7 @@ func testMultiHopLocalForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness, // We'll now send a single HTLC across our multi-hop network. carolPubKey := carol.PubKey[:] - payHash := bytes.Repeat([]byte{2}, 32) + payHash := makeFakePayHash(t) err = alicePayStream.Send(&lnrpc.SendRequest{ Dest: carolPubKey, Amt: int64(htlcAmt), @@ -7647,7 +7677,7 @@ func testMultiHopRemoteForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness, // We'll now send a single HTLC across our multi-hop network. carolPubKey := carol.PubKey[:] - payHash := bytes.Repeat([]byte{2}, 32) + payHash := makeFakePayHash(t) err = alicePayStream.Send(&lnrpc.SendRequest{ Dest: carolPubKey, Amt: int64(htlcAmt),