Skip to content

Commit

Permalink
compare parsed address in OwnsAddress
Browse files Browse the repository at this point in the history
The string address can have different formattings, so parse the input
string into a common.Address and compare that to the account address.

Also use direct array comparison for address equality instead of
bytes.Equal.
  • Loading branch information
chappjc committed Nov 26, 2021
1 parent ca16f89 commit 3991a0a
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 23 deletions.
27 changes: 16 additions & 11 deletions client/asset/eth/eth.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func CreateWallet(createWalletParams *asset.CreateWalletParams) error {
return err
}

importKeyToNode(node, privateKey, createWalletParams.Pass)
err = importKeyToNode(node, privateKey, createWalletParams.Pass)
if err != nil {
return err
}
Expand Down Expand Up @@ -329,11 +329,17 @@ func (eth *ExchangeWallet) Connect(ctx context.Context) (*sync.WaitGroup, error)
return &wg, nil
}

// OwnsAddress indicates if an address belongs to the wallet.
// OwnsAddress indicates if an address belongs to the wallet. The address need
// not be a EIP55-compliant formatted address. It may or may not have a 0x
// prefix, and case is not important.
//
// In Ethereum, an address is an account.
func (eth *ExchangeWallet) OwnsAddress(address string) (bool, error) {
return strings.ToLower(eth.acct.Address.String()) == strings.ToLower(address), nil
if !common.IsHexAddress(address) {
return false, errors.New("invalid address")
}
addr := common.HexToAddress(address)
return addr == eth.acct.Address, nil
}

