From ffb2fbc5587670361b5d7f70a936262063e0ff2f Mon Sep 17 00:00:00 2001 From: Mario Vega Date: Wed, 5 Jun 2024 18:29:54 +0000 Subject: [PATCH] feat(fw): label addresses from code --- src/ethereum_test_tools/common/base_types.py | 6 ++- src/ethereum_test_tools/common/types.py | 54 +++++++++++++++----- whitelist.txt | 2 + 3 files changed, 48 insertions(+), 14 deletions(-) diff --git a/src/ethereum_test_tools/common/base_types.py b/src/ethereum_test_tools/common/base_types.py index c3f0d255d3..b54ba401b0 100644 --- a/src/ethereum_test_tools/common/base_types.py +++ b/src/ethereum_test_tools/common/base_types.py @@ -115,6 +115,8 @@ def __new__(cls, input: BytesConvertible): """ Creates a new Bytes object. """ + if type(input) is cls: + return input return super(Bytes, cls).__new__(cls, to_bytes(input)) def __hash__(self) -> int: @@ -234,6 +236,8 @@ def __new__(cls, input: FixedSizeBytesConvertible | T): """ Creates a new FixedSizeBytes object. """ + if type(input) is cls: + return input return super(FixedSizeBytes, cls).__new__(cls, to_fixed_size_bytes(input, cls.byte_length)) def __hash__(self) -> int: @@ -277,7 +281,7 @@ class Address(FixedSizeBytes[20]): # type: ignore Class that helps represent Ethereum addresses in tests. """ - pass + label: str | None = None class Hash(FixedSizeBytes[32]): # type: ignore diff --git a/src/ethereum_test_tools/common/types.py b/src/ethereum_test_tools/common/types.py index cd28bb3d72..2a4c30970e 100644 --- a/src/ethereum_test_tools/common/types.py +++ b/src/ethereum_test_tools/common/types.py @@ -2,6 +2,7 @@ Useful types for generating Ethereum tests. """ +import inspect from dataclasses import dataclass from enum import IntEnum from functools import cached_property @@ -193,8 +194,11 @@ def __init__(self, address: Address, key: int, want: int, got: int, *args): def __str__(self): """Print exception string""" + label_str = "" + if self.address.label is not None: + label_str = f" ({self.address.label})" return ( - f"incorrect value in address {self.address} for " + f"incorrect value in address {self.address}{label_str} for " + f"key {Hash(self.key)}:" + f" want {HexNumber(self.want)} (dec:{self.want})," + f" got {HexNumber(self.got)} (dec:{self.got})" @@ -369,8 +373,11 @@ def __init__(self, address: Address, want: int | None, got: int | None, *args): def __str__(self): """Print exception string""" + label_str = "" + if self.address.label is not None: + label_str = f" ({self.address.label})" return ( - f"unexpected nonce for account {self.address}: " + f"unexpected nonce for account {self.address}{label_str}: " + f"want {self.want}, got {self.got}" ) @@ -393,8 +400,11 @@ def __init__(self, address: Address, want: int | None, got: int | None, *args): def __str__(self): """Print exception string""" + label_str = "" + if self.address.label is not None: + label_str = f" ({self.address.label})" return ( - f"unexpected balance for account {self.address}: " + f"unexpected balance for account {self.address}{label_str}: " + f"want {self.want}, got {self.got}" ) @@ -417,8 +427,11 @@ def __init__(self, address: Address, want: bytes | None, got: bytes | None, *arg def __str__(self): """Print exception string""" + label_str = "" + if self.address.label is not None: + label_str = f" ({self.address.label})" return ( - f"unexpected code for account {self.address}: " + f"unexpected code for account {self.address}{label_str}: " + f"want {self.want}, got {self.got}" ) @@ -521,6 +534,12 @@ def __new__( instance.nonce = Number(nonce) return instance + def copy(self) -> "Sender": + """ + Returns a copy of the sender. + """ + return Sender(Address(self), key=self.key, nonce=self.nonce) + MAX_SENDERS = 50 SENDERS_ITER = iter( @@ -699,6 +718,7 @@ def deploy_contract( balance: NumberConvertible = 0, nonce: NumberConvertible = 1, address: Address | None = None, + label: str | None = None, ) -> Address: """ Deploy a contract to the allocation. @@ -725,10 +745,22 @@ def deploy_contract( code=code, storage=storage, ) - - return Address(contract_address) - - def fund_sender(self, amount: NumberConvertible) -> Sender: + if label is None: + # Try to deduce the label from the code + frame = inspect.currentframe() + if frame is not None: + caller_frame = frame.f_back + if caller_frame is not None: + code_context = inspect.getframeinfo(caller_frame).code_context + if code_context is not None: + line = code_context[0].strip() + if "=" in line: + label = line.split("=")[0].strip() + + contract_address.label = label + return contract_address + + def fund_sender(self, amount: NumberConvertible, label: str | None = None) -> Sender: """ Fund a sender with a given amount to be able to afford transactions. """ @@ -738,11 +770,7 @@ def fund_sender(self, amount: NumberConvertible) -> Sender: nonce=0, balance=amount, ) - return Sender( - Address(sender), - key=sender.key, - nonce=0, - ) + return sender.copy() raise ValueError("no more senders available") def fund_address(self, address: Address, amount: NumberConvertible): diff --git a/whitelist.txt b/whitelist.txt index d4c60e2e1c..13ec128b33 100644 --- a/whitelist.txt +++ b/whitelist.txt @@ -68,6 +68,7 @@ controlflow cp CPUs crypto +currentframe customizations Customizations danceratopz @@ -152,6 +153,7 @@ gaslimit gasprice GeneralStateTestsFiller gentest +getframeinfo geth geth's getitem