Skip to content

Commit

Permalink
Merge 0758b93 into 27430f8
Browse files Browse the repository at this point in the history
  • Loading branch information
bhandras committed Dec 13, 2019
2 parents 27430f8 + 0758b93 commit becbbac
Show file tree
Hide file tree
Showing 16 changed files with 1,147 additions and 239 deletions.
109 changes: 109 additions & 0 deletions channeldb/invoice_test.go
Expand Up @@ -2,6 +2,7 @@ package channeldb

import (
"crypto/rand"
mrand "math/rand"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -393,6 +394,114 @@ func TestInvoiceAddTimeSeries(t *testing.T) {
}
}

// Tests that FetchAllInvoicesWithPaymentHash returns all invoices with their
// corresponding payment hashes.
func TestFetchAllInvoicesWithPaymentHash(t *testing.T) {
t.Parallel()

db, cleanup, err := makeTestDB()
defer cleanup()
if err != nil {
t.Fatalf("unable to make test db: %v", err)
}

// With an empty DB we expect to return no error and an empty list.
empty, err := db.FetchAllInvoicesWithPaymentHash(false)
if err != nil {
t.Fatalf("failed to call FetchAllInvoicesWithPaymentHash on empty DB: %v",
err)
}

if len(empty) != 0 {
t.Fatalf("expected empty list as a result, got: %v", empty)
}

// Now populate the DB and check if we can get all invoices with their
// payment hashes as expected.
const numInvoices = 20
testPendingInvoices := make(map[lntypes.Hash]*Invoice)
testAllInvoices := make(map[lntypes.Hash]*Invoice)

states := []ContractState{
ContractOpen, ContractSettled, ContractCanceled, ContractAccepted,
}

for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ {
invoice, err := randInvoice(i)
if err != nil {
t.Fatalf("unable to create invoice: %v", err)
}

invoice.State = states[mrand.Intn(len(states))]
paymentHash := invoice.Terms.PaymentPreimage.Hash()

if invoice.State != ContractSettled && invoice.State != ContractCanceled {
testPendingInvoices[paymentHash] = invoice
}

testAllInvoices[paymentHash] = invoice

if _, err := db.AddInvoice(invoice, paymentHash); err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
}

pendingInvoices, err := db.FetchAllInvoicesWithPaymentHash(true)
if err != nil {
t.Fatalf("can't fetch invoices with payment hash: %v", err)
}

if len(testPendingInvoices) != len(pendingInvoices) {
t.Fatalf("expected %v pending invoices, got: %v",
len(testPendingInvoices), len(pendingInvoices))
}

allInvoices, err := db.FetchAllInvoicesWithPaymentHash(false)
if err != nil {
t.Fatalf("can't fetch invoices with payment hash: %v", err)
}

if len(testAllInvoices) != len(allInvoices) {
t.Fatalf("expected %v invoices, got: %v",
len(testAllInvoices), len(allInvoices))
}

for i := range pendingInvoices {
expected, ok := testPendingInvoices[pendingInvoices[i].PaymentHash]
if !ok {
t.Fatalf("coulnd't find invoice with hash: %v",
pendingInvoices[i].PaymentHash)
}

// Zero out add index to not confuse DeepEqual.
pendingInvoices[i].Invoice.AddIndex = 0
expected.AddIndex = 0

if !reflect.DeepEqual(*expected, pendingInvoices[i].Invoice) {
t.Fatalf("expected: %v, got: %v",
spew.Sdump(expected), spew.Sdump(pendingInvoices[i].Invoice))
}
}

for i := range allInvoices {
expected, ok := testAllInvoices[allInvoices[i].PaymentHash]
if !ok {
t.Fatalf("coulnd't find invoice with hash: %v",
allInvoices[i].PaymentHash)
}

// Zero out add index to not confuse DeepEqual.
allInvoices[i].Invoice.AddIndex = 0
expected.AddIndex = 0

if !reflect.DeepEqual(*expected, allInvoices[i].Invoice) {
t.Fatalf("expected: %v, got: %v",
spew.Sdump(expected), spew.Sdump(allInvoices[i].Invoice))
}
}

}

// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it
// twice, then the second time we also receive the invoice that we settled as a
// return argument.
Expand Down
77 changes: 77 additions & 0 deletions channeldb/invoices.go
Expand Up @@ -565,6 +565,83 @@ func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) {
return invoice, nil
}

// InvoiceWithPaymentHash is used to store an invoice and its corresponding
// payment hash. This struct is only used to store results of
// ChannelDB.FetchAllInvoicesWithPaymentHash() call.
type InvoiceWithPaymentHash struct {
// Invoice holds the invoice as selected from the invoices bucket.
Invoice Invoice

// PaymentHash is the payment hash for the Invoice.
PaymentHash lntypes.Hash
}

// FetchAllInvoicesWithPaymentHash returns all invoices and their payment hashes
// currently stored within the database. If the pendingOnly param is true, then
// only unsettled invoices and their payment hashes will be returned, skipping
// all invoices that are fully settled or canceled. Note that the returned
// array is not ordered by add index.
func (d *DB) FetchAllInvoicesWithPaymentHash(pendingOnly bool) (
[]InvoiceWithPaymentHash, error) {

var result []InvoiceWithPaymentHash

err := d.View(func(tx *bbolt.Tx) error {
invoices := tx.Bucket(invoiceBucket)
if invoices == nil {
return ErrNoInvoicesCreated
}

invoiceIndex := invoices.Bucket(invoiceIndexBucket)
if invoiceIndex == nil {
// Mask the error if there's no invoice
// index as that simply means there are no
// invoices added yet to the DB. In this case
// we simply return an empty list.
return nil
}

return invoiceIndex.ForEach(func(k, v []byte) error {
// Skip the special numInvoicesKey as that does not
// point to a valid invoice.
if bytes.Equal(k, numInvoicesKey) {
return nil
}

if v == nil {
return nil
}

invoice, err := fetchInvoice(v, invoices)
if err != nil {
return err
}

if pendingOnly &&
(invoice.State == ContractSettled ||
invoice.State == ContractCanceled) {

return nil
}

invoiceWithPaymentHash := InvoiceWithPaymentHash{
Invoice: invoice,
}

copy(invoiceWithPaymentHash.PaymentHash[:], k)
result = append(result, invoiceWithPaymentHash)

return nil
})
})

if err != nil {
return nil, err
}

return result, nil
}

// FetchAllInvoices returns all invoices currently stored within the database.
// If the pendingOnly param is true, then only unsettled invoices will be
// returned, skipping all invoices that are fully settled.
Expand Down
24 changes: 24 additions & 0 deletions clock/default_clock.go
@@ -0,0 +1,24 @@
package clock

import (
"time"
)

// DefaultClock implements Clock interface by simply calling the appropriate
// time functions.
type DefaultClock struct{}

// NewDefaultClock constructs a new DefaultClock.
func NewDefaultClock() Clock {
return &DefaultClock{}
}

// Now simply returns time.Now().
func (DefaultClock) Now() time.Time {
return time.Now()
}

// TickAfter simply wraps time.After().
func (DefaultClock) TickAfter(duration time.Duration) <-chan time.Time {
return time.After(duration)
}
16 changes: 16 additions & 0 deletions clock/interface.go
@@ -0,0 +1,16 @@
package clock

import (
"time"
)

// Clock is an interface that provides a time functions for LND packages.
// This is useful during testing when a concrete time reference is needed.
type Clock interface {
// Now returns the current local time (as defined by the Clock).
Now() time.Time

// TickAfter returns a channel that will receive a tick after the specified
// duration has passed.
TickAfter(duration time.Duration) <-chan time.Time
}
28 changes: 13 additions & 15 deletions invoices/clock_test.go → clock/test_clock.go
@@ -1,42 +1,40 @@
package invoices
package clock

import (
"sync"
"time"
)