// Balance returns the total available funds in the account. The eth node
Expand Down Expand Up @@ -593,7 +599,7 @@ func (eth *ExchangeWallet) FundingCoins(ids []dex.Bytes) (asset.Coins, error) {
if err != nil {
return nil, err
}
if !bytes.Equal(coin.id.Address.Bytes(), eth.acct.Address.Bytes()) {
if coin.id.Address != eth.acct.Address {
return nil, fmt.Errorf("FundingCoins: coin address %v != wallet address %v",
coin.id.Address, eth.acct.Address)
}
Expand Down Expand Up @@ -685,18 +691,18 @@ func (*ExchangeWallet) Redeem(form *asset.RedeemForm) ([]dex.Bytes, asset.Coin,
// SignMessage signs the message with the private key associated with the
// specified funding Coin. Only a coin that came from the address this wallet
// is initialized with can be used to sign.
func (e *ExchangeWallet) SignMessage(coin asset.Coin, msg dex.Bytes) (pubkeys, sigs []dex.Bytes, err error) {
func (eth *ExchangeWallet) SignMessage(coin asset.Coin, msg dex.Bytes) (pubkeys, sigs []dex.Bytes, err error) {
ethCoin, err := decodeCoinID(coin.ID())
if err != nil {
return nil, nil, err
}

if !bytes.Equal(ethCoin.id.Address.Bytes(), e.acct.Address.Bytes()) {
if ethCoin.id.Address != eth.acct.Address {
return nil, nil, fmt.Errorf("SignMessage: coin address: %v != wallet address: %v",
ethCoin.id.Address, e.acct.Address)
ethCoin.id.Address, eth.acct.Address)
}

sig, err := e.node.signData(e.acct.Address, msg)
sig, err := eth.node.signData(eth.acct.Address, msg)
if err != nil {
return nil, nil, fmt.Errorf("SignMessage: error signing data: %v", err)
}
Expand Down Expand Up @@ -764,7 +770,7 @@ func (eth *ExchangeWallet) Locked() bool {
findWallet := func() bool {
for _, w := range wallets {
for _, a := range w.Accounts {
if bytes.Equal(a.Address[:], eth.acct.Address[:]) {
if a.Address == eth.acct.Address {
wallet = w
return true
}
Expand Down Expand Up @@ -894,8 +900,7 @@ func (eth *ExchangeWallet) checkForNewBlocks() {
eth.tipMtx.RLock()
currentTipHash := eth.currentTip.Hash()
eth.tipMtx.RUnlock()
sameTip := bytes.Equal(currentTipHash[:], bestHash[:])
if sameTip {
if currentTipHash == bestHash {
return
}

Expand Down
92 changes: 89 additions & 3 deletions client/asset/eth/eth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"encoding/hex"
"errors"
"math/big"
"math/rand"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -81,9 +83,6 @@ func (n *testNode) balance(ctx context.Context, acct *common.Address) (*big.Int,
func (n *testNode) sendTransaction(ctx context.Context, tx map[string]string) (common.Hash, error) {
return common.Hash{}, nil
}
func (n *testNode) syncStatus(ctx context.Context) (bool, float32, error) {
return false, 0, nil
}
func (n *testNode) unlock(ctx context.Context, pw string, acct *accounts.Account) error {
return nil
}
Expand Down Expand Up @@ -1057,6 +1056,93 @@ func TestMaxOrder(t *testing.T) {
}
}

func TestOwnsAddress(t *testing.T) {
address := "0b84C791b79Ee37De042AD2ffF1A253c3ce9bc27" // no "0x" prefix
if !common.IsHexAddress(address) {
t.Fatalf("bad test address")
}

var otherAddress common.Address
rand.Read(otherAddress[:])

eth := &ExchangeWallet{
acct: &accounts.Account{
Address: common.HexToAddress(address),
},
}

tests := []struct {
name string
address string
wantOwns bool
wantErr bool
}{
{
name: "same (exact)",
address: address,
wantOwns: true,
wantErr: false,
},
{
name: "same (lower)",
address: strings.ToLower(address),
wantOwns: true,
wantErr: false,
},
{
name: "same (upper)",
address: strings.ToUpper(address),
wantOwns: true,
wantErr: false,
},
{
name: "same (0x prefix)",
address: "0x" + address,
wantOwns: true,
wantErr: false,
},
{
name: "different (valid canonical)",
address: otherAddress.String(),
wantOwns: false,
wantErr: false,
},
{
name: "different (valid hex)",
address: otherAddress.Hex(),
wantOwns: false,
wantErr: false,
},
{
name: "error (bad hex char)",
address: strings.Replace(address, "b", "r", 1),
wantOwns: false,
wantErr: true,
},
{
name: "error (bad length)",
address: "ababababababab",
wantOwns: false,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
owns, err := eth.OwnsAddress(tt.address)
if (err == nil) && tt.wantErr {
t.Error("expected error")
}
if (err != nil) && !tt.wantErr {
t.Errorf("unexpected error: %v", err)
}
if owns != tt.wantOwns {
t.Errorf("got %v, want %v", owns, tt.wantOwns)
}
})
}
}

func TestSignMessage(t *testing.T) {
node := &testNode{}
ctx, cancel := context.WithCancel(context.Background())
Expand Down
2 changes: 1 addition & 1 deletion client/asset/eth/rpcclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func (c *rpcclient) swap(ctx context.Context, from *accounts.Account, secretHash
func (c *rpcclient) wallet(acct accounts.Account) (accounts.Wallet, error) {
wallet, err := c.n.AccountManager().Find(acct)
if err != nil {
return nil, fmt.Errorf("error finding wallet for account %s: %v \n", acct.Address, err)
return nil, fmt.Errorf("error finding wallet for account %s: %w", acct.Address, err)
}
return wallet, nil
}
Expand Down
31 changes: 23 additions & 8 deletions client/asset/eth/rpcclient_harness_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,19 @@ func TestMain(m *testing.M) {
// testing harness.
err := os.MkdirAll(simnetWalletDir, 0755)
if err != nil {
return 1, fmt.Errorf("error creating simnet wallet dir dir: %v\n", err)
return 1, fmt.Errorf("error creating simnet wallet dir dir: %v", err)
}
err = os.MkdirAll(participantWalletDir, 0755)
if err != nil {
return 1, fmt.Errorf("error creating participant wallet dir: %v\n", err)
return 1, fmt.Errorf("error creating participant wallet dir: %v", err)
}
addrBytes, err := os.ReadFile(contractAddrFile)
if err != nil {
return 1, fmt.Errorf("error reading contract address: %v\n", err)
return 1, fmt.Errorf("error reading contract address: %v", err)
}
addrLen := len(addrBytes)
if addrLen == 0 {
return 1, fmt.Errorf("no contract address found at %v\n", contractAddrFile)
return 1, fmt.Errorf("no contract address found at %v", contractAddrFile)
}
addrStr := string(addrBytes[:addrLen-1])
contractAddr = common.HexToAddress(addrStr)
Expand All @@ -163,16 +163,19 @@ func TestMain(m *testing.M) {
simnetWallet.internalNode.Wait()
}()
participantWallet, err := setupWallet(participantWalletDir, participantWalletSeed, "localhost:30356")
if err != nil {
return 1, err
}
defer func() {
participantWallet.internalNode.Close()
participantWallet.internalNode.Wait()
}()
addr := common.HexToAddress(addrStr)
if err := ethClient.connect(ctx, simnetWallet.internalNode, &addr); err != nil {
return 1, fmt.Errorf("connect error: %v\n", err)
return 1, fmt.Errorf("connect error: %v", err)
}
if err := participantEthClient.connect(ctx, participantWallet.internalNode, &addr); err != nil {
return 1, fmt.Errorf("connect error: %v\n", err)
return 1, fmt.Errorf("connect error: %v", err)
}
accts, err := exportAccountsFromNode(simnetWallet.internalNode)
if err != nil {
Expand Down Expand Up @@ -217,11 +220,11 @@ func setupWallet(walletDir, seed, listenAddress string) (*ExchangeWallet, error)
}
err := CreateWallet(&createWalletParams)
if err != nil {
return nil, fmt.Errorf("error creating node: %v\n", err)
return nil, fmt.Errorf("error creating node: %v", err)
}
wallet, err := NewWallet(&walletConfig, tLogger, dex.Simnet)
if err != nil {
return nil, fmt.Errorf("error starting node: %v\n", err)
return nil, fmt.Errorf("error starting node: %v", err)
}
return wallet, nil
}
Expand Down Expand Up @@ -289,6 +292,9 @@ func TestBlock(t *testing.T) {

func TestAccounts(t *testing.T) {
accts := ethClient.accounts()
if len(accts) == 0 {
t.Errorf("Found no accounts")
}
spew.Dump(accts)
}

Expand All @@ -297,6 +303,9 @@ func TestBalance(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if bal == nil {
t.Fatalf("empty balance")
}
spew.Dump(bal)
}

Expand All @@ -319,6 +328,9 @@ func TestListWallets(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if len(wallets) == 0 {
t.Fatalf("no wallets")
}
spew.Dump(wallets)
}

Expand Down Expand Up @@ -1381,6 +1393,9 @@ func TestSignMessage(t *testing.T) {
t.Fatalf("error signing text: %v", err)
}
pubKey, err := secp256k1.RecoverPubkey(crypto.Keccak256(msg), signature)
if err != nil {
t.Fatalf("RecoverPubkey: %v", err)
}
x, y := elliptic.Unmarshal(secp256k1.S256(), pubKey)
recoveredAddress := crypto.PubkeyToAddress(ecdsa.PublicKey{
Curve: secp256k1.S256(),
Expand Down

0 comments on commit 3991a0a

Please sign in to comment.