Skip to content

Commit

Permalink
feat(fw): label addresses from code
Browse files Browse the repository at this point in the history
  • Loading branch information
marioevz committed Jun 5, 2024
1 parent f4cb5b1 commit ffb2fbc
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 14 deletions.
6 changes: 5 additions & 1 deletion src/ethereum_test_tools/common/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def __new__(cls, input: BytesConvertible):
"""
Creates a new Bytes object.
"""
if type(input) is cls:
return input
return super(Bytes, cls).__new__(cls, to_bytes(input))

def __hash__(self) -> int:
Expand Down Expand Up @@ -234,6 +236,8 @@ def __new__(cls, input: FixedSizeBytesConvertible | T):
"""
Creates a new FixedSizeBytes object.
"""
if type(input) is cls:
return input
return super(FixedSizeBytes, cls).__new__(cls, to_fixed_size_bytes(input, cls.byte_length))

def __hash__(self) -> int:
Expand Down Expand Up @@ -277,7 +281,7 @@ class Address(FixedSizeBytes[20]): # type: ignore
Class that helps represent Ethereum addresses in tests.
"""

pass
label: str | None = None


class Hash(FixedSizeBytes[32]): # type: ignore
Expand Down
54 changes: 41 additions & 13 deletions src/ethereum_test_tools/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Useful types for generating Ethereum tests.
"""

import inspect
from dataclasses import dataclass
from enum import IntEnum
from functools import cached_property
Expand Down Expand Up @@ -193,8 +194,11 @@ def __init__(self, address: Address, key: int, want: int, got: int, *args):

def __str__(self):
"""Print exception string"""
label_str = ""
if self.address.label is not None:
label_str = f" ({self.address.label})"
return (
f"incorrect value in address {self.address} for "
f"incorrect value in address {self.address}{label_str} for "
+ f"key {Hash(self.key)}:"
+ f" want {HexNumber(self.want)} (dec:{self.want}),"
+ f" got {HexNumber(self.got)} (dec:{self.got})"
Expand Down Expand Up @@ -369,8 +373,11 @@ def __init__(self, address: Address, want: int | None, got: int | None, *args):

def __str__(self):
"""Print exception string"""
label_str = ""
if self.address.label is not None:
label_str = f" ({self.address.label})"
return (
f"unexpected nonce for account {self.address}: "
f"unexpected nonce for account {self.address}{label_str}: "
+ f"want {self.want}, got {self.got}"
)

Expand All @@ -393,8 +400,11 @@ def __init__(self, address: Address, want: int | None, got: int | None, *args):

def __str__(self):
"""Print exception string"""
label_str = ""
if self.address.label is not None:
label_str = f" ({self.address.label})"
return (
f"unexpected balance for account {self.address}: "
f"unexpected balance for account {self.address}{label_str}: "
+ f"want {self.want}, got {self.got}"
)

Expand All @@ -417,8 +427,11 @@ def __init__(self, address: Address, want: bytes | None, got: bytes | None, *arg

def __str__(self):
"""Print exception string"""
label_str = ""
if self.address.label is not None:
label_str = f" ({self.address.label})"
return (
f"unexpected code for account {self.address}: "
f"unexpected code for account {self.address}{label_str}: "
+ f"want {self.want}, got {self.got}"
)

Expand Down Expand Up @@ -521,6 +534,12 @@ def __new__(
instance.nonce = Number(nonce)
return instance

def copy(self) -> "Sender":
"""
Returns a copy of the sender.
"""
return Sender(Address(self), key=self.key, nonce=self.nonce)


MAX_SENDERS = 50
SENDERS_ITER = iter(
Expand Down Expand Up @@ -699,6 +718,7 @@ def deploy_contract(
balance: NumberConvertible = 0,
nonce: NumberConvertible = 1,
address: Address | None = None,
label: str | None = None,
) -> Address:
"""
Deploy a contract to the allocation.
Expand All @@ -725,10 +745,22 @@ def deploy_contract(
code=code,
storage=storage,
)

return Address(contract_address)

def fund_sender(self, amount: NumberConvertible) -> Sender:
if label is None:
# Try to deduce the label from the code
frame = inspect.currentframe()
if frame is not None:
caller_frame = frame.f_back
if caller_frame is not None:
code_context = inspect.getframeinfo(caller_frame).code_context
if code_context is not None:
line = code_context[0].strip()
if "=" in line:
label = line.split("=")[0].strip()

contract_address.label = label
return contract_address

def fund_sender(self, amount: NumberConvertible, label: str | None = None) -> Sender:
"""
Fund a sender with a given amount to be able to afford transactions.
"""
Expand All @@ -738,11 +770,7 @@ def fund_sender(self, amount: NumberConvertible) -> Sender:
nonce=0,
balance=amount,
)
return Sender(
Address(sender),
key=sender.key,
nonce=0,
)
return sender.copy()
raise ValueError("no more senders available")

def fund_address(self, address: Address, amount: NumberConvertible):
Expand Down
2 changes: 2 additions & 0 deletions whitelist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ controlflow
cp
CPUs
crypto
currentframe
customizations
Customizations
danceratopz
Expand Down Expand Up @@ -152,6 +153,7 @@ gaslimit
gasprice
GeneralStateTestsFiller
gentest
getframeinfo
geth
geth's
getitem
Expand Down

0 comments on commit ffb2fbc

Please sign in to comment.