Skip to content

Commit

Permalink
Correctly handle both p2sh and p2pkh addrs in wstakemgr. (#376)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrick committed Nov 8, 2016
1 parent deafe41 commit cfb38c4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
24 changes: 13 additions & 11 deletions wstakemgr/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ func deserializeSStxRecord(serializedSStxRecord []byte) (*sstxRecord, error) {
return record, nil
}

// deserializeSStxTicketScriptHash deserializes and returns a 20 byte script
// deserializeSStxTicketHash160 deserializes and returns a 20 byte script
// hash for a ticket's 0th output.
func deserializeSStxTicketScriptHash(serializedSStxRecord []byte) ([]byte, error) {
func deserializeSStxTicketHash160(serializedSStxRecord []byte) (hash160 []byte, p2sh bool, err error) {
dataLen := len(serializedSStxRecord)
curPos := 0

Expand All @@ -238,32 +238,34 @@ func deserializeSStxTicketScriptHash(serializedSStxRecord []byte) ([]byte, error
// Figure out the actual location of the script.
actualLoc := curPos + pkScrLoc
if actualLoc+3 >= dataLen {
return nil, stakeStoreError(ErrDatabase,
return nil, false, stakeStoreError(ErrDatabase,
"bad serialized sstx record size", nil)
}

// Pop off the script prefix, then pop off the 20 bytes
// HASH160 pubkey or script hash.
prefixBytes := serializedSStxRecord[actualLoc : actualLoc+3]
scriptHash := make([]byte, 20, 20)
p2sh = false
switch {
case bytes.Equal(prefixBytes, sstxTicket2PKHPrefix):
scrHashLoc := actualLoc + 4
if scrHashLoc+20 >= dataLen {
return nil, stakeStoreError(ErrDatabase,
return nil, false, stakeStoreError(ErrDatabase,
"bad serialized sstx record size for pubkey hash", nil)
}
copy(scriptHash, serializedSStxRecord[scrHashLoc:scrHashLoc+20])
case bytes.Equal(prefixBytes, sstxTicket2SHPrefix):
scrHashLoc := actualLoc + 3
if scrHashLoc+20 >= dataLen {
return nil, stakeStoreError(ErrDatabase,
return nil, false, stakeStoreError(ErrDatabase,
"bad serialized sstx record size for script hash", nil)
}
copy(scriptHash, serializedSStxRecord[scrHashLoc:scrHashLoc+20])
p2sh = true
}

return scriptHash, nil
return scriptHash, p2sh, nil
}

// serializeSSTxRecord returns the serialization of the passed txrecord row.
Expand Down Expand Up @@ -597,21 +599,21 @@ func fetchSStxRecord(ns walletdb.ReadBucket, hash *chainhash.Hash) (*sstxRecord,
return deserializeSStxRecord(val)
}

// fetchSStxRecordSStxTicketScriptHash retrieves a ticket 0th output script or
// fetchSStxRecordSStxTicketHash160 retrieves a ticket 0th output script or
// pubkeyhash from the sstx records bucket with the given hash.
func fetchSStxRecordSStxTicketScriptHash(ns walletdb.ReadBucket,
hash *chainhash.Hash) ([]byte, error) {
func fetchSStxRecordSStxTicketHash160(ns walletdb.ReadBucket,
hash *chainhash.Hash) (hash160 []byte, p2sh bool, err error) {

bucket := ns.NestedReadBucket(sstxRecordsBucketName)

key := hash.Bytes()
val := bucket.Get(key)
if val == nil {
str := fmt.Sprintf("missing sstx record for hash '%s'", hash.String())
return nil, stakeStoreError(ErrSStxNotFound, str, nil)
return nil, false, stakeStoreError(ErrSStxNotFound, str, nil)
}

return deserializeSStxTicketScriptHash(val)
return deserializeSStxTicketHash160(val)
}

// fetchSStxRecordVoteBits fetches an individual ticket's intended voteBits
Expand Down
23 changes: 16 additions & 7 deletions wstakemgr/stake.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,23 +313,28 @@ func (s *StakeStore) DumpSStxHashes() ([]chainhash.Hash, error) {
func (s *StakeStore) dumpSStxHashesForAddress(ns walletdb.ReadBucket, addr dcrutil.Address) ([]chainhash.Hash, error) {
// Extract the HASH160 script hash; if it's not 20 bytes
// long, return an error.
scriptHash := addr.ScriptAddress()
if len(scriptHash) != 20 {
hash160 := addr.ScriptAddress()
if len(hash160) != 20 {
str := "stake store is closed"
return nil, stakeStoreError(ErrInput, str, nil)
}
_, addrIsP2SH := addr.(*dcrutil.AddressScriptHash)

allTickets := s.dumpSStxHashes()
var ticketsForAddr []chainhash.Hash

// Access the database and store the result locally.
for _, h := range allTickets {
thisScrHash, err := fetchSStxRecordSStxTicketScriptHash(ns, &h)
thisHash160, p2sh, err := fetchSStxRecordSStxTicketHash160(ns, &h)
if err != nil {
str := "failure getting ticket 0th out script hashes from db"
return nil, stakeStoreError(ErrDatabase, str, err)
}
if bytes.Equal(scriptHash, thisScrHash) {
if addrIsP2SH != p2sh {
continue
}

if bytes.Equal(hash160, thisHash160) {
ticketsForAddr = append(ticketsForAddr, h)
}
}
Expand All @@ -354,13 +359,17 @@ func (s *StakeStore) DumpSStxHashesForAddress(ns walletdb.ReadBucket, addr dcrut
// sstxAddress returns the address for a given ticket.
func (s *StakeStore) sstxAddress(ns walletdb.ReadBucket, hash *chainhash.Hash) (dcrutil.Address, error) {
// Access the database and store the result locally.
thisScrHash, err := fetchSStxRecordSStxTicketScriptHash(ns, hash)
thisHash160, p2sh, err := fetchSStxRecordSStxTicketHash160(ns, hash)
if err != nil {
str := "failure getting ticket 0th out script hashes from db"
return nil, stakeStoreError(ErrDatabase, str, err)
}
addr, err := dcrutil.NewAddressScriptHashFromHash(thisScrHash,
s.Params)
var addr dcrutil.Address
if p2sh {
addr, err = dcrutil.NewAddressScriptHashFromHash(thisHash160, s.Params)
} else {
addr, err = dcrutil.NewAddressPubKeyHash(thisHash160, s.Params, chainec.ECTypeSecp256k1)
}
if err != nil {
str := "failure getting ticket 0th out script hashes from db"
return nil, stakeStoreError(ErrDatabase, str, err)
Expand Down

0 comments on commit cfb38c4

Please sign in to comment.