// testClock can be used in tests to mock time.
type testClock struct {
// TestClock can be used in tests to mock time.
type TestClock struct {
currentTime time.Time
timeChanMap map[time.Time][]chan time.Time
timeLock sync.Mutex
}

// newTestClock returns a new test clock.
func newTestClock(startTime time.Time) *testClock {
return &testClock{
// NewTestClock returns a new test clock.
func NewTestClock(startTime time.Time) *TestClock {
return &TestClock{
currentTime: startTime,
timeChanMap: make(map[time.Time][]chan time.Time),
}
}

// now returns the current (test) time.
func (c *testClock) now() time.Time {
// Now returns the current (test) time.
func (c *TestClock) Now() time.Time {
c.timeLock.Lock()
defer c.timeLock.Unlock()

return c.currentTime
}

// tickAfter returns a channel that will receive a tick at the specified time.
func (c *testClock) tickAfter(duration time.Duration) <-chan time.Time {
// TickAfter returns a channel that will receive a tick after the specified
// duration has passed passed by the user set test time.
func (c *TestClock) TickAfter(duration time.Duration) <-chan time.Time {
c.timeLock.Lock()
defer c.timeLock.Unlock()

triggerTime := c.currentTime.Add(duration)
log.Debugf("tickAfter called: duration=%v, trigger_time=%v",
duration, triggerTime)

ch := make(chan time.Time, 1)

// If already expired, tick immediately.
Expand All @@ -53,8 +51,8 @@ func (c *testClock) tickAfter(duration time.Duration) <-chan time.Time {
return ch
}

// setTime sets the (test) time and triggers tick channels when they expire.
func (c *testClock) setTime(now time.Time) {
// SetTime sets the (test) time and triggers tick channels when they expire.
func (c *TestClock) SetTime(now time.Time) {
c.timeLock.Lock()
defer c.timeLock.Unlock()

Expand Down
63 changes: 63 additions & 0 deletions clock/test_clock_test.go
@@ -0,0 +1,63 @@
package clock

import (
"testing"
"time"
)

var (
testTime = time.Date(2009, time.January, 3, 12, 0, 0, 0, time.UTC)
)

func TestNow(t *testing.T) {
c := NewTestClock(testTime)
now := c.Now()

if now != testTime {
t.Fatalf("expected: %v, got: %v", testTime, now)
}

now = now.Add(time.Hour)
c.SetTime(now)
if c.Now() != now {
t.Fatalf("epected: %v, got: %v", now, c.Now())
}
}

func TestTickAfter(t *testing.T) {
c := NewTestClock(testTime)

// Should be ticking immediately.
ticker0 := c.TickAfter(0)

// Both should be ticking after SetTime
ticker1 := c.TickAfter(time.Hour)
ticker2 := c.TickAfter(time.Hour)

// We don't expect this one to tick.
ticker3 := c.TickAfter(2 * time.Hour)

tickOrTimeOut := func(ticker <-chan time.Time, expectTick bool) {
tick := false
select {
case <-ticker:
tick = true
case <-time.After(time.Millisecond):
}

if tick != expectTick {
t.Fatalf("expected tick: %v, ticked: %v", expectTick, tick)
}
}

tickOrTimeOut(ticker0, true)
tickOrTimeOut(ticker1, false)
tickOrTimeOut(ticker2, false)
tickOrTimeOut(ticker3, false)

c.SetTime(c.Now().Add(time.Hour))

tickOrTimeOut(ticker1, true)
tickOrTimeOut(ticker2, true)
tickOrTimeOut(ticker3, false)
}
2 changes: 2 additions & 0 deletions htlcswitch/mock.go
Expand Up @@ -22,6 +22,7 @@ import (
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input"
Expand Down Expand Up @@ -792,6 +793,7 @@ func newMockRegistry(minDelta uint32) *mockInvoiceRegistry {

registry := invoices.NewRegistry(
cdb,
invoices.NewInvoiceExpiryWatcher(clock.NewDefaultClock()),
&invoices.RegistryConfig{
FinalCltvRejectDelta: 5,
},
Expand Down

0 comments on commit becbbac

Please sign in to comment.