Skip to content

Commit

Permalink
Mark used addresses to prevent reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
tuxcanfly committed Mar 3, 2015
1 parent 3c5d165 commit 6feb9dc
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 35 deletions.
3 changes: 3 additions & 0 deletions chainntfns.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ func (w *Wallet) addRedeemingTx(tx *btcutil.Tx, block *txstore.Block) error {
if _, err := txr.AddDebits(); err != nil {
return err
}
if err := w.markAddrsUsed(txr); err != nil {
return err
}

bs, err := w.chainSvr.BlockStamp()
if err == nil {
Expand Down
22 changes: 21 additions & 1 deletion waddrmgr/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ type ManagedAddress interface {

// Compressed returns true if the backing address is compressed.
Compressed() bool

// Used returns true if the backing address has been used in a transaction.
Used() bool
}

// ManagedPubKeyAddress extends ManagedAddress and additionally provides the
Expand Down Expand Up @@ -119,6 +122,7 @@ type managedAddress struct {
imported bool
internal bool
compressed bool
used bool
pubKey *btcec.PublicKey
privKeyEncrypted []byte
privKeyCT []byte // non-nil if unlocked
Expand Down Expand Up @@ -209,6 +213,13 @@ func (a *managedAddress) Compressed() bool {
return a.compressed
}

// Used returns true if the address has been used in a transaction.
//
// This is part of the ManagedAddress interface implementation.
func (a *managedAddress) Used() bool {
return a.used
}

// PubKey returns the public key associated with the address.
//
// This is part of the ManagedPubKeyAddress interface implementation.
Expand Down Expand Up @@ -379,6 +390,7 @@ type scriptAddress struct {
scriptEncrypted []byte
scriptCT []byte
scriptMutex sync.Mutex
used bool
}

// Enforce scriptAddress satisfies the ManagedScriptAddress interface.
Expand Down Expand Up @@ -466,6 +478,13 @@ func (a *scriptAddress) Compressed() bool {
return false
}

// Used returns true if the address has been used in a transaction.
//
// This is part of the ManagedAddress interface implementation.
func (a *scriptAddress) Used() bool {
return a.used
}

// Script returns the script associated with the address.
//
// This implements the ScriptAddress interface.
Expand All @@ -490,7 +509,7 @@ func (a *scriptAddress) Script() ([]byte, error) {
}

// newScriptAddress initializes and returns a new pay-to-script-hash address.
func newScriptAddress(m *Manager, account uint32, scriptHash, scriptEncrypted []byte) (*scriptAddress, error) {
func newScriptAddress(m *Manager, account uint32, scriptHash, scriptEncrypted []byte, used bool) (*scriptAddress, error) {
address, err := btcutil.NewAddressScriptHashFromHash(scriptHash,
m.chainParams)
if err != nil {
Expand All @@ -502,5 +521,6 @@ func newScriptAddress(m *Manager, account uint32, scriptHash, scriptEncrypted []
account: account,
address: address,
scriptEncrypted: scriptEncrypted,
used: used,
}, nil
}
107 changes: 98 additions & 9 deletions waddrmgr/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,25 @@ import (

const (
// LatestMgrVersion is the most recent manager version.
LatestMgrVersion = 1
LatestMgrVersion = 2
)

const (
falseByte byte = iota
trueByte
)

func byteAsBool(b byte) bool {
return b != 0
}

func boolAsByte(b bool) byte {
if b {
return trueByte
}
return falseByte
}

// maybeConvertDbError converts the passed error to a ManagerError with an
// error code of ErrDatabase if it is not already a ManagerError. This is
// useful for potential errors returned from managed transaction an other parts
Expand Down Expand Up @@ -101,6 +117,7 @@ type dbAddressRow struct {
account uint32
addTime uint64
syncStatus syncStatus
used bool
rawData []byte // Varies based on address type field.
}

Expand Down Expand Up @@ -156,6 +173,9 @@ var (

// Account related key names (account bucket).
acctNumAcctsName = []byte("numaccts")

// Used addresses (used bucket)
usedAddrBucketName = []byte("usedaddrs")
)

// fetchMasterKeyParams loads the master key parameters needed to derive them
Expand Down Expand Up @@ -526,14 +546,14 @@ func putNumAccounts(tx walletdb.Tx, numAccounts uint32) error {
// the common parts.
func deserializeAddressRow(addressID, serializedAddress []byte) (*dbAddressRow, error) {
// The serialized address format is:
// <addrType><account><addedTime><syncStatus><rawdata>
// <addrType><account><addedTime><syncStatus><used><rawdata>
//
// 1 byte addrType + 4 bytes account + 8 bytes addTime + 1 byte
// syncStatus + 4 bytes raw data length + raw data
// syncStatus + 1 byte used + 4 bytes raw data length + raw data

// Given the above, the length of the entry must be at a minimum
// the constant value sizes.
if len(serializedAddress) < 18 {
if len(serializedAddress) < 19 {
str := fmt.Sprintf("malformed serialized address for key %s",
addressID)
return nil, managerError(ErrDatabase, str, nil)
Expand All @@ -554,13 +574,12 @@ func deserializeAddressRow(addressID, serializedAddress []byte) (*dbAddressRow,
// serializeAddressRow returns the serialization of the passed address row.
func serializeAddressRow(row *dbAddressRow) []byte {
// The serialized address format is:
// <addrType><account><addedTime><syncStatus><commentlen><comment>
// <rawdata>
// <addrType><account><addedTime><syncStatus><used><rawdata>
//
// 1 byte addrType + 4 bytes account + 8 bytes addTime + 1 byte
// syncStatus + 4 bytes raw data length + raw data
// syncStatus + 1 byte used + 4 bytes raw data length + raw data
rdlen := len(row.rawData)
buf := make([]byte, 18+rdlen)
buf := make([]byte, 19+rdlen)
buf[0] = byte(row.addrType)
binary.LittleEndian.PutUint32(buf[1:5], row.account)
binary.LittleEndian.PutUint64(buf[5:13], row.addTime)
Expand Down Expand Up @@ -730,6 +749,8 @@ func fetchAddress(tx walletdb.Tx, addressID []byte) (interface{}, error) {
if err != nil {
return nil, err
}
used := fetchAddrUsed(tx, addressID)
row.used = used

switch row.addrType {
case adtChain:
Expand All @@ -744,6 +765,35 @@ func fetchAddress(tx walletdb.Tx, addressID []byte) (interface{}, error) {
return nil, managerError(ErrDatabase, str, nil)
}

// markAddressUsed flags the provided address id as used in the database.
func markAddressUsed(tx walletdb.Tx, addressID []byte) error {
bucket := tx.RootBucket().Bucket(usedAddrBucketName)

addrHash := fastsha256.Sum256(addressID)
val := bucket.Get(addrHash[:])
if val != nil {
return nil
}
err := bucket.Put(addrHash[:], []byte{0})
if err != nil {
str := fmt.Sprintf("failed to mark address used %x", addressID)
return managerError(ErrDatabase, str, err)
}
return nil
}

// fetchAddrUsed returns true if the provided address id was flagged as used.
func fetchAddrUsed(tx walletdb.Tx, addressID []byte) bool {
bucket := tx.RootBucket().Bucket(usedAddrBucketName)

addrHash := fastsha256.Sum256(addressID)
val := bucket.Get(addrHash[:])
if val != nil {
return true
}
return false
}

// putAddress stores the provided address information to the database. This
// is used a common base for storing the various address types.
func putAddress(tx walletdb.Tx, addressID []byte, row *dbAddressRow) error {
Expand Down Expand Up @@ -1228,6 +1278,12 @@ func upgradeManager(namespace walletdb.Namespace) error {
return managerError(ErrDatabase, str, err)
}

_, err = rootBucket.CreateBucketIfNotExists(usedAddrBucketName)
if err != nil {
str := "failed to create used addresses bucket"
return managerError(ErrDatabase, str, err)
}

// Save the most recent database version if it isn't already
// there, otherwise keep track of it for potential upgrades.
verBytes := mainBucket.Get(mgrVersionName)
Expand Down Expand Up @@ -1268,8 +1324,41 @@ func upgradeManager(namespace walletdb.Namespace) error {

// Upgrade the manager as needed.
if version < LatestMgrVersion {
// No upgrades yet.
// Upgrade addresses used flag
upgradeVersion1to2(namespace)
}

return nil
}

// upgradeVersion1to2 upgrades the database from version 1 to version 2
// 'usedAddrBucketName' a bucket for storing addrs flagged as marked is initialized
// and it will be updated on the next rescan
func upgradeVersion1to2(namespace walletdb.Namespace) error {
err := namespace.Update(func(tx walletdb.Tx) error {
rootBucket := tx.RootBucket()
mainBucket := tx.RootBucket().Bucket(mainBucketName)

_, err := rootBucket.CreateBucketIfNotExists(usedAddrBucketName)
if err != nil {
str := "failed to create used addresses bucket"
return managerError(ErrDatabase, str, err)
}

var version uint32
var buf [4]byte
version = LatestMgrVersion
binary.LittleEndian.PutUint32(buf[:], version)
err = mainBucket.Put(mgrVersionName, buf[:])
if err != nil {
str := "failed to store latest database version"
return managerError(ErrDatabase, str, err)
}
return nil
})
if err != nil {
str := "failed to upgrade version 1 to version 2"
return managerError(ErrDatabase, str, err)
}
return nil
}
25 changes: 19 additions & 6 deletions waddrmgr/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ func (m *Manager) Close() error {
// The passed derivedKey is zeroed after the new address is created.
//
// This function MUST be called with the manager lock held for writes.
func (m *Manager) keyToManaged(derivedKey *hdkeychain.ExtendedKey, account, branch, index uint32) (ManagedAddress, error) {
func (m *Manager) keyToManaged(derivedKey *hdkeychain.ExtendedKey, account, branch, index uint32, used bool) (ManagedAddress, error) {
// Create a new managed address based on the public or private key
// depending on whether the passed key is private. Also, zero the
// key after creating the managed address from it.
Expand All @@ -371,6 +371,7 @@ func (m *Manager) keyToManaged(derivedKey *hdkeychain.ExtendedKey, account, bran
if branch == internalBranch {
ma.internal = true
}
ma.used = used

return ma, nil
}
Expand Down Expand Up @@ -485,7 +486,7 @@ func (m *Manager) loadAccountInfo(account uint32) (*accountInfo, error) {
if err != nil {
return nil, err
}
lastExtAddr, err := m.keyToManaged(lastExtKey, account, branch, index)
lastExtAddr, err := m.keyToManaged(lastExtKey, account, branch, index, false)
if err != nil {
return nil, err
}
Expand All @@ -500,7 +501,7 @@ func (m *Manager) loadAccountInfo(account uint32) (*accountInfo, error) {
if err != nil {
return nil, err
}
lastIntAddr, err := m.keyToManaged(lastIntKey, account, branch, index)
lastIntAddr, err := m.keyToManaged(lastIntKey, account, branch, index, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -536,7 +537,7 @@ func (m *Manager) chainAddressRowToManaged(row *dbChainAddressRow) (ManagedAddre
return nil, err
}

return m.keyToManaged(addressKey, row.account, row.branch, row.index)
return m.keyToManaged(addressKey, row.account, row.branch, row.index, row.used)
}

// importedAddressRowToManaged returns a new managed address based on imported
Expand All @@ -563,6 +564,7 @@ func (m *Manager) importedAddressRowToManaged(row *dbImportedAddressRow) (Manage
}
ma.privKeyEncrypted = row.encryptedPrivKey
ma.imported = true
ma.used = row.used

return ma, nil
}
Expand All @@ -577,7 +579,7 @@ func (m *Manager) scriptAddressRowToManaged(row *dbScriptAddressRow) (ManagedAdd
return nil, managerError(ErrCrypto, str, err)
}

return newScriptAddress(m, row.account, scriptHash, row.encryptedScript)
return newScriptAddress(m, row.account, scriptHash, row.encryptedScript, row.used)
}

// rowInterfaceToManaged returns a new managed address based on the given
Expand Down Expand Up @@ -1125,7 +1127,7 @@ func (m *Manager) ImportScript(script []byte, bs *BlockStamp) (ManagedScriptAddr
// since it will be cleared on lock and the script the caller passed
// should not be cleared out from under the caller.
scriptAddr, err := newScriptAddress(m, ImportedAddrAccount, scriptHash,
encryptedScript)
encryptedScript, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1289,6 +1291,17 @@ func (m *Manager) Unlock(passphrase []byte) error {
return nil
}

// MarkUsed updates the used flag for the provided address id.
func (m *Manager) MarkUsed(addressID []byte) error {
err := m.namespace.Update(func(tx walletdb.Tx) error {
return markAddressUsed(tx, addressID)
})
if err != nil {
return maybeConvertDbError(err)
}
return nil
}

// ChainParams returns the chain parameters for this address manager.
func (m *Manager) ChainParams() *chaincfg.Params {
// NOTE: No need for mutex here since the net field does not change
Expand Down
Loading

0 comments on commit 6feb9dc

Please sign in to comment.