Skip to content

Commit

Permalink
htlcswitch: face race condition in unit tests by returning invoice
Browse files Browse the repository at this point in the history
In this commit we modify the primary InvoiceRegistry interface within
the package to instead return a direct value for LookupInvoice rather
than a pointer. This fixes an existing race condition wherein a caller
could modify or read the value of the returned invoice.
  • Loading branch information
Roasbeef committed Nov 12, 2017
1 parent 010815e commit b6f6493
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 14 deletions.
2 changes: 1 addition & 1 deletion htlcswitch/interfaces.go
Expand Up @@ -12,7 +12,7 @@ import (
type InvoiceDatabase interface {
// LookupInvoice attempts to look up an invoice according to it's 32
// byte payment hash.
LookupInvoice(chainhash.Hash) (*channeldb.Invoice, error)
LookupInvoice(chainhash.Hash) (channeldb.Invoice, error)

// SettleInvoice attempts to mark an invoice corresponding to the
// passed payment hash as fully settled.
Expand Down
4 changes: 2 additions & 2 deletions htlcswitch/link_test.go
Expand Up @@ -978,7 +978,7 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) {
invoice.Terms.PaymentPreimage[0] ^= byte(255)

// Check who is last in the route and add invoice to server registry.
if err := n.carolServer.registry.AddInvoice(invoice); err != nil {
if err := n.carolServer.registry.AddInvoice(*invoice); err != nil {
t.Fatalf("unable to add invoice in carol registry: %v", err)
}

Expand Down Expand Up @@ -1955,7 +1955,7 @@ func TestChannelRetransmission(t *testing.T) {
// TODO(andrew.shvv) Will be removed if we move the notification center
// to the channel link itself.

var invoice *channeldb.Invoice
var invoice channeldb.Invoice
for i := 0; i < 20; i++ {
select {
case <-time.After(time.Millisecond * 200):
Expand Down
11 changes: 6 additions & 5 deletions htlcswitch/mock.go
Expand Up @@ -397,22 +397,22 @@ var _ ChannelLink = (*mockChannelLink)(nil)

type mockInvoiceRegistry struct {
sync.Mutex
invoices map[chainhash.Hash]*channeldb.Invoice
invoices map[chainhash.Hash]channeldb.Invoice
}

func newMockRegistry() *mockInvoiceRegistry {
return &mockInvoiceRegistry{
invoices: make(map[chainhash.Hash]*channeldb.Invoice),
invoices: make(map[chainhash.Hash]channeldb.Invoice),
}
}

func (i *mockInvoiceRegistry) LookupInvoice(rHash chainhash.Hash) (*channeldb.Invoice, error) {
func (i *mockInvoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, error) {
i.Lock()
defer i.Unlock()

invoice, ok := i.invoices[rHash]
if !ok {
return nil, errors.New("can't find mock invoice")
return channeldb.Invoice{}, errors.New("can't find mock invoice")
}

return invoice, nil
Expand All @@ -428,11 +428,12 @@ func (i *mockInvoiceRegistry) SettleInvoice(rhash chainhash.Hash) error {
}

invoice.Terms.Settled = true
i.invoices[rhash] = invoice

return nil
}

func (i *mockInvoiceRegistry) AddInvoice(invoice *channeldb.Invoice) error {
func (i *mockInvoiceRegistry) AddInvoice(invoice channeldb.Invoice) error {
i.Lock()
defer i.Unlock()

Expand Down
2 changes: 1 addition & 1 deletion htlcswitch/test_utils.go
Expand Up @@ -549,7 +549,7 @@ func (n *threeHopNetwork) makePayment(sendingPeer, receivingPeer Peer,
rhash = fastsha256.Sum256(invoice.Terms.PaymentPreimage[:])

// Check who is last in the route and add invoice to server registry.
if err := receiver.registry.AddInvoice(invoice); err != nil {
if err := receiver.registry.AddInvoice(*invoice); err != nil {
paymentErr <- err
return &paymentResponse{
rhash: rhash,
Expand Down
11 changes: 8 additions & 3 deletions invoiceregistry.go
Expand Up @@ -98,7 +98,7 @@ func (i *invoiceRegistry) AddInvoice(invoice *channeldb.Invoice) error {
// lookupInvoice looks up an invoice by its payment hash (R-Hash), if found
// then we're able to pull the funds pending within an HTLC.
// TODO(roasbeef): ignore if settled?
func (i *invoiceRegistry) LookupInvoice(rHash chainhash.Hash) (*channeldb.Invoice, error) {
func (i *invoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, error) {
// First check the in-memory debug invoice index to see if this is an
// existing invoice added for debugging.
i.RLock()
Expand All @@ -107,12 +107,17 @@ func (i *invoiceRegistry) LookupInvoice(rHash chainhash.Hash) (*channeldb.Invoic

// If found, then simply return the invoice directly.
if ok {
return invoice, nil
return *invoice, nil
}

// Otherwise, we'll check the database to see if there's an existing
// matching invoice.
return i.cdb.LookupInvoice(rHash)
invoice, err := i.cdb.LookupInvoice(rHash)
if err != nil {
return channeldb.Invoice{}, err
}

return *invoice, nil
}

// SettleInvoice attempts to mark an invoice as settled. If the invoice is a
Expand Down
5 changes: 4 additions & 1 deletion lnwallet/channel_test.go
Expand Up @@ -3359,7 +3359,10 @@ func TestChanSyncUnableToSync(t *testing.T) {
}
}

// TestChanAvailableBandwidth...
// TestChanAvailableBandwidth tests the accuracy of the AvailableBalance()
// method. The value returned from this message should reflect the value
// returned within the commitment state of a channel after the transition is
// initiated.
func TestChanAvailableBandwidth(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 1 addition & 1 deletion rpcserver.go
Expand Up @@ -2029,7 +2029,7 @@ func (r *rpcServer) LookupInvoice(ctx context.Context,
return spew.Sdump(invoice)
}))

rpcInvoice, err := createRPCInvoice(invoice)
rpcInvoice, err := createRPCInvoice(&invoice)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit b6f6493

Please sign in to comment.