diff --git a/batch.go b/batch.go new file mode 100644 index 0000000..8e87a8c --- /dev/null +++ b/batch.go @@ -0,0 +1,89 @@ +package sphinx + +import "errors" + +// ErrAlreadyCommitted signals that an entry could not be added to the +// batch because it has already been persisted. +var ErrAlreadyCommitted = errors.New("cannot add to batch after committing") + +// Batch is an object used to incrementally construct a set of entries to add to +// the replay log. After construction is completed, it can be added to the log +// using the PutBatch method. +type Batch struct { + // isCommitted denotes whether or not this batch has been successfully + // written to disk. + isCommitted bool + + // id is a unique, caller chosen identifier for this batch. + id []byte + + // entries stores the set of all potential entries that might get + // written to the replay log. Some entries may be skipped after + // examining the on-disk content at the time of commit.. + entries map[uint16]batchEntry + + // replayCache is an in memory lookup-table, which stores the hash + // prefix of entries already added to this batch. This allows a quick + // mechanism for intra-batch duplicate detection. + replayCache map[HashPrefix]struct{} + + // replaySet contains the sequence numbers of all entries that were + // detected as replays. The set is finalized upon writing the batch to + // disk, and merges replays detected by the replay cache and on-disk + // replay log. + replaySet *ReplaySet +} + +// NewBatch initializes an object for constructing a set of entries to +// atomically add to a replay log. Batches are identified by byte slice, which +// allows the caller to safely process the same batch twice and get an +// idempotent result. +func NewBatch(id []byte) *Batch { + return &Batch{ + id: id, + entries: make(map[uint16]batchEntry), + replayCache: make(map[HashPrefix]struct{}), + replaySet: NewReplaySet(), + } +} + +// Put inserts a hash-prefix/CLTV pair into the current batch. This method only +// returns an error in the event that the batch was already committed to disk. +// Decisions regarding whether or not a particular sequence number is a replay +// is ultimately reported via the batch's ReplaySet after committing to disk. +func (b *Batch) Put(seqNum uint16, hashPrefix *HashPrefix, cltv uint32) error { + // Abort if this batch was already written to disk. + if b.isCommitted { + return ErrAlreadyCommitted + } + + // Check to see if this hash prefix is already included in this batch. + // If so, we will opportunistically mark this index as replayed. + if _, ok := b.replayCache[*hashPrefix]; ok { + b.replaySet.Add(seqNum) + return nil + } + + // Otherwise, this is a distinct hash prefix for this batch. Add it to + // our list of entries that we will try to write to disk. Each of these + // entries will be checked again during the commit to see if any other + // on-disk entries contain the same hash prefix. + b.entries[seqNum] = batchEntry{ + hashPrefix: *hashPrefix, + cltv: cltv, + } + + // Finally, add this hash prefix to our in-memory replay cache, this + // will be consulted upon further adds to check for duplicates in the + // same batch. + b.replayCache[*hashPrefix] = struct{}{} + + return nil +} + +// batchEntry is a tuple of a secret's hash prefix and the corresponding CLTV at +// which the onion blob from which the secret was derived expires. +type batchEntry struct { + hashPrefix HashPrefix + cltv uint32 +} diff --git a/bench_test.go b/bench_test.go index 3755db2..23d7b20 100644 --- a/bench_test.go +++ b/bench_test.go @@ -60,19 +60,29 @@ func BenchmarkProcessPacket(b *testing.B) { b.Fatalf("unable to create test route: %v", err) } b.ReportAllocs() + path[0].log.Start() + defer shutdown("0", path[0].log) b.StartTimer() var ( pkt *ProcessedPacket ) for i := 0; i < b.N; i++ { - pkt, err = path[0].ProcessOnionPacket(sphinxPacket, nil) + pkt, err = path[0].ProcessOnionPacket(sphinxPacket, nil, uint32(i)) if err != nil { - b.Fatalf("unable to process packet: %v", err) + b.Fatalf("unable to process packet %d: %v", i, err) } b.StopTimer() - path[0].seenSecrets = make(map[[sharedSecretSize]byte]struct{}) + router := path[0] + shutdown("0", router.log) + path[0] = &Router{ + nodeID: router.nodeID, + nodeAddr: router.nodeAddr, + onionKey: router.onionKey, + log: NewDecayedLog("0", nil), + } + path[0].log.Start() b.StartTimer() } diff --git a/cmd/main.go b/cmd/main.go index 60e0d71..0f7b454 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -76,7 +76,7 @@ func main() { } privkey, _ := btcec.PrivKeyFromBytes(btcec.S256(), binKey) - s := sphinx.NewRouter(privkey, &chaincfg.TestNet3Params) + s := sphinx.NewRouter(privkey, &chaincfg.TestNet3Params, nil) var packet sphinx.OnionPacket err = packet.Decode(bytes.NewBuffer(binMsg)) diff --git a/decayedlog.go b/decayedlog.go new file mode 100644 index 0000000..3b986a9 --- /dev/null +++ b/decayedlog.go @@ -0,0 +1,442 @@ +package sphinx + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "math" + "sync" + "sync/atomic" + + "github.com/boltdb/bolt" + "github.com/lightningnetwork/lnd/chainntnfs" +) + +const ( + // defaultDbDirectory is the default directory where our decayed log + // will store our (sharedHash, CLTV) key-value pairs. + defaultDbDirectory = "sharedhashes" + + // dbPermissions sets the database permissions to user write-and-readable. + dbPermissions = 0600 + + // sharedHashSize is the size in bytes of the keys we will be storing + // in the DecayedLog. It represents the first 20 bytes of a truncated + // sha-256 hash of a secret generated by ECDH. + sharedHashSize = 20 +) + +var ( + // sharedHashBucket is a bucket which houses the first sharedHashSize + // bytes of a received HTLC's hashed shared secret as the key and the HTLC's + // CLTV expiry as the value. + sharedHashBucket = []byte("shared-hash") + + // batchReplayBucket is a bucket that maps batch identifiers to + // serialized ReplaySets. This is used to give idempotency in the event + // that a batch is processed more than once. + batchReplayBucket = []byte("batch-replay") +) + +// HashPrefix is a statically size, 20-byte array containing the prefix +// of a Hash256, and is used to detect duplicate sphinx packets. +type HashPrefix [sharedHashSize]byte + +var ( + // ErrDecayedLogInit is used to indicate a decayed log failed to create + // the proper bucketing structure on startup. + ErrDecayedLogInit = errors.New("unable to initialize decayed log") + + // ErrDecayedLogCorrupted signals that the anticipated bucketing + // structure has diverged since initialization. + ErrDecayedLogCorrupted = errors.New("decayed log structure corrupted") +) + +// ReplayLog is an interface that defines a new on-disk data structure that +// contains a persistent log to enable strong replay protection. The interface +// is general to allow implementations near-complete autonomy. All of these +// calls should be safe for concurrent access. +type ReplayLog interface { + // Start starts up the on-disk persistent log. It returns an error if + // one occurs. + Start() error + + // Stop safely stops the on-disk persistent log. + Stop() error + + // Get retrieves an entry from the persistent log given a []byte. It + // returns the value stored and an error if one occurs. + Get([]byte) (uint32, error) + + // Put stores an entry into the persistent log given a []byte and an + // accompanying purposefully general type. It returns an error if the + // provided hash prefix already exists in the log. + Put(*HashPrefix, uint32) error + + // PutBatch stores + PutBatch(*Batch) (*ReplaySet, error) + + // Delete deletes an entry from the persistent log given []byte + Delete([]byte) error +} + +// DecayedLog implements the PersistLog interface. It stores the first +// sharedHashSize bytes of a sha256-hashed shared secret along with a node's +// CLTV value. It is a decaying log meaning there will be a garbage collector +// to collect entries which are expired according to their stored CLTV value +// and the current block height. DecayedLog wraps boltdb for simplicity and +// batches writes to the database to decrease write contention. +type DecayedLog struct { + started int32 + stopped int32 + + dbPath string + + db *bolt.DB + + notifier chainntnfs.ChainNotifier + + wg sync.WaitGroup + quit chan struct{} +} + +// NewDecayedLog creates a new DecayedLog, which caches recently seen hash +// shared secrets. Entries are evicted as their cltv expires using block epochs +// from the given notifier. +func NewDecayedLog(dbPath string, + notifier chainntnfs.ChainNotifier) *DecayedLog { + + // Use default path for log database + if dbPath == "" { + dbPath = defaultDbDirectory + } + + return &DecayedLog{ + dbPath: dbPath, + notifier: notifier, + quit: make(chan struct{}), + } +} + +// Start opens the database we will be using to store hashed shared secrets. +// It also starts the garbage collector in a goroutine to remove stale +// database entries. +func (d *DecayedLog) Start() error { + if !atomic.CompareAndSwapInt32(&d.started, 0, 1) { + return nil + } + + // Open the boltdb for use. + var err error + if d.db, err = bolt.Open(d.dbPath, dbPermissions, nil); err != nil { + return fmt.Errorf("Could not open boltdb: %v", err) + } + + // Initialize the primary buckets used by the decayed log. + if err := d.initBuckets(); err != nil { + return err + } + + // Start garbage collector. + if d.notifier != nil { + epochClient, err := d.notifier.RegisterBlockEpochNtfn() + if err != nil { + return fmt.Errorf("Unable to register for epoch "+ + "notifications: %v", err) + } + + d.wg.Add(1) + go d.garbageCollector(epochClient) + } + + return nil +} + +// initBuckets initializes the primary buckets used by the decayed log, namely +// the shared hash bucket, and batch replay +func (d *DecayedLog) initBuckets() error { + return d.db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucketIfNotExists(sharedHashBucket) + if err != nil { + return ErrDecayedLogInit + } + + _, err = tx.CreateBucketIfNotExists(batchReplayBucket) + if err != nil { + return ErrDecayedLogInit + } + + return nil + }) +} + +// Stop halts the garbage collector and closes boltdb. +func (d *DecayedLog) Stop() error { + if !atomic.CompareAndSwapInt32(&d.stopped, 0, 1) { + return nil + } + + // Stop garbage collector. + close(d.quit) + + d.wg.Wait() + + // Close boltdb. + d.db.Close() + + return nil +} + +// garbageCollector deletes entries from sharedHashBucket whose expiry height +// has already past. This function MUST be run as a goroutine. +func (d *DecayedLog) garbageCollector(epochClient *chainntnfs.BlockEpochEvent) { + defer d.wg.Done() + defer epochClient.Cancel() + + for { + select { + case epoch, ok := <-epochClient.Epochs: + if !ok { + // Block epoch was canceled, shutting down. + sphxLog.Infof("Block epoch canceled, " + + "decaying hash log shutting down") + return + } + + // Perform a bout of garbage collection using the + // epoch's block height. + height := uint32(epoch.Height) + if err := d.gcExpiredHashes(height); err != nil { + sphxLog.Errorf("unable to expire hashes at "+ + "height=%d", height) + } + + case <-d.quit: + // Received shutdown request. + sphxLog.Infof("Decaying hash log received " + + "shutdown request") + return + } + } +} + +// gcExpiredHashes purges the decaying log of all entries whose CLTV expires +// below the provided height. +func (d *DecayedLog) gcExpiredHashes(height uint32) error { + return d.db.Batch(func(tx *bolt.Tx) error { + // Grab the shared hash bucket + sharedHashes := tx.Bucket(sharedHashBucket) + if sharedHashes == nil { + return fmt.Errorf("sharedHashBucket " + + "is nil") + } + + var expiredCltv [][]byte + if err := sharedHashes.ForEach(func(k, v []byte) error { + // Deserialize the CLTV value for this entry. + cltv := uint32(binary.BigEndian.Uint32(v)) + + if cltv < height { + // This CLTV is expired. We must add it to an + // array which we'll loop over and delete every + // hash contained from the db. + expiredCltv = append(expiredCltv, k) + } + + return nil + }); err != nil { + return err + } + + // Delete every item in the array. This must + // be done explicitly outside of the ForEach + // function for safety reasons. + for _, hash := range expiredCltv { + err := sharedHashes.Delete(hash) + if err != nil { + return err + } + } + + return nil + }) +} + +// hashSharedSecret Sha-256 hashes the shared secret and returns the first +// sharedHashSize bytes of the hash. +func hashSharedSecret(sharedSecret *Hash256) HashPrefix { + // Sha256 hash of sharedSecret + h := sha256.New() + h.Write(sharedSecret[:]) + + var sharedHash HashPrefix + + // Copy bytes to sharedHash + copy(sharedHash[:], h.Sum(nil)) + return sharedHash +} + +// Delete removes a key-pair from the +// sharedHashBucket. +func (d *DecayedLog) Delete(hash []byte) error { + return d.db.Batch(func(tx *bolt.Tx) error { + sharedHashes := tx.Bucket(sharedHashBucket) + if sharedHashes == nil { + return ErrDecayedLogCorrupted + } + + return sharedHashes.Delete(hash) + }) +} + +// Get retrieves the CLTV of a processed HTLC given the first 20 bytes of the +// Sha-256 hash of the shared secret. +func (d *DecayedLog) Get(hash []byte) (uint32, error) { + // math.MaxUint32 is returned when Get did not retrieve a value. + // This was chosen because it's not feasible for a CLTV to be this high. + var value uint32 = math.MaxUint32 + + err := d.db.View(func(tx *bolt.Tx) error { + // Grab the shared hash bucket which stores the mapping from + // truncated sha-256 hashes of shared secrets to CLTV's. + sharedHashes := tx.Bucket(sharedHashBucket) + if sharedHashes == nil { + return fmt.Errorf("sharedHashes is nil, could " + + "not retrieve CLTV value") + } + + // Retrieve the bytes which represents the CLTV + valueBytes := sharedHashes.Get(hash) + if valueBytes == nil { + return nil + } + + // The first 4 bytes represent the CLTV, store it in value. + value = uint32(binary.BigEndian.Uint32(valueBytes)) + + return nil + }) + if err != nil { + return value, err + } + + return value, nil +} + +// Put stores a shared secret hash as the key and the CLTV as the value. +func (d *DecayedLog) Put(hash *HashPrefix, cltv uint32) error { + // Optimisitically serialize the cltv value into the scratch buffer. + var scratch [4]byte + binary.BigEndian.PutUint32(scratch[:], cltv) + + return d.db.Batch(func(tx *bolt.Tx) error { + sharedHashes := tx.Bucket(sharedHashBucket) + if sharedHashes == nil { + return ErrDecayedLogCorrupted + } + + // Check to see if this hash prefix has been recorded before. If + // a value is found, this packet is being replayed. + valueBytes := sharedHashes.Get(hash[:]) + if valueBytes != nil { + return ErrReplayedPacket + } + + return sharedHashes.Put(hash[:], scratch[:]) + }) +} + +// PutBatch accepts a pending batch of hashed secret entries to write to disk. +// Each hashed secret is inserted with a corresponding time value, dictating +// when the entry will be evicted from the log. +// NOTE: This method enforces idempotency by writing the replay set obtained +// from the first attempt for a particular batch ID, and decoding the return +// value to subsequent calls. For the indices of the replay set to be aligned +// properly, the batch MUST be constructed identically to the first attempt, +// pruning will cause the indices to become invalid. +func (d *DecayedLog) PutBatch(b *Batch) (*ReplaySet, error) { + // Since batched boltdb txns may be executed multiple times before + // succeeding, we will create a new replay set for each invocation to + // avoid any side-effects. If the txn is successful, this replay set + // will be merged with the replay set computed during batch construction + // to generate the complete replay set. If this batch was previously + // processed, the replay set will be deserialized from disk. + var replays *ReplaySet + if err := d.db.Batch(func(tx *bolt.Tx) error { + sharedHashes := tx.Bucket(sharedHashBucket) + if sharedHashes == nil { + return ErrDecayedLogCorrupted + } + + // Load the batch replay bucket, which will be used to either + // retrieve the result of previously processing this batch, or + // to write the result of this operation. + batchReplayBkt := tx.Bucket(batchReplayBucket) + if batchReplayBkt == nil { + return ErrDecayedLogCorrupted + } + + // Check for the existence of this batch's id in the replay + // bucket. If a non-nil value is found, this indicates that we + // have already processed this batch before. We deserialize the + // resulting and return it to ensure calls to put batch are + // idempotent. + replayBytes := batchReplayBkt.Get(b.id) + if replayBytes != nil { + replays = &ReplaySet{} + return replays.Decode(bytes.NewReader(replayBytes)) + } + + // The CLTV will be stored into scratch and then stored into the + // sharedHashBucket. + var scratch [4]byte + + replays = NewReplaySet() + for seqNum, entry := range b.entries { + // Retrieve the bytes which represents the CLTV + valueBytes := sharedHashes.Get(entry.hashPrefix[:]) + if valueBytes != nil { + replays.Add(seqNum) + continue + } + + // Serialize the cltv value and write an entry keyed by + // the hash prefix. + binary.BigEndian.PutUint32(scratch[:], entry.cltv) + err := sharedHashes.Put(entry.hashPrefix[:], scratch[:]) + if err != nil { + return err + } + } + + // Merge the replay set computed from checking the on-disk + // entries with the in-batch replays computed during this + // batch's construction. + replays.Merge(b.replaySet) + + // Write the replay set under the batch identifier to the batch + // replays bucket. This can be used during recovery to test (1) + // that a particular batch was successfully processed and (2) + // recover the indexes of the adds that were rejected as + // replays. + var replayBuf bytes.Buffer + if err := replays.Encode(&replayBuf); err != nil { + return err + } + + return batchReplayBkt.Put(b.id, replayBuf.Bytes()) + }); err != nil { + return nil, err + } + + b.replaySet = replays + b.isCommitted = true + + return replays, nil +} + +// A compile time check to see if DecayedLog adheres to the PersistLog +// interface. +var _ ReplayLog = (*DecayedLog)(nil) diff --git a/decayedlog_test.go b/decayedlog_test.go new file mode 100644 index 0000000..91fcc63 --- /dev/null +++ b/decayedlog_test.go @@ -0,0 +1,336 @@ +package sphinx + +import ( + "math" + "testing" + "time" + + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/roasbeef/btcd/btcec" + "github.com/roasbeef/btcd/chaincfg/chainhash" + "github.com/roasbeef/btcd/wire" +) + +const ( + cltv uint32 = 100000 +) + +var ( + // Bytes of a private key + key = [32]byte{ + 0x81, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x68, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0xd, 0xe7, 0x93, 0xe4, 0xb7, 0x25, 0xb8, 0x4d, + 0x1e, 0xb, 0x4c, 0xf9, 0x9e, 0xc5, 0x8c, 0xe9, + } +) + +type mockNotifier struct { + confChannel chan *chainntnfs.TxConfirmation + epochChan chan *chainntnfs.BlockEpoch +} + +func (m *mockNotifier) RegisterBlockEpochNtfn() (*chainntnfs.BlockEpochEvent, error) { + return &chainntnfs.BlockEpochEvent{ + Epochs: m.epochChan, + Cancel: func() {}, + }, nil +} + +func (m *mockNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash, numConfs, + heightHint uint32) (*chainntnfs.ConfirmationEvent, error) { + return nil, nil +} + +func (m *mockNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint, + heightHint uint32) (*chainntnfs.SpendEvent, error) { + return nil, nil +} + +func (m *mockNotifier) Start() error { + return nil +} + +func (m *mockNotifier) Stop() error { + return nil +} + +// startup sets up the DecayedLog and possibly the garbage collector. +func startup(notifier bool) (ReplayLog, *mockNotifier, HashPrefix, error) { + var log ReplayLog + var chainNotifier *mockNotifier + var hashedSecret HashPrefix + if notifier { + + // Create the MockNotifier which triggers the garbage collector + chainNotifier = &mockNotifier{ + epochChan: make(chan *chainntnfs.BlockEpoch, 1), + } + + // Initialize the DecayedLog object + log = NewDecayedLog("tempdir", chainNotifier) + } else { + // Initialize the DecayedLog object + log = NewDecayedLog("tempdir", nil) + } + + // Open the channeldb (start the garbage collector) + err := log.Start() + if err != nil { + return nil, nil, hashedSecret, err + } + + // Create a new private key on elliptic curve secp256k1 + priv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + return nil, nil, hashedSecret, err + } + + // Generate a public key from the key bytes + _, testPub := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + + // Generate a shared secret with the public and private keys we made + secret := generateSharedSecret(testPub, priv) + + // Create the hashedSecret given the shared secret we just generated. + // This is the first 20 bytes of the Sha-256 hash of the shared secret. + // This is used as a key to retrieve the cltv value. + hashedSecret = hashSharedSecret(&secret) + + return log, chainNotifier, hashedSecret, nil +} + +// TestDecayedLogGarbageCollector tests the ability of the garbage collector +// to delete expired cltv values every time a block is received. Expired cltv +// values are cltv values that are < current block height. +func TestDecayedLogGarbageCollector(t *testing.T) { + d, notifier, hashedSecret, err := startup(true) + if err != nil { + t.Fatalf("Unable to start up DecayedLog: %v", err) + } + defer shutdown("tempdir", d) + + // Store in the sharedHashBucket. + err = d.Put(&hashedSecret, cltv) + if err != nil { + t.Fatalf("Unable to store in channeldb: %v", err) + } + + // Wait for database write (GC is in a goroutine) + time.Sleep(500 * time.Millisecond) + + // Send block notifications to garbage collector. The garbage collector + // should remove the entry by block 100001. + + // Send block 100000 + notifier.epochChan <- &chainntnfs.BlockEpoch{ + Height: 100000, + } + + // Assert that hashedSecret is still in the sharedHashBucket + val, err := d.Get(hashedSecret[:]) + if err != nil { + t.Fatalf("Get failed - received an error upon Get: %v", err) + } + + if val != cltv { + t.Fatalf("GC incorrectly deleted CLTV") + } + + // Send block 100001 (expiry block) + notifier.epochChan <- &chainntnfs.BlockEpoch{ + Height: 100001, + } + + // Wait for database write (GC is in a goroutine) + time.Sleep(500 * time.Millisecond) + + // Assert that hashedSecret is not in the sharedHashBucket + val, err = d.Get(hashedSecret[:]) + if err != nil { + t.Fatalf("Get failed - received an error upon Get: %v", err) + } + + if val != math.MaxUint32 { + t.Fatalf("CLTV was not deleted") + } +} + +// TestDecayedLogPersistentGarbageCollector tests the persistence property of +// the garbage collector. The garbage collector will be restarted immediately and +// a block that expires the stored CLTV value will be sent to the ChainNotifier. +// We test that this causes the pair to be deleted even +// on GC restarts. +func TestDecayedLogPersistentGarbageCollector(t *testing.T) { + d, _, hashedSecret, err := startup(true) + if err != nil { + t.Fatalf("Unable to start up DecayedLog: %v", err) + } + defer shutdown("tempdir", d) + + // Store in the sharedHashBucket + if err = d.Put(&hashedSecret, cltv); err != nil { + t.Fatalf("Unable to store in channeldb: %v", err) + } + + // Wait for database write (GC is in a goroutine) + time.Sleep(500 * time.Millisecond) + + // Shut down DecayedLog and the garbage collector along with it. + d.Stop() + + d2, notifier2, hashedSecret2, err := startup(true) + if err != nil { + t.Fatalf("Unable to restart DecayedLog: %v", err) + } + defer shutdown("tempdir", d2) + + // Send a block notification to the garbage collector that expires + // the stored CLTV. + notifier2.epochChan <- &chainntnfs.BlockEpoch{ + Height: int32(100001), + } + + // Wait for database write (GC is in a goroutine) + time.Sleep(500 * time.Millisecond) + + // Assert that hashedSecret is not in the sharedHashBucket + val, err := d2.Get(hashedSecret2[:]) + if err != nil { + t.Fatalf("Delete failed - received an error upon Get: %v", err) + } + + if val != math.MaxUint32 { + t.Fatalf("cltv was not deleted") + } +} + +// TestDecayedLogInsertionAndRetrieval inserts a cltv value into the +// sharedHashBucket and then deletes it and finally asserts that we can no +// longer retrieve it. +func TestDecayedLogInsertionAndDeletion(t *testing.T) { + d, _, hashedSecret, err := startup(false) + if err != nil { + t.Fatalf("Unable to start up DecayedLog: %v", err) + } + defer shutdown("tempdir", d) + + // Store in the sharedHashBucket. + err = d.Put(&hashedSecret, cltv) + if err != nil { + t.Fatalf("Unable to store in channeldb: %v", err) + } + + // Delete hashedSecret from the sharedHashBucket. + err = d.Delete(hashedSecret[:]) + if err != nil { + t.Fatalf("Unable to delete from channeldb: %v", err) + } + + // Assert that hashedSecret is not in the sharedHashBucket + val, err := d.Get(hashedSecret[:]) + if err != nil { + t.Fatalf("Delete failed - received the wrong error message: %v", err) + } + + if val != math.MaxUint32 { + t.Fatalf("cltv was not deleted") + } + +} + +// TestDecayedLogStartAndStop tests for persistence. The DecayedLog is started, +// a cltv value is stored in the sharedHashBucket, and then it the DecayedLog +// is stopped. The DecayedLog is then started up again and we test that the +// cltv value is indeed still stored in the sharedHashBucket. We then delete +// the cltv value and check that it persists upon startup. +func TestDecayedLogStartAndStop(t *testing.T) { + d, _, hashedSecret, err := startup(false) + if err != nil { + t.Fatalf("Unable to start up DecayedLog: %v", err) + } + defer shutdown("tempdir", d) + + // Store in the sharedHashBucket. + err = d.Put(&hashedSecret, cltv) + if err != nil { + t.Fatalf("Unable to store in channeldb: %v", err) + } + + // Shutdown the DecayedLog's channeldb + d.Stop() + + d2, _, hashedSecret2, err := startup(false) + if err != nil { + t.Fatalf("Unable to restart DecayedLog: %v", err) + } + defer shutdown("tempdir", d2) + + // Retrieve the stored cltv value given the hashedSecret key. + value, err := d2.Get(hashedSecret[:]) + if err != nil { + t.Fatalf("Unable to retrieve from channeldb: %v", err) + } + + // Check that the original cltv value matches the retrieved cltv + // value. + if cltv != value { + t.Fatalf("Value retrieved doesn't match value stored") + } + + // Delete hashedSecret from sharedHashBucket + err = d2.Delete(hashedSecret2[:]) + if err != nil { + t.Fatalf("Unable to delete from channeldb: %v", err) + } + + // Shutdown the DecayedLog's channeldb + d2.Stop() + + d3, _, hashedSecret3, err := startup(false) + if err != nil { + t.Fatalf("Unable to restart DecayedLog: %v", err) + } + defer shutdown("tempdir", d3) + + // Assert that hashedSecret is not in the sharedHashBucket + val, err := d3.Get(hashedSecret3[:]) + if err != nil { + t.Fatalf("Delete failed: %v", err) + } + + if val != math.MaxUint32 { + t.Fatalf("cltv was not deleted") + } + +} + +// TestDecayedLogStorageAndRetrieval stores a cltv value and then retrieves it +// via the nested sharedHashBucket and finally asserts that the original stored +// and retrieved cltv values are equal. +func TestDecayedLogStorageAndRetrieval(t *testing.T) { + d, _, hashedSecret, err := startup(false) + if err != nil { + t.Fatalf("Unable to start up DecayedLog: %v", err) + } + defer shutdown("tempdir", d) + + // Store in the sharedHashBucket + err = d.Put(&hashedSecret, cltv) + if err != nil { + t.Fatalf("Unable to store in channeldb: %v", err) + } + + // Retrieve the stored cltv value given the hashedSecret key. + value, err := d.Get(hashedSecret[:]) + if err != nil { + t.Fatalf("Unable to retrieve from channeldb: %v", err) + } + + // If the original cltv value does not match the value retrieved, + // then the test failed. + if cltv != value { + t.Fatalf("Value retrieved doesn't match value stored") + } + +} diff --git a/glide.lock b/glide.lock index f73c478..dd8c50b 100644 --- a/glide.lock +++ b/glide.lock @@ -1,33 +1,44 @@ -hash: edf51fbb3ee6f3e3f9e39d5ffc33739e8a2817f6c990042840594e867a8bc94a -updated: 2017-06-24T21:22:50.234773431+03:00 +hash: 2d47f5b9766af60984cadf5ba0ba9a21e085e4715cafcbdae080a7b620c38b0d +updated: 2018-02-20T23:10:22.589085-08:00 imports: - name: github.com/aead/chacha20 version: d31a916ded42d1640b9d89a26f8abd53cc96790c subpackages: - chacha -- name: github.com/btcsuite/fastsha256 - version: 637e656429416087660c84436a2a035d69d54e2e +- name: github.com/boltdb/bolt + version: 2f1ce7a837dcb8da3ec595b1dac9d0632f0f99e8 +- name: github.com/btcsuite/btclog + version: 84c8d2346e9fc8c7b947e243b9c24e6df9fd206a - name: github.com/btcsuite/golangcrypto version: 53f62d9b43e87a6c56975cf862af7edf33a8d0df subpackages: - ripemd160 - name: github.com/go-errors/errors version: 8fa88b06e5974e97fbf9899a7f86a344bfd1f105 +- name: github.com/lightningnetwork/lnd + version: 1c3dbb25434ef9f4d3dedc226dea41755e1621e7 + subpackages: + - chainntnfs - name: github.com/roasbeef/btcd - version: 707a14a79daeb2440fe92feaeceb0fae68ab3e9b + version: e6807bc4dd5ddbb95b4ab163f6dd61e4ad79463a subpackages: - btcec - chaincfg - chaincfg/chainhash - wire - name: github.com/roasbeef/btcutil - version: d347e49b656d2a7f6d06cc9e2daebc5acded5728 + version: c3ff179366044979fb9856c2feb79bd4c2184c7a subpackages: - base58 + - bech32 - name: golang.org/x/crypto - version: 459e26527287adbc2adcc5d0d49abff9a5f315a7 + version: 9419663f5a44be8b34ca85f08abc5fe1be11f8a3 subpackages: - ripemd160 +- name: golang.org/x/sys + version: ab9e364efd8b52800ff7ee48a9ffba4e0ed78dfb + subpackages: + - unix testImports: - name: github.com/davecgh/go-spew version: 346938d642f2ec3594ed81d874461961cd0faa76 diff --git a/glide.yaml b/glide.yaml index 059d502..16da9db 100644 --- a/glide.yaml +++ b/glide.yaml @@ -1,11 +1,17 @@ package: github.com/lightningnetwork/lightning-onion import: +- package: github.com/boltdb/bolt + version: ^1.2.1 - package: github.com/aead/chacha20 version: d31a916ded42d1640b9d89a26f8abd53cc96790c +- package: github.com/lightningnetwork/lnd + subpackages: + - chainntnfs - package: github.com/roasbeef/btcd subpackages: - btcec - chaincfg + - wire - package: github.com/roasbeef/btcutil - package: golang.org/x/crypto subpackages: diff --git a/log.go b/log.go new file mode 100644 index 0000000..c9804c9 --- /dev/null +++ b/log.go @@ -0,0 +1,42 @@ +package sphinx + +import "github.com/btcsuite/btclog" + +// sphxLog is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var sphxLog btclog.Logger + +// The default amount of logging is none. +func init() { + DisableLog() +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until UseLogger is called. +func DisableLog() { + sphxLog = btclog.Disabled +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + sphxLog = logger +} + +// logClosure is used to provide a closure over expensive logging operations +// so don't have to be performed when the logging level doesn't warrant it. +type logClosure func() string + +// String invokes the underlying function and returns the result. +func (c logClosure) String() string { + return c() +} + +// newLogClosure returns a new closure over a function that returns a string +// which itself provides a Stringer interface so that it can be used with the +// logging system. +func newLogClosure(c func() string) logClosure { + return logClosure(c) +} diff --git a/obfuscation.go b/obfuscation.go index 2d74d93..d7f902b 100644 --- a/obfuscation.go +++ b/obfuscation.go @@ -13,7 +13,7 @@ import ( // onionEncrypt obfuscates the data with compliance with BOLT#4. As we use a // stream cipher, calling onionEncrypt on an already encrypted piece of data // will decrypt it. -func onionEncrypt(sharedSecret [sha256.Size]byte, data []byte) []byte { +func onionEncrypt(sharedSecret *Hash256, data []byte) []byte { p := make([]byte, len(data)) @@ -27,7 +27,7 @@ func onionEncrypt(sharedSecret [sha256.Size]byte, data []byte) []byte { // OnionErrorEncrypter is a struct that's used to implement onion error // encryption as defined within BOLT0004. type OnionErrorEncrypter struct { - sharedSecret [sha256.Size]byte + sharedSecret Hash256 } // NewOnionErrorEncrypter creates new instance of the onion encryper backed by @@ -59,14 +59,14 @@ func NewOnionErrorEncrypter(router *Router, // failure and its origin. func (o *OnionErrorEncrypter) EncryptError(initial bool, data []byte) []byte { if initial { - umKey := generateKey("um", o.sharedSecret) + umKey := generateKey("um", &o.sharedSecret) hash := hmac.New(sha256.New, umKey[:]) hash.Write(data) h := hash.Sum(nil) data = append(h, data...) } - return onionEncrypt(o.sharedSecret, data) + return onionEncrypt(&o.sharedSecret, data) } // Encode writes the encrypter's shared secret to the provided io.Writer. @@ -180,7 +180,7 @@ func (o *OnionErrorDecrypter) DecryptError(encryptedData []byte) (*btcec.PublicK var ( sender *btcec.PublicKey msg []byte - dummySecret [sha256.Size]byte + dummySecret Hash256 ) copy(dummySecret[:], bytes.Repeat([]byte{1}, 32)) @@ -188,7 +188,7 @@ func (o *OnionErrorDecrypter) DecryptError(encryptedData []byte) (*btcec.PublicK // away an timing information pertaining to the position in the route // that the error emanated from. for i := 0; i < NumMaxHops; i++ { - var sharedSecret [sha256.Size]byte + var sharedSecret Hash256 // If we've already found the sender, then we'll use our dummy // secret to continue decryption attempts to fill out the rest @@ -202,7 +202,7 @@ func (o *OnionErrorDecrypter) DecryptError(encryptedData []byte) (*btcec.PublicK // With the shared secret, we'll now strip off a layer of // encryption from the encrypted error payload. - encryptedData = onionEncrypt(sharedSecret, encryptedData) + encryptedData = onionEncrypt(&sharedSecret, encryptedData) // Next, we'll need to separate the data, from the MAC itself // so we can reconstruct and verify it. @@ -211,7 +211,7 @@ func (o *OnionErrorDecrypter) DecryptError(encryptedData []byte) (*btcec.PublicK // With the data split, we'll now re-generate the MAC using its // specified key. - umKey := generateKey("um", sharedSecret) + umKey := generateKey("um", &sharedSecret) h := hmac.New(sha256.New, umKey[:]) h.Write(data) diff --git a/replay_set.go b/replay_set.go new file mode 100644 index 0000000..a631a53 --- /dev/null +++ b/replay_set.go @@ -0,0 +1,81 @@ +package sphinx + +import ( + "encoding/binary" + "io" +) + +// ReplaySet is a data structure used to efficiently record the occurrence of +// replays, identified by sequence number, when processing a Batch. Its primary +// functionality includes set construction, membership queries, and merging of +// replay sets. +type ReplaySet struct { + replays map[uint16]struct{} +} + +// NewReplaySet initializes an empty replay set. +func NewReplaySet() *ReplaySet { + return &ReplaySet{ + replays: make(map[uint16]struct{}), + } +} + +// Size returns the number of elements in the replay set. +func (rs *ReplaySet) Size() int { + return len(rs.replays) +} + +// Add inserts the provided index into the replay set. +func (rs *ReplaySet) Add(idx uint16) { + rs.replays[idx] = struct{}{} +} + +// Contains queries the contents of the replay set for membership of a +// particular index. +func (rs *ReplaySet) Contains(idx uint16) bool { + _, ok := rs.replays[idx] + return ok +} + +// Merge adds the contents of the provided replay set to the receiver's set. +func (rs *ReplaySet) Merge(rs2 *ReplaySet) { + for seqNum := range rs2.replays { + rs.Add(seqNum) + } +} + +// Encode serializes the replay set into an io.Writer suitable for storage. The +// replay set can be recovered using Decode. +func (rs *ReplaySet) Encode(w io.Writer) error { + for seqNum := range rs.replays { + err := binary.Write(w, binary.BigEndian, seqNum) + if err != nil { + return err + } + } + + return nil +} + +// Decode reconstructs a replay set given a io.Reader. The byte +// slice is assumed to be even in length, otherwise resulting in failure. +func (rs *ReplaySet) Decode(r io.Reader) error { + for { + // seqNum provides to buffer to read the next uint16 index. + var seqNum uint16 + + err := binary.Read(r, binary.BigEndian, &seqNum) + switch err { + case nil: + // Successful read, proceed. + case io.EOF: + return nil + default: + // Can return ErrShortBuffer or ErrUnexpectedEOF. + return err + } + + // Add this decoded sequence number to the set. + rs.Add(seqNum) + } +} diff --git a/sphinx.go b/sphinx.go index 96027c4..05ec024 100644 --- a/sphinx.go +++ b/sphinx.go @@ -9,9 +9,9 @@ import ( "io" "io/ioutil" "math/big" - "sync" "github.com/aead/chacha20" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/roasbeef/btcd/btcec" "github.com/roasbeef/btcd/chaincfg" "github.com/roasbeef/btcutil" @@ -68,6 +68,10 @@ const ( baseVersion = 0 ) +// Hash256 is a statically sized, 32-byte array, typically containing +// the output of a SHA256 hash. +type Hash256 [sha256.Size]byte + var ( // paddingBytes are the padding bytes used to fill out the remainder of the // unused portion of the per-hop payload. @@ -205,14 +209,14 @@ func (hd *HopData) Decode(r io.Reader) error { // generateSharedSecrets by the given nodes pubkeys, generates the shared // secrets. func generateSharedSecrets(paymentPath []*btcec.PublicKey, - sessionKey *btcec.PrivateKey) [][sha256.Size]byte { + sessionKey *btcec.PrivateKey) []Hash256 { // Each hop performs ECDH with our ephemeral key pair to arrive at a // shared secret. Additionally, each hop randomizes the group element // for the next hop by multiplying it by the blinding factor. This way // we only need to transmit a single group element, and hops can't link // a session back to us if they have several nodes in the path. numHops := len(paymentPath) - hopSharedSecrets := make([][sha256.Size]byte, numHops) + hopSharedSecrets := make([]Hash256, numHops) // Compute the triplet for the first hop outside of the main loop. // Within the loop each new triplet will be computed recursively based @@ -299,8 +303,8 @@ func NewOnionPacket(paymentPath []*btcec.PublicKey, sessionKey *btcec.PrivateKey // We'll derive the two keys we need for each hop in order to: // generate our stream cipher bytes for the mixHeader, and // calculate the MAC over the entire constructed packet. - rhoKey := generateKey("rho", hopSharedSecrets[i]) - muKey := generateKey("mu", hopSharedSecrets[i]) + rhoKey := generateKey("rho", &hopSharedSecrets[i]) + muKey := generateKey("mu", &hopSharedSecrets[i]) // The HMAC for the final hop is simply zeroes. This allows the // last hop to recognize that it is the destination for a @@ -377,13 +381,13 @@ func rightShift(slice []byte, num int) { // "filler" bytes produced by this function at the last hop. Using this // methodology, the size of the field stays constant at each hop. func generateHeaderPadding(key string, numHops int, hopSize int, - sharedSecrets [][sharedSecretSize]byte) []byte { + sharedSecrets []Hash256) []byte { filler := make([]byte, (numHops-1)*hopSize) for i := 1; i < numHops; i++ { totalFillerSize := ((NumMaxHops - i) + 1) * hopSize - streamKey := generateKey(key, sharedSecrets[i-1]) + streamKey := generateKey(key, &sharedSecrets[i-1]) streamBytes := generateCipherStream(streamKey, numStreamBytes) xor(filler, filler, streamBytes[totalFillerSize:totalFillerSize+i*hopSize]) @@ -484,7 +488,7 @@ func xor(dst, a, b []byte) int { // construction/processing based off of the denoted keyType. Within Sphinx // various keys are used within the same onion packet for padding generation, // MAC generation, and encryption/decryption. -func generateKey(keyType string, sharedKey [sharedSecretSize]byte) [keyLen]byte { +func generateKey(keyType string, sharedKey *Hash256) [keyLen]byte { mac := hmac.New(sha256.New, []byte(keyType)) mac.Write(sharedKey[:]) h := mac.Sum(nil) @@ -515,12 +519,14 @@ func generateCipherStream(key [keyLen]byte, numBytes uint) []byte { // computeBlindingFactor for the next hop given the ephemeral pubKey and // sharedSecret for this hop. The blinding factor is computed as the // sha-256(pubkey || sharedSecret). -func computeBlindingFactor(hopPubKey *btcec.PublicKey, hopSharedSecret []byte) [sha256.Size]byte { +func computeBlindingFactor(hopPubKey *btcec.PublicKey, + hopSharedSecret []byte) Hash256 { + sha := sha256.New() sha.Write(hopPubKey.SerializeCompressed()) sha.Write(hopSharedSecret) - var hash [sha256.Size]byte + var hash Hash256 copy(hash[:], sha.Sum(nil)) return hash } @@ -545,7 +551,7 @@ func blindBaseElement(blindingFactor []byte) *btcec.PublicKey { // key. We then take the _entire_ point generated by the ECDH operation, // serialize that using a compressed format, then feed the raw bytes through a // single SHA256 invocation. The resulting value is the shared secret. -func generateSharedSecret(pub *btcec.PublicKey, priv *btcec.PrivateKey) [32]byte { +func generateSharedSecret(pub *btcec.PublicKey, priv *btcec.PrivateKey) Hash256 { s := &btcec.PublicKey{} x, y := btcec.S256().ScalarMult(pub.X, pub.Y, priv.D.Bytes()) s.X = x @@ -620,14 +626,13 @@ type Router struct { onionKey *btcec.PrivateKey - sync.RWMutex - - seenSecrets map[[sharedSecretSize]byte]struct{} + log ReplayLog } // NewRouter creates a new instance of a Sphinx onion Router given the node's // currently advertised onion private key, and the target Bitcoin network. -func NewRouter(nodeKey *btcec.PrivateKey, net *chaincfg.Params) *Router { +func NewRouter(dbPath string, nodeKey *btcec.PrivateKey, net *chaincfg.Params, + chainNotifier chainntnfs.ChainNotifier) *Router { var nodeID [addressSize]byte copy(nodeID[:], btcutil.Hash160(nodeKey.PubKey().SerializeCompressed())) @@ -647,10 +652,22 @@ func NewRouter(nodeKey *btcec.PrivateKey, net *chaincfg.Params) *Router { }, // TODO(roasbeef): replace instead with bloom filter? // * https://moderncrypto.org/mail-archive/messaging/2015/001911.html - seenSecrets: make(map[[sharedSecretSize]byte]struct{}), + log: NewDecayedLog(dbPath, chainNotifier), } } +// Start starts / opens the DecayedLog's channeldb and its accompanying +// garbage collector goroutine. +func (r *Router) Start() error { + return r.log.Start() +} + +// Stop stops / closes the DecayedLog's channeldb and its accompanying +// garbage collector goroutine. +func (r *Router) Stop() { + r.log.Stop() +} + // ProcessOnionPacket processes an incoming onion packet which has been forward // to the target Sphinx router. If the encoded ephemeral key isn't on the // target Elliptic Curve, then the packet is rejected. Similarly, if the @@ -660,25 +677,60 @@ func NewRouter(nodeKey *btcec.PrivateKey, net *chaincfg.Params) *Router { // In the case of a successful packet processing, and ProcessedPacket struct is // returned which houses the newly parsed packet, along with instructions on // what to do next. -func (r *Router) ProcessOnionPacket(onionPkt *OnionPacket, assocData []byte) (*ProcessedPacket, error) { - dhKey := onionPkt.EphemeralKey - routeInfo := onionPkt.RoutingInfo - headerMac := onionPkt.HeaderMAC +func (r *Router) ProcessOnionPacket(onionPkt *OnionPacket, + assocData []byte, incomingCltv uint32) (*ProcessedPacket, error) { + // Compute the shared secret for this onion packet. sharedSecret, err := r.generateSharedSecret(onionPkt.EphemeralKey) if err != nil { return nil, err } - // In order to mitigate replay attacks, if we've seen this particular - // shared secret before, cease processing and just drop this forwarding - // message. - r.RLock() - if _, ok := r.seenSecrets[sharedSecret]; ok { - r.RUnlock() - return nil, ErrReplayedPacket + // Additionally, compute the hash prefix of the shared secret, which + // will serve as an identifier for detecting replayed packets. + hashPrefix := hashSharedSecret(&sharedSecret) + + // Continue to optimistically process this packet, deferring replay + // protection until the end to reduce the penalty of multiple IO + // operations. + packet, err := processOnionPacket(onionPkt, &sharedSecret, assocData) + if err != nil { + return nil, err } - r.RUnlock() + + // Atomically compare this hash prefix with the contents of the on-disk + // log, persisting it only if this entry was not detected as a replay. + if err := r.log.Put(&hashPrefix, incomingCltv); err != nil { + return nil, err + } + + return packet, nil +} + +// ReconstructOnionPacket rederives the subsequent onion packet. +// NOTE: This method does not do any sort of replay protection, and should only +// be used to reconstruct packets that were successfully processed previously. +func (r *Router) ReconstructOnionPacket(onionPkt *OnionPacket, + assocData []byte) (*ProcessedPacket, error) { + + // Compute the shared secret for this onion packet. + sharedSecret, err := r.generateSharedSecret(onionPkt.EphemeralKey) + if err != nil { + return nil, err + } + + return processOnionPacket(onionPkt, &sharedSecret, assocData) +} + +// processOnionPacket performs the primary key derivation and handling of onion +// packets. The processed packets returned from this method should only be used +// if the packet was not flagged as a replayed packet. +func processOnionPacket(onionPkt *OnionPacket, + sharedSecret *Hash256, assocData []byte) (*ProcessedPacket, error) { + + dhKey := onionPkt.EphemeralKey + routeInfo := onionPkt.RoutingInfo + headerMac := onionPkt.HeaderMAC // Using the derived shared secret, ensure the integrity of the routing // information by checking the attached MAC without leaking timing @@ -689,24 +741,17 @@ func (r *Router) ProcessOnionPacket(onionPkt *OnionPacket, assocData []byte) (*P return nil, ErrInvalidOnionHMAC } - // The MAC checks out, mark this current shared secret as processed in - // order to mitigate future replay attacks. We need to check to see if - // we already know the secret again since a replay might have happened - // while we were checking the MAC. - r.Lock() - if _, ok := r.seenSecrets[sharedSecret]; ok { - r.RUnlock() - return nil, ErrReplayedPacket - } - r.seenSecrets[sharedSecret] = struct{}{} - r.Unlock() - // Attach the padding zeroes in order to properly strip an encryption // layer off the routing info revealing the routing information for the // next hop. + streamBytes := generateCipherStream( + generateKey("rho", sharedSecret), + numStreamBytes, + ) + headerWithPadding := append(routeInfo[:], + bytes.Repeat([]byte{0}, hopDataSize)...) + var hopInfo [numStreamBytes]byte - streamBytes := generateCipherStream(generateKey("rho", sharedSecret), numStreamBytes) - headerWithPadding := append(routeInfo[:], bytes.Repeat([]byte{0}, hopDataSize)...) xor(hopInfo[:], headerWithPadding, streamBytes) // Randomize the DH group element for the next hop using the @@ -737,7 +782,7 @@ func (r *Router) ProcessOnionPacket(onionPkt *OnionPacket, assocData []byte) (*P // However if the uncovered 'nextMac' is all zeroes, then this // indicates that we're the final hop in the route. var action ProcessCode = MoreHops - if bytes.Compare(bytes.Repeat([]byte{0x00}, hmacSize), hopData.HMAC[:]) == 0 { + if bytes.Compare(zeroHMAC[:], hopData.HMAC[:]) == 0 { action = ExitNode } @@ -749,9 +794,9 @@ func (r *Router) ProcessOnionPacket(onionPkt *OnionPacket, assocData []byte) (*P } // generateSharedSecret generates the shared secret by given ephemeral key. -func (r *Router) generateSharedSecret(dhKey *btcec.PublicKey) ([sha256.Size]byte, +func (r *Router) generateSharedSecret(dhKey *btcec.PublicKey) (Hash256, error) { - var sharedSecret [sha256.Size]byte + var sharedSecret Hash256 // Ensure that the public key is on our curve. if !btcec.S256().IsOnCurve(dhKey.X, dhKey.Y) { @@ -762,3 +807,96 @@ func (r *Router) generateSharedSecret(dhKey *btcec.PublicKey) ([sha256.Size]byte sharedSecret = generateSharedSecret(dhKey, r.onionKey) return sharedSecret, nil } + +// Tx is a transaction consisting of a number of sphinx packets to be atomically +// written to the replay log. This structure helps to coordinate construction of +// the underlying Batch object, and to ensure that the result of the processing +// is idempotent. +type Tx struct { + // batch is the set of packets to be incrementally processed and + // ultimately committed in this transaction + batch *Batch + + // router is a reference to the sphinx router that created this + // transaction. Committing this transaction will utilize this router's + // replay log. + router *Router + + // packets contains a potentially sparse list of optimistically processed + // packets for this batch. The contents of a particular index should + // only be accessed if the index is *not* included in the replay set, or + // otherwise failed any other stage of the processing. + packets []ProcessedPacket +} + +// BeginTxn creates a new transaction that can later be committed back to the +// sphinx router's replay log. +// +// NOTE: The nels parameter should represent the maximum number of that could be +// added to the batch, using sequence numbers that match or exceed this value +// could result in an out-of-bounds panic. +func (r *Router) BeginTxn(id []byte, nels int) *Tx { + return &Tx{ + batch: NewBatch(id), + router: r, + packets: make([]ProcessedPacket, nels), + } +} + +// ProcessOnionPacket processes an incoming onion packet which has been forward +// to the target Sphinx router. If the encoded ephemeral key isn't on the +// target Elliptic Curve, then the packet is rejected. Similarly, if the +// derived shared secret has been seen before the packet is rejected. Finally +// if the MAC doesn't check the packet is again rejected. +// +// In the case of a successful packet processing, and ProcessedPacket struct is +// returned which houses the newly parsed packet, along with instructions on +// what to do next. +func (t *Tx) ProcessOnionPacket(seqNum uint16, onionPkt *OnionPacket, + assocData []byte, incomingCltv uint32) error { + + // Compute the shared secret for this onion packet. + sharedSecret, err := t.router.generateSharedSecret( + onionPkt.EphemeralKey) + if err != nil { + return err + } + + // Additionally, compute the hash prefix of the shared secret, which + // will serve as an identifier for detecting replayed packets. + hashPrefix := hashSharedSecret(&sharedSecret) + + // Continue to optimistically process this packet, deferring replay + // protection until the end to reduce the penalty of multiple IO + // operations. + packet, err := processOnionPacket(onionPkt, &sharedSecret, assocData) + if err != nil { + return err + } + + // Add the hash prefix to pending batch of shared secrets that will be + // written later via Commit(). + err = t.batch.Put(seqNum, &hashPrefix, incomingCltv) + if err != nil { + return err + } + + // If we successfully added this packet to the batch, cache the processed + // packet within the Tx which can be accessed after committing if this + // sequence number does not appear in the replay set. + t.packets[seqNum] = *packet + + return nil +} + +// Commit writes this transaction's batch of sphinx packets to the replay log, +// performing a final check against the log for replays. +func (t *Tx) Commit() ([]ProcessedPacket, *ReplaySet, error) { + if t.batch.isCommitted { + return t.packets, t.batch.replaySet, nil + } + + rs, err := t.router.log.PutBatch(t.batch) + + return t.packets, rs, err +} diff --git a/sphinx_test.go b/sphinx_test.go index 1595dfe..b1002dd 100644 --- a/sphinx_test.go +++ b/sphinx_test.go @@ -4,7 +4,9 @@ import ( "bytes" "encoding/hex" "fmt" + "os" "reflect" + "strconv" "testing" "github.com/davecgh/go-spew/spew" @@ -97,7 +99,9 @@ func newTestRoute(numHops int) ([]*Router, *[]HopData, *OnionPacket, error) { " random key for sphinx node: %v", err) } - nodes[i] = NewRouter(privKey, &chaincfg.MainNetParams) + dbPath := strconv.Itoa(i) + + nodes[i] = NewRouter(dbPath, privKey, &chaincfg.MainNetParams, nil) } // Gather all the pub keys in the path. @@ -177,6 +181,13 @@ func TestBolt4Packet(t *testing.T) { } } +// shutdown deletes the temporary directory that the test database uses +// and handles closing the database. +func shutdown(dir string, d ReplayLog) { + d.Stop() + os.RemoveAll(dir) +} + func TestSphinxCorrectness(t *testing.T) { nodes, hopDatas, fwdMsg, err := newTestRoute(NumMaxHops) if err != nil { @@ -186,10 +197,15 @@ func TestSphinxCorrectness(t *testing.T) { // Now simulate the message propagating through the mix net eventually // reaching the final destination. for i := 0; i < len(nodes); i++ { + // Start each node's DecayedLog and defer shutdown + tempDir := strconv.Itoa(i) + nodes[i].log.Start() + defer shutdown(tempDir, nodes[i].log) + hop := nodes[i] t.Logf("Processing at hop: %v \n", i) - onionPacket, err := hop.ProcessOnionPacket(fwdMsg, nil) + onionPacket, err := hop.ProcessOnionPacket(fwdMsg, nil, uint32(i)+1) if err != nil { t.Fatalf("Node %v was unable to process the "+ "forwarding message: %v", i, err) @@ -246,9 +262,13 @@ func TestSphinxSingleHop(t *testing.T) { t.Fatalf("unable to create test route: %v", err) } + // Start the DecayedLog and defer shutdown + nodes[0].log.Start() + defer shutdown("0", nodes[0].log) + // Simulating a direct single-hop payment, send the sphinx packet to // the destination node, making it process the packet fully. - processedPacket, err := nodes[0].ProcessOnionPacket(fwdMsg, nil) + processedPacket, err := nodes[0].ProcessOnionPacket(fwdMsg, nil, 1) if err != nil { t.Fatalf("unable to process sphinx packet: %v", err) } @@ -269,19 +289,165 @@ func TestSphinxNodeRelpay(t *testing.T) { t.Fatalf("unable to create test route: %v", err) } + // Start the DecayedLog and defer shutdown + nodes[0].log.Start() + defer shutdown("0", nodes[0].log) + // Allow the node to process the initial packet, this should proceed // without any failures. - if _, err := nodes[0].ProcessOnionPacket(fwdMsg, nil); err != nil { + if _, err := nodes[0].ProcessOnionPacket(fwdMsg, nil, 1); err != nil { t.Fatalf("unable to process sphinx packet: %v", err) } // Now, force the node to process the packet a second time, this should // fail with a detected replay error. - if _, err := nodes[0].ProcessOnionPacket(fwdMsg, nil); err != ErrReplayedPacket { + if _, err := nodes[0].ProcessOnionPacket(fwdMsg, nil, 1); err != ErrReplayedPacket { t.Fatalf("sphinx packet replay should be rejected, instead error is %v", err) } } +func TestSphinxNodeRelpaySameBatch(t *testing.T) { + // We'd like to ensure that the sphinx node itself rejects all replayed + // packets which share the same shared secret. + nodes, _, fwdMsg, err := newTestRoute(NumMaxHops) + if err != nil { + t.Fatalf("unable to create test route: %v", err) + } + + // Start the DecayedLog and defer shutdown + nodes[0].log.Start() + defer shutdown("0", nodes[0].log) + + tx := nodes[0].BeginTxn([]byte("0"), 2) + + // Allow the node to process the initial packet, this should proceed + // without any failures. + if err := tx.ProcessOnionPacket(0, fwdMsg, nil, 1); err != nil { + t.Fatalf("unable to process sphinx packet: %v", err) + } + + // Now, force the node to process the packet a second time, this call + // should not fail, even though the batch has internally recorded this + // as a duplicate. + err = tx.ProcessOnionPacket(1, fwdMsg, nil, 1) + if err != nil { + t.Fatalf("adding duplicate sphinx packet to batch should not "+ + "result in an error, instead got: %v", err) + } + + // Commit the batch to disk, then we will inspect the replay set to + // ensure the duplicate entry was properly included. + _, replaySet, err := tx.Commit() + if err != nil { + t.Fatalf("unable to commit batch of sphinx packets: %v", err) + } + + if replaySet.Contains(0) { + t.Fatalf("index 0 was not expected to be in replay set") + } + + if !replaySet.Contains(1) { + t.Fatalf("expected replay set to contain duplicate packet " + + "at index 1") + } +} + +func TestSphinxNodeRelpayLaterBatch(t *testing.T) { + // We'd like to ensure that the sphinx node itself rejects all replayed + // packets which share the same shared secret. + nodes, _, fwdMsg, err := newTestRoute(NumMaxHops) + if err != nil { + t.Fatalf("unable to create test route: %v", err) + } + + // Start the DecayedLog and defer shutdown + nodes[0].log.Start() + defer shutdown("0", nodes[0].log) + + tx := nodes[0].BeginTxn([]byte("0"), 1) + + // Allow the node to process the initial packet, this should proceed + // without any failures. + if err := tx.ProcessOnionPacket(uint16(0), fwdMsg, nil, 1); err != nil { + t.Fatalf("unable to process sphinx packet: %v", err) + } + + _, _, err = tx.Commit() + if err != nil { + t.Fatalf("unable to commit sphinx batch: %v", err) + } + + tx2 := nodes[0].BeginTxn([]byte("1"), 1) + + // Now, force the node to process the packet a second time, this should + // fail with a detected replay error. + err = tx2.ProcessOnionPacket(uint16(0), fwdMsg, nil, 1) + if err != nil { + t.Fatalf("sphinx packet replay should not have been rejected, "+ + "instead error is %v", err) + } + + _, replays, err := tx2.Commit() + if err != nil { + t.Fatalf("unable to commit second sphinx batch: %v", err) + } + + if !replays.Contains(0) { + t.Fatalf("expected replay set to contain index: %v", 0) + } +} + +func TestSphinxNodeRelpayBatchIdempotency(t *testing.T) { + // We'd like to ensure that the sphinx node itself rejects all replayed + // packets which share the same shared secret. + nodes, _, fwdMsg, err := newTestRoute(NumMaxHops) + if err != nil { + t.Fatalf("unable to create test route: %v", err) + } + + // Start the DecayedLog and defer shutdown + nodes[0].log.Start() + defer shutdown("0", nodes[0].log) + + tx := nodes[0].BeginTxn([]byte("0"), 1) + + // Allow the node to process the initial packet, this should proceed + // without any failures. + if err := tx.ProcessOnionPacket(uint16(0), fwdMsg, nil, 1); err != nil { + t.Fatalf("unable to process sphinx packet: %v", err) + } + + packets, replays, err := tx.Commit() + if err != nil { + t.Fatalf("unable to commit sphinx batch: %v", err) + } + + tx2 := nodes[0].BeginTxn([]byte("0"), 1) + + // Now, force the node to process the packet a second time, this should + // not fail with a detected replay error. + err = tx2.ProcessOnionPacket(uint16(0), fwdMsg, nil, 1) + if err != nil { + t.Fatalf("sphinx packet replay should not have been rejected, "+ + "instead error is %v", err) + } + + packets2, replays2, err := tx2.Commit() + if err != nil { + t.Fatalf("unable to commit second sphinx batch: %v", err) + } + + if replays.Size() != replays2.Size() { + t.Fatalf("expected replay set to be %v, instead got %v", + replays, replays2) + } + + if !reflect.DeepEqual(packets, packets2) { + t.Fatalf("expected packets to be %v, instead go %v", + packets, packets2) + } +} + func TestSphinxAssocData(t *testing.T) { // We want to make sure that the associated data is considered in the // HMAC creation @@ -290,7 +456,12 @@ func TestSphinxAssocData(t *testing.T) { t.Fatalf("unable to create random onion packet: %v", err) } - if _, err := nodes[0].ProcessOnionPacket(fwdMsg, []byte("somethingelse")); err == nil { + // Start the DecayedLog and defer shutdown + nodes[0].log.Start() + defer shutdown("0", nodes[0].log) + + _, err = nodes[0].ProcessOnionPacket(fwdMsg, []byte("somethingelse"), 1) + if err == nil { t.Fatalf("we should fail when associated data changes") }