Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added elfpy/eth/errors/__init__.py
Empty file.
51 changes: 51 additions & 0 deletions elfpy/eth/errors/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Error handling for the hyperdrive ecosystem"""

from eth_utils.conversions import to_hex
from eth_utils.crypto import keccak
from web3.contract.contract import Contract

from .types import ABIError


def decode_error_selector_for_contract(error_selector: str, contract: Contract) -> str:
"""Decode the error selector for a contract,

Arguments
---------

error_selector : str
A 3 byte hex string obtained from a keccak256 has of the error signature, i.e.
'InvalidToken()' would yield '0xc1ab6dc1'.
contract: Contract
A web3.py Contract interface, the abi is required for this function to work.

Returns
-------
str
The name of the error. If the error is not found, returns UnknownError.
"""

abi = contract.abi
if not abi:
raise ValueError("Contract does not have an abi, cannot decode the error selector.")

errors = [
ABIError(name=err.get("name"), inputs=err.get("inputs"), type="error") # type: ignore
for err in abi
if err.get("type") == "error"
]

error_name = "UnknownError"

for error in errors:
error_inputs = error.get("inputs")
# build a list of argument types like 'uint256,bytes,bool'
input_types_csv = ",".join([input_type.get("type") or "" for input_type in error_inputs])
# create an error signature, i.e. CustomError(uint256,bool)
error_signature = f"{error.get('name')}({input_types_csv})"
decoded_error_selector = str(to_hex(primitive=keccak(text=error_signature)))[:10]
if decoded_error_selector == error_selector:
error_name = error.get("name")
break

return error_name
48 changes: 48 additions & 0 deletions elfpy/eth/errors/test_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Tests for errors.py"""
import pytest

from elfpy.eth.errors.errors import decode_error_selector_for_contract


class TestDecodeErrorSelector:
"""Tests for decode_error_selector_for_contract."""

@pytest.fixture
def mock_contract(self):
"""Fixture that returns a MockContract."""

class MockContract:
"""Mock contract for testing."""

abi = [
{"name": "InvalidToken", "inputs": [], "type": "error"},
{"name": "OutOfGas", "inputs": [], "type": "error"},
{"name": "CustomError", "inputs": [{"type": "uint256"}, {"type": "bool"}], "type": "error"},
]

return MockContract()

def test_decode_error_selector_for_contract_error_found(self, mock_contract):
"""Test happy path."""
# Test no inputs
error_selector = "0xc1ab6dc1"
result = decode_error_selector_for_contract(error_selector, mock_contract)
assert result == "InvalidToken"

# Test with inputs
error_selector = "0x659c1f59"
result = decode_error_selector_for_contract(error_selector, mock_contract)
assert result == "CustomError"

def test_decode_error_selector_for_contract_error_not_found(self, mock_contract):
"""Test unhappy path."""
error_selector = "0xdeadbeef"
result = decode_error_selector_for_contract(error_selector, mock_contract)
assert result == "UnknownError"

def test_decode_error_selector_for_contract_no_abi(self, mock_contract):
"""Test bad abi."""
mock_contract.abi = []
error_selector = "0xdeadbeef"
with pytest.raises(ValueError):
decode_error_selector_for_contract(error_selector, mock_contract)
12 changes: 12 additions & 0 deletions elfpy/eth/errors/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Literal, Sequence, TypedDict

from web3.types import ABIFunctionParams


# TODO: add this to web3.py
class ABIError(TypedDict, total=True):
"""ABI error definition."""

name: str
inputs: Sequence[ABIFunctionParams]
type: Literal["error"]
26 changes: 21 additions & 5 deletions elfpy/eth/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,31 @@
from web3.exceptions import ContractCustomError, ContractLogicError
from web3.types import ABI, ABIFunctionComponents, ABIFunctionParams, TxReceipt

from elfpy.hyperdrive_interface.errors import decode_hyperdrive_errors
from elfpy.hyperdrive_interface.errors import lookup_hyperdrive_error_selector

from .accounts import EthAccount


def smart_contract_read(contract: Contract, function_name: str, *fn_args, **fn_kwargs) -> dict[str, Any]:
"""Return from a smart contract read call

Arguments
---------
contract : web3.contract.contract.Contract
The contract that we are reading from.
function_name : str
The name of the function
*fn_args : Unknown
The arguments passed to the contract method.
**fn_kwargs : Unknown
The keyword arguments passed to the contract method.

Returns
-------
dict[str, Any]
A dictionary of value names
.. todo::
Add better typing to the return value
function to recursively find component names & types
function to dynamically assign types to output variables
would be cool if this also put stuff into FixedPoint
Expand Down Expand Up @@ -56,11 +72,11 @@ def smart_contract_transact(
web3 : Web3
web3 provider object
contract : Contract
function_name : str
signer : EthAccount
the EthAccount that will be used to pay for the gas & sign the transaction
function_name_or_signature : str
any compiled web3 contract
this function must exist in the compiled contract's ABI
from_account : EthAccount
the EthAccount that will be used to pay for the gas & sign the transaction
fn_args : unordered list
all remaining arguments will be passed to the contract function in the order received

Expand All @@ -87,7 +103,7 @@ def smart_contract_transact(
except ContractCustomError as err:
logging.error(
"ContractCustomError %s raised.\n function name: %s\nfunction args: %s",
decode_hyperdrive_errors(err.args[0]),
lookup_hyperdrive_error_selector(err.args[0]),
function_name_or_signature,
fn_args,
)
Expand Down
17 changes: 15 additions & 2 deletions elfpy/hyperdrive_interface/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,21 @@

# TODO: get error names from the ABI, encode to get the selector, match selector with name. For now
# this is hard coded list of errors in all the contracts we use.
def decode_hyperdrive_errors(error_selector: str) -> str:
"""Get the error name for a given error selector."""
def lookup_hyperdrive_error_selector(error_selector: str) -> str:
"""Get the error name for a given error selector.

Arguments
---------

error_selector : str
A 3 byte hex string obtained from a keccak256 has of the error signature, i.e.
'InvalidToken()' would yield '0xc1ab6dc1'.

Returns
-------
str
The name of the error.
"""
return getattr(_hyperdrive_errors, error_selector)


Expand Down