Skip to content

Commit

Permalink
Mark used addresses to prevent reuse
Browse files Browse the repository at this point in the history
Address review items:

* add comment that usedAddrBucketName was added after v1 release
* use addrHash instead of addressID for fetchAddressUsed
* define fetchAddressUsed before use
* don't bump version number
* remove unnecessary local var for used
* Added upgrade path from version 1 to 2
* Added test case for manager MarkUsed
* Remove version param from upgradeToVersion2
* Move version declaration after upgrade call
* Remove key from cache after marking used
  • Loading branch information
tuxcanfly committed Mar 7, 2015
1 parent cec3dc3 commit f331a07
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 27 deletions.
3 changes: 3 additions & 0 deletions chainntfns.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ func (w *Wallet) addRedeemingTx(tx *btcutil.Tx, block *txstore.Block) error {
if _, err := txr.AddDebits(); err != nil {
return err
}
if err := w.markAddrsUsed(txr); err != nil {
return err
}

bs, err := w.chainSvr.BlockStamp()
if err == nil {
Expand Down
22 changes: 21 additions & 1 deletion waddrmgr/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ type ManagedAddress interface {

// Compressed returns true if the backing address is compressed.
Compressed() bool

// Used returns true if the backing address has been used in a transaction.
Used() bool
}

// ManagedPubKeyAddress extends ManagedAddress and additionally provides the
Expand Down Expand Up @@ -94,6 +97,7 @@ type managedAddress struct {
imported bool
internal bool
compressed bool
used bool
pubKey *btcec.PublicKey
privKeyEncrypted []byte
privKeyCT []byte // non-nil if unlocked
Expand Down Expand Up @@ -184,6 +188,13 @@ func (a *managedAddress) Compressed() bool {
return a.compressed
}

// Used returns true if the address has been used in a transaction.
//
// This is part of the ManagedAddress interface implementation.
func (a *managedAddress) Used() bool {
return a.used
}

// PubKey returns the public key associated with the address.
//
// This is part of the ManagedPubKeyAddress interface implementation.
Expand Down Expand Up @@ -354,6 +365,7 @@ type scriptAddress struct {
scriptEncrypted []byte
scriptCT []byte
scriptMutex sync.Mutex
used bool
}

// Enforce scriptAddress satisfies the ManagedScriptAddress interface.
Expand Down Expand Up @@ -441,6 +453,13 @@ func (a *scriptAddress) Compressed() bool {
return false
}

// Used returns true if the address has been used in a transaction.
//
// This is part of the ManagedAddress interface implementation.
func (a *scriptAddress) Used() bool {
return a.used
}

// Script returns the script associated with the address.
//
// This implements the ScriptAddress interface.
Expand All @@ -465,7 +484,7 @@ func (a *scriptAddress) Script() ([]byte, error) {
}

// newScriptAddress initializes and returns a new pay-to-script-hash address.
func newScriptAddress(m *Manager, account uint32, scriptHash, scriptEncrypted []byte) (*scriptAddress, error) {
func newScriptAddress(m *Manager, account uint32, scriptHash, scriptEncrypted []byte, used bool) (*scriptAddress, error) {
address, err := btcutil.NewAddressScriptHashFromHash(scriptHash,
m.chainParams)
if err != nil {
Expand All @@ -477,5 +496,6 @@ func newScriptAddress(m *Manager, account uint32, scriptHash, scriptEncrypted []
account: account,
address: address,
scriptEncrypted: scriptEncrypted,
used: used,
}, nil
}
78 changes: 77 additions & 1 deletion waddrmgr/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (

const (
// LatestMgrVersion is the most recent manager version.
LatestMgrVersion = 1
LatestMgrVersion = 2
)

var (
Expand Down Expand Up @@ -107,6 +107,7 @@ type dbAddressRow struct {
account uint32
addTime uint64
syncStatus syncStatus
used bool
rawData []byte // Varies based on address type field.
}

Expand Down Expand Up @@ -162,6 +163,9 @@ var (

// Account related key names (account bucket).
acctNumAcctsName = []byte("numaccts")

// Used addresses (used bucket)
usedAddrBucketName = []byte("usedaddrs")
)

// uint32ToBytes converts a 32 bit unsigned integer into a 4-byte slice in
Expand Down Expand Up @@ -732,6 +736,17 @@ func serializeScriptAddress(encryptedHash, encryptedScript []byte) []byte {
return rawData
}

// fetchAddressUsed returns true if the provided address hash was flagged as used.
func fetchAddressUsed(tx walletdb.Tx, addrHash [32]byte) bool {
bucket := tx.RootBucket().Bucket(usedAddrBucketName)

val := bucket.Get(addrHash[:])
if val != nil {
return true
}
return false
}

// fetchAddress loads address information for the provided address id from
// the database. The returned value is one of the address rows for the specific
// address type. The caller should use type assertions to ascertain the type.
Expand All @@ -749,6 +764,7 @@ func fetchAddress(tx walletdb.Tx, addressID []byte) (interface{}, error) {
if err != nil {
return nil, err
}
row.used = fetchAddressUsed(tx, addrHash)

switch row.addrType {
case adtChain:
Expand All @@ -763,6 +779,23 @@ func fetchAddress(tx walletdb.Tx, addressID []byte) (interface{}, error) {
return nil, managerError(ErrDatabase, str, nil)
}

// markAddressUsed flags the provided address id as used in the database.
func markAddressUsed(tx walletdb.Tx, addressID []byte) error {
bucket := tx.RootBucket().Bucket(usedAddrBucketName)

addrHash := fastsha256.Sum256(addressID)
val := bucket.Get(addrHash[:])
if val != nil {
return nil
}
err := bucket.Put(addrHash[:], []byte{0})
if err != nil {
str := fmt.Sprintf("failed to mark address used %x", addressID)
return managerError(ErrDatabase, str, err)
}
return nil
}

// putAddress stores the provided address information to the database. This
// is used a common base for storing the various address types.
func putAddress(tx walletdb.Tx, addressID []byte, row *dbAddressRow) error {
Expand Down Expand Up @@ -1243,6 +1276,13 @@ func createManagerNS(namespace walletdb.Namespace) error {
return managerError(ErrDatabase, str, err)
}

// usedAddrBucketName bucket was added after manager version 1 release
_, err = rootBucket.CreateBucketIfNotExists(usedAddrBucketName)
if err != nil {
str := "failed to create used addresses bucket"
return managerError(ErrDatabase, str, err)
}

if err := putManagerVersion(tx, latestMgrVersion); err != nil {
return err
}
Expand Down Expand Up @@ -1312,6 +1352,16 @@ func upgradeManager(namespace walletdb.Namespace) error {
// version = 3
// }

if version < 2 {
// Upgrade from version 1 to 2.
if err := upgradeToVersion2(namespace); err != nil {
return err
}

// The manager is now at version 2.
version = 2
}

// Ensure the manager is upraded to the latest version. This check is
// to intentionally cause a failure if the manager version is updated
// without writing code to handle the upgrade.
Expand All @@ -1324,3 +1374,29 @@ func upgradeManager(namespace walletdb.Namespace) error {

return nil
}

// upgradeToVersion2 upgrades the database from version 1 to version 2
// 'usedAddrBucketName' a bucket for storing addrs flagged as marked is
// initialized and it will be updated on the next rescan.
func upgradeToVersion2(namespace walletdb.Namespace) error {
err := namespace.Update(func(tx walletdb.Tx) error {
currentMgrVersion := uint32(2)
rootBucket := tx.RootBucket()

_, err := rootBucket.CreateBucketIfNotExists(usedAddrBucketName)
if err != nil {
str := "failed to create used addresses bucket"
return managerError(ErrDatabase, str, err)
}

if err := putManagerVersion(tx, currentMgrVersion); err != nil {
return err
}

return nil
})
if err != nil {
return maybeConvertDbError(err)
}
return nil
}
27 changes: 21 additions & 6 deletions waddrmgr/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ func (m *Manager) Close() error {
// The passed derivedKey is zeroed after the new address is created.
//
// This function MUST be called with the manager lock held for writes.
func (m *Manager) keyToManaged(derivedKey *hdkeychain.ExtendedKey, account, branch, index uint32) (ManagedAddress, error) {
func (m *Manager) keyToManaged(derivedKey *hdkeychain.ExtendedKey, account, branch, index uint32, used bool) (ManagedAddress, error) {
// Create a new managed address based on the public or private key
// depending on whether the passed key is private. Also, zero the
// key after creating the managed address from it.
Expand All @@ -372,6 +372,7 @@ func (m *Manager) keyToManaged(derivedKey *hdkeychain.ExtendedKey, account, bran
if branch == internalBranch {
ma.internal = true
}
ma.used = used

return ma, nil
}
Expand Down Expand Up @@ -486,7 +487,7 @@ func (m *Manager) loadAccountInfo(account uint32) (*accountInfo, error) {
if err != nil {
return nil, err
}
lastExtAddr, err := m.keyToManaged(lastExtKey, account, branch, index)
lastExtAddr, err := m.keyToManaged(lastExtKey, account, branch, index, false)
if err != nil {
return nil, err
}
Expand All @@ -501,7 +502,7 @@ func (m *Manager) loadAccountInfo(account uint32) (*accountInfo, error) {
if err != nil {
return nil, err
}
lastIntAddr, err := m.keyToManaged(lastIntKey, account, branch, index)
lastIntAddr, err := m.keyToManaged(lastIntKey, account, branch, index, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -537,7 +538,7 @@ func (m *Manager) chainAddressRowToManaged(row *dbChainAddressRow) (ManagedAddre
return nil, err
}

return m.keyToManaged(addressKey, row.account, row.branch, row.index)
return m.keyToManaged(addressKey, row.account, row.branch, row.index, row.used)
}

// importedAddressRowToManaged returns a new managed address based on imported
Expand All @@ -564,6 +565,7 @@ func (m *Manager) importedAddressRowToManaged(row *dbImportedAddressRow) (Manage
}
ma.privKeyEncrypted = row.encryptedPrivKey
ma.imported = true
ma.used = row.used

return ma, nil
}
Expand All @@ -578,7 +580,7 @@ func (m *Manager) scriptAddressRowToManaged(row *dbScriptAddressRow) (ManagedAdd
return nil, managerError(ErrCrypto, str, err)
}

return newScriptAddress(m, row.account, scriptHash, row.encryptedScript)
return newScriptAddress(m, row.account, scriptHash, row.encryptedScript, row.used)
}

// rowInterfaceToManaged returns a new managed address based on the given
Expand Down Expand Up @@ -1126,7 +1128,7 @@ func (m *Manager) ImportScript(script []byte, bs *BlockStamp) (ManagedScriptAddr
// since it will be cleared on lock and the script the caller passed
// should not be cleared out from under the caller.
scriptAddr, err := newScriptAddress(m, ImportedAddrAccount, scriptHash,
encryptedScript)
encryptedScript, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1290,6 +1292,19 @@ func (m *Manager) Unlock(passphrase []byte) error {
return nil
}

// MarkUsed updates the used flag for the provided address id.
func (m *Manager) MarkUsed(addressID []byte) error {
err := m.namespace.Update(func(tx walletdb.Tx) error {
return markAddressUsed(tx, addressID)
})
if err != nil {
return maybeConvertDbError(err)
}
// 'used' flag has become stale so remove key from cache
delete(m.addrs, addrKey(addressID))
return nil
}

// ChainParams returns the chain parameters for this address manager.
func (m *Manager) ChainParams() *chaincfg.Params {
// NOTE: No need for mutex here since the net field does not change
Expand Down
28 changes: 28 additions & 0 deletions waddrmgr/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ type expectedAddr struct {
addressHash []byte
internal bool
compressed bool
used bool
imported bool
pubKey []byte
privKey []byte
Expand Down Expand Up @@ -1016,6 +1017,32 @@ func testImportScript(tc *testContext) bool {
return true
}

// testMarkUsed ensures used addresses are flagged as such.
func testMarkUsed(tc *testContext) bool {
expectedAddr1 := expectedAddr{
addressHash: hexToBytes("2ef94abb9ee8f785d087c3ec8d6ee467e92d0d0a"),
used: true,
}
prefix := "MarkUsed"
chainParams := tc.manager.ChainParams()
addrHash := expectedAddr1.addressHash
addr, err := btcutil.NewAddressPubKeyHash(addrHash, chainParams)

err = tc.manager.MarkUsed(addrHash)
if err != nil {
tc.t.Errorf("%s: unexpected error: %v", prefix, err)
}
maddr, err := tc.manager.Address(addr)
if err != nil {
tc.t.Errorf("%s: unexpected error: %v", prefix, err)
}
if maddr.Used() != expectedAddr1.used {
tc.t.Errorf("%v: unexpected used flag -- got "+
"%v, want %v", prefix, maddr.Used(), expectedAddr1.used)
}
return true
}

// testChangePassphrase ensures changes both the public and privte passphrases
// works as intended.
func testChangePassphrase(tc *testContext) bool {
Expand Down Expand Up @@ -1129,6 +1156,7 @@ func testManagerAPI(tc *testContext) {
testInternalAddresses(tc)
testImportPrivateKey(tc)
testImportScript(tc)
testMarkUsed(tc)
testChangePassphrase(tc)
}

Expand Down
Loading

0 comments on commit f331a07

Please sign in to comment.