Skip to content

Commit

Permalink
Refactor KeyWallet (#81)
Browse files Browse the repository at this point in the history
* Support compressed public key format
* Add to_dict() and from_dict() to KeyWallet
* Introduce public_key_to_address() and convert_public_key_format()
  • Loading branch information
goldworm-icon committed Apr 26, 2023
1 parent a90a887 commit 6982c92
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 21 deletions.
78 changes: 57 additions & 21 deletions iconsdk/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import warnings
from abc import ABCMeta, abstractmethod
from hashlib import sha3_256
from typing import Dict, Union, Any

from coincurve import PrivateKey
from eth_keyfile import create_keyfile_json, extract_key_from_keyfile
from coincurve import PrivateKey, PublicKey
from eth_keyfile import create_keyfile_json, extract_key_from_keyfile, decode_keyfile_json
from multimethod import multimethod

from iconsdk import logger
Expand Down Expand Up @@ -55,8 +56,15 @@ class KeyWallet(Wallet):
"""KeyWallet class implements Wallet."""

def __init__(self, private_key_object: PrivateKey):
self.__private_key: bytes = private_key_object.secret
self.public_key: bytes = private_key_object.public_key.format(compressed=False)
self._private_key_object: PrivateKey = private_key_object

@property
def private_key(self) -> bytes:
return self._private_key_object.secret

@property
def public_key(self) -> bytes:
return self._private_key_object.public_key.format(compressed=False)

@staticmethod
def create() -> 'KeyWallet':
Expand All @@ -77,8 +85,7 @@ def load(private_key: bytes) -> KeyWallet:
:return: An instance of Wallet class.
"""
try:
private_key_object = PrivateKey(private_key)
wallet = KeyWallet(private_key_object)
wallet = KeyWallet(PrivateKey(private_key))
logger.info(f"Loaded Wallet by the private key. Address: {wallet.get_address()}")
return wallet
except TypeError:
Expand Down Expand Up @@ -119,15 +126,7 @@ def store(self, file_path: PathLikeObject, password: str):
type(str)
"""
try:
key_store_contents = create_keyfile_json(
self.__private_key,
bytes(password, 'utf-8'),
iterations=16384,
kdf="scrypt"
)
key_store_contents['address'] = self.get_address()
key_store_contents['coinType'] = 'icx'

key_store_contents = self.to_dict(password)
# validate the contents of a keystore file.
if is_keystore_file(key_store_contents):
json_string_keystore_data = json.dumps(key_store_contents)
Expand All @@ -142,27 +141,64 @@ def store(self, file_path: PathLikeObject, password: str):
except IsADirectoryError:
raise KeyStoreException("Directory is invalid.")

def get_private_key(self) -> str:
def to_dict(self, password: str) -> Dict[str, Any]:
ret: Dict[str, Any] = create_keyfile_json(
self.private_key,
bytes(password, 'utf-8'),
iterations=16384,
kdf="scrypt"
)
ret['address'] = self.get_address()
ret['coinType'] = 'icx'
return ret

@classmethod
def from_dict(cls, jso: Dict[str, Any], password: str) -> KeyWallet:
private_key: bytes = decode_keyfile_json(jso, password)
return KeyWallet.load(private_key)

def get_private_key(self, hexadecimal: bool = True) -> Union[str, bytes]:
"""Returns the private key of the wallet.
:return a private_key in hexadecimal.
"""
return self.__private_key.hex()
pri_key: bytes = self._private_key_object.secret
return pri_key.hex() if hexadecimal else pri_key

def get_public_key(self, compressed: bool = True, hexadecimal: bool = True) -> Union[str, bytes]:
pub_key: bytes = self._private_key_object.public_key.format(compressed)
return pub_key.hex() if hexadecimal else pub_key

def get_address(self) -> str:
"""Returns an EOA address.
:return address: An EOA address
"""
return f'hx{sha3_256(self.public_key[1:]).digest()[-20:].hex()}'
return public_key_to_address(self._private_key_object.public_key.format(compressed=False))

def sign(self, data: bytes) -> bytes:
"""Generates signature from input data which is transaction data
:param data: data to be signed
:return signature: signature made from input
"""
return sign(data, self.__private_key)
return sign(data, self.private_key)

def __eq__(self, other: KeyWallet) -> bool:
return self.private_key == other.private_key

def __ne__(self, other: KeyWallet) -> bool:
return not self.__eq__(other)


def public_key_to_address(public_key: bytes) -> str:
if not (len(public_key) == 65 and public_key[0] == 4):
pub_key = PublicKey(public_key)
public_key: bytes = pub_key.format(compressed=False)
return f'hx{sha3_256(public_key[1:]).digest()[-20:].hex()}'


def convert_public_key_format(public_key: bytes, compressed: bool) -> bytes:
pub_key = PublicKey(public_key)
return pub_key.format(compressed)


def get_public_key(private_key_object: PrivateKey):
Expand Down
92 changes: 92 additions & 0 deletions tests/wallet/test_wallet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Union, Dict

import pytest

from iconsdk.wallet.wallet import KeyWallet, public_key_to_address, convert_public_key_format


class TestKeyWallet:

@pytest.mark.parametrize(
"compressed,hexadecimal,ret_type,size",
(
(True, True, str, 33),
(True, False, bytes, 33),
(False, True, str, 65),
(False, False, bytes, 65),
)
)
def test_get_public_key(self, compressed: bool, hexadecimal: bool, ret_type: type, size: int):
wallet: KeyWallet = KeyWallet.create()
public_key: Union[str, bytes] = wallet.get_public_key(compressed, hexadecimal)
assert isinstance(public_key, ret_type)
if hexadecimal:
print(public_key)
pub_key: bytes = bytes.fromhex(public_key)
assert len(pub_key) == size
else:
assert len(public_key) == size

def test_get_private_key(self):
wallet: KeyWallet = KeyWallet.create()
private_key: str = wallet.get_private_key()
assert isinstance(private_key, str)
assert not private_key.startswith("0x")

private_key: bytes = wallet.get_private_key(hexadecimal=False)
assert isinstance(private_key, bytes)
assert wallet.private_key == private_key

def test_private_key(self):
wallet: KeyWallet = KeyWallet.create()
wallet2: KeyWallet = KeyWallet.load(wallet.private_key)
assert wallet == wallet2

wallet3 = KeyWallet.create()
assert wallet != wallet3

def test_public_key(self):
wallet: KeyWallet = KeyWallet.create()
public_key: bytes = wallet.public_key
assert isinstance(public_key, bytes)
assert len(public_key) == 65

def test_to_dict(self):
password = "1234"
wallet: KeyWallet = KeyWallet.create()
jso: Dict[str, str] = wallet.to_dict(password)
assert jso["address"] == wallet.get_address()
assert jso["coinType"] == "icx"

wallet2 = KeyWallet.from_dict(jso, password)
assert wallet2 == wallet


def test_public_key_to_address():
wallet: KeyWallet = KeyWallet.create()
address: str = public_key_to_address(wallet.public_key)
assert address == wallet.get_address()

compressed_public_key: bytes = wallet.get_public_key(compressed=True, hexadecimal=False)
address2: str = public_key_to_address(compressed_public_key)
assert address2 == wallet.get_address()


@pytest.mark.parametrize(
"iformat,oformat,size",
(
(True, True, 33),
(True, False, 65),
(False, True, 33),
(False, False, 65),
)
)
def test_convert_public_key_format(iformat: bool, oformat: bool, size: int):
wallet: KeyWallet = KeyWallet.create()

public_key: bytes = wallet.get_public_key(compressed=iformat, hexadecimal=False)
ret: bytes = convert_public_key_format(public_key, compressed=oformat)
if iformat == oformat:
assert ret == public_key
assert len(ret) == size
assert public_key_to_address(public_key) == public_key_to_address(ret)

0 comments on commit 6982c92

Please sign in to comment.