-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
We'll use this AMP-specific ShardTracker for AMP payments. It will be used to derive hashes for each HTLC attempt using the underlying AMP derivation scheme.
- Loading branch information
Showing
2 changed files
with
260 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
package amp | ||
|
||
import ( | ||
"crypto/rand" | ||
"encoding/binary" | ||
"fmt" | ||
"sync" | ||
|
||
"github.com/lightningnetwork/lnd/lntypes" | ||
"github.com/lightningnetwork/lnd/lnwire" | ||
"github.com/lightningnetwork/lnd/record" | ||
"github.com/lightningnetwork/lnd/routing/shards" | ||
) | ||
|
||
// Shard is an implementation of the shards.PaymentShards interface specific | ||
// to AMP payments. | ||
type Shard struct { | ||
child *Child | ||
mpp *record.MPP | ||
amp *record.AMP | ||
} | ||
|
||
// A compile time check to ensure Shard implements the shards.PaymentShard | ||
// interface. | ||
var _ shards.PaymentShard = (*Shard)(nil) | ||
|
||
// Hash returns the hash used for the HTLC representing this AMP shard. | ||
func (s *Shard) Hash() lntypes.Hash { | ||
return s.child.Hash | ||
} | ||
|
||
// MPP returns any extra MPP records that should be set for the final hop on | ||
// the route used by this shard. | ||
func (s *Shard) MPP() *record.MPP { | ||
return s.mpp | ||
} | ||
|
||
// AMP returns any extra AMP records that should be set for the final hop on | ||
// the route used by this shard. | ||
func (s *Shard) AMP() *record.AMP { | ||
return s.amp | ||
} | ||
|
||
// ShardTracker is an implementation of the shards.ShardTracker interface | ||
// that is able to generate payment shards according to the AMP splitting | ||
// algorithm. It can be used to generate new hashes to use for HTLCs, and also | ||
// cancel shares used for failed payment shards. | ||
type ShardTracker struct { | ||
setID [32]byte | ||
paymentAddr [32]byte | ||
totalAmt lnwire.MilliSatoshi | ||
|
||
sharer Sharer | ||
|
||
shards map[uint64]*Child | ||
sync.Mutex | ||
} | ||
|
||
// A compile time check to ensure ShardTracker implements the | ||
// shards.ShardTracker interface. | ||
var _ shards.ShardTracker = (*ShardTracker)(nil) | ||
|
||
// NewShardTracker creates a new shard tracker to use for AMP payments. The | ||
// root shard, setID, payment address and total amount must be correctly set in | ||
// order for the TLV options to include with each shard to be created | ||
// correctly. | ||
func NewShardTracker(root, setID, payAddr [32]byte, | ||
totalAmt lnwire.MilliSatoshi) *ShardTracker { | ||
|
||
// Create a new seed sharer from this root. | ||
rootShare := Share(root) | ||
rootSharer := SeedSharerFromRoot(&rootShare) | ||
|
||
return &ShardTracker{ | ||
setID: setID, | ||
paymentAddr: payAddr, | ||
totalAmt: totalAmt, | ||
sharer: rootSharer, | ||
shards: make(map[uint64]*Child), | ||
} | ||
} | ||
|
||
// NewShard registers a new attempt with the ShardTracker and returns a | ||
// new shard representing this attempt. This attempt's shard should be canceled | ||
// if it ends up not being used by the overall payment, i.e. if the attempt | ||
// fails. | ||
func (s *ShardTracker) NewShard(pid uint64, last bool) (shards.PaymentShard, | ||
error) { | ||
|
||
s.Lock() | ||
defer s.Unlock() | ||
|
||
// Use a random child index. | ||
var childIndex [4]byte | ||
if _, err := rand.Read(childIndex[:]); err != nil { | ||
return nil, err | ||
} | ||
idx := binary.BigEndian.Uint32(childIndex[:]) | ||
|
||
// Depending on whether we are requesting the last shard or not, either | ||
// split the current share into two, or get a Child directly from the | ||
// current sharer. | ||
var child *Child | ||
if last { | ||
child = s.sharer.Child(idx) | ||
|
||
// If this was the last shard, set the current share to the | ||
// zero share to indicate we cannot split it further. | ||
s.sharer = s.sharer.Zero() | ||
} else { | ||
left, sharer, err := s.sharer.Split() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
s.sharer = sharer | ||
child = left.Child(idx) | ||
} | ||
|
||
// Track the new child and return the shard. | ||
s.shards[pid] = child | ||
|
||
mpp := record.NewMPP(s.totalAmt, s.paymentAddr) | ||
amp := record.NewAMP( | ||
child.ChildDesc.Share, s.setID, child.ChildDesc.Index, | ||
) | ||
|
||
return &Shard{ | ||
child: child, | ||
mpp: mpp, | ||
amp: amp, | ||
}, nil | ||
} | ||
|
||
// CancelShard cancel's the shard corresponding to the given attempt ID. | ||
func (s *ShardTracker) CancelShard(pid uint64) error { | ||
s.Lock() | ||
defer s.Unlock() | ||
|
||
c, ok := s.shards[pid] | ||
if !ok { | ||
return fmt.Errorf("pid not found") | ||
} | ||
delete(s.shards, pid) | ||
|
||
// Now that we are canceling this shard, we XOR the share back into our | ||
// current share. | ||
s.sharer = s.sharer.Merge(c) | ||
return nil | ||
} | ||
|
||
// GetHash retrieves the hash used by the shard of the given attempt ID. This | ||
// will return an error if the attempt ID is unknown. | ||
func (s *ShardTracker) GetHash(pid uint64) (lntypes.Hash, error) { | ||
s.Lock() | ||
defer s.Unlock() | ||
|
||
c, ok := s.shards[pid] | ||
if !ok { | ||
return lntypes.Hash{}, fmt.Errorf("AMP shard for attempt %v "+ | ||
"not found", pid) | ||
} | ||
|
||
return c.Hash, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
package amp_test | ||
|
||
import ( | ||
"crypto/rand" | ||
"testing" | ||
|
||
"github.com/lightningnetwork/lnd/amp" | ||
"github.com/lightningnetwork/lnd/lnwire" | ||
"github.com/lightningnetwork/lnd/routing/shards" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
// TestAMPShardTracker tests that we can derive and cancel shards at will using | ||
// the AMP shard tracker. | ||
func TestAMPShardTracker(t *testing.T) { | ||
var root, setID, payAddr [32]byte | ||
_, err := rand.Read(root[:]) | ||
require.NoError(t, err) | ||
|
||
_, err = rand.Read(setID[:]) | ||
require.NoError(t, err) | ||
|
||
_, err = rand.Read(payAddr[:]) | ||
require.NoError(t, err) | ||
|
||
var totalAmt lnwire.MilliSatoshi = 1000 | ||
|
||
// Create an AMP shard tracker using the random data we just generated. | ||
tracker := amp.NewShardTracker(root, setID, payAddr, totalAmt) | ||
|
||
// Trying to retrieve a hash for id 0 should result in an error. | ||
_, err = tracker.GetHash(0) | ||
require.Error(t, err) | ||
|
||
// We start by creating 20 shards. | ||
const numShards = 20 | ||
|
||
var shards []shards.PaymentShard | ||
for i := uint64(0); i < numShards; i++ { | ||
s, err := tracker.NewShard(i, i == numShards-1) | ||
require.NoError(t, err) | ||
|
||
// Check that the shards have their payloads set as expected. | ||
require.Equal(t, setID, s.AMP().SetID()) | ||
require.Equal(t, totalAmt, s.MPP().TotalMsat()) | ||
require.Equal(t, payAddr, s.MPP().PaymentAddr()) | ||
|
||
shards = append(shards, s) | ||
} | ||
|
||
// Make sure we can retrieve the hash for all of them. | ||
for i := uint64(0); i < numShards; i++ { | ||
hash, err := tracker.GetHash(i) | ||
require.NoError(t, err) | ||
require.Equal(t, shards[i].Hash(), hash) | ||
} | ||
|
||
// Now cancel half of the shards. | ||
j := 0 | ||
for i := uint64(0); i < numShards; i++ { | ||
if i%2 == 0 { | ||
err := tracker.CancelShard(i) | ||
require.NoError(t, err) | ||
continue | ||
} | ||
|
||
// Keep shard. | ||
shards[j] = shards[i] | ||
j++ | ||
} | ||
shards = shards[:j] | ||
|
||
// Get a new last shard. | ||
s, err := tracker.NewShard(numShards, true) | ||
require.NoError(t, err) | ||
shards = append(shards, s) | ||
|
||
// Finally make sure these shards together can be used to reconstruct | ||
// the children. | ||
childDescs := make([]amp.ChildDesc, len(shards)) | ||
for i, s := range shards { | ||
childDescs[i] = amp.ChildDesc{ | ||
Share: s.AMP().RootShare(), | ||
Index: s.AMP().ChildIndex(), | ||
} | ||
} | ||
|
||
// Using the child descriptors, reconstruct the children. | ||
children := amp.ReconstructChildren(childDescs...) | ||
|
||
// Validate that the derived child preimages match the hash of each shard. | ||
for i, child := range children { | ||
require.Equal(t, shards[i].Hash(), child.Hash) | ||
} | ||
} |