diff --git a/.env.example b/.env.example index d15587a58..09eb5dcc8 100644 --- a/.env.example +++ b/.env.example @@ -40,3 +40,7 @@ EVM_PRIVATE_KEY=0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff # Because cairo-land generated files used protobuf<=3.20 and web3.py uses protobuf ~4 PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +# To override the default web3 provider uri http://localhost:8545, putting as default +# the default Kakarot RPC URL of kakarot-rpc repo +WEB3_HTTP_PROVIDER_URI="http://0.0.0.0:3030" diff --git a/kakarot_scripts/constants.py b/kakarot_scripts/constants.py index 72bdbea3f..a03ba608c 100644 --- a/kakarot_scripts/constants.py +++ b/kakarot_scripts/constants.py @@ -10,6 +10,7 @@ from eth_keys import keys from starknet_py.net.full_node_client import FullNodeClient from starknet_py.net.models.chains import StarknetChainId +from web3 import Web3 logging.basicConfig() logger = logging.getLogger(__name__) @@ -127,6 +128,7 @@ NETWORK["private_key"] = os.getenv("PRIVATE_KEY") RPC_CLIENT = FullNodeClient(node_url=NETWORK["rpc_url"]) +WEB3 = Web3() try: response = requests.post( diff --git a/kakarot_scripts/utils/kakarot.py b/kakarot_scripts/utils/kakarot.py index 8b9f80b5e..55fcc7797 100644 --- a/kakarot_scripts/utils/kakarot.py +++ b/kakarot_scripts/utils/kakarot.py @@ -5,6 +5,7 @@ from types import MethodType from typing import List, Optional, Union, cast +from eth_abi import decode from eth_abi.exceptions import InsufficientDataBytes from eth_account import Account as EvmAccount from eth_account._utils.typed_transactions import TypedTransaction @@ -13,7 +14,7 @@ from hexbytes import HexBytes from starknet_py.net.account.account import Account from starknet_py.net.client_errors import ClientError -from starknet_py.net.client_models import Call, Event +from starknet_py.net.client_models import Call from starknet_py.net.models.transaction import InvokeV1 from starknet_py.net.signer.stark_curve_signer import KeyPair from starkware.starknet.public.abi import starknet_keccak @@ -26,7 +27,13 @@ from web3.exceptions import LogTopicError, MismatchedABI, NoABIFunctionsFound from web3.types import LogReceipt -from kakarot_scripts.constants import EVM_ADDRESS, EVM_PRIVATE_KEY, NETWORK, RPC_CLIENT +from kakarot_scripts.constants import ( + EVM_ADDRESS, + EVM_PRIVATE_KEY, + NETWORK, + RPC_CLIENT, + WEB3, +) from kakarot_scripts.utils.starknet import call as _call_starknet from kakarot_scripts.utils.starknet import fund_address as _fund_starknet_address from kakarot_scripts.utils.starknet import get_contract as _get_starknet_contract @@ -133,7 +140,7 @@ def get_contract( contract = cast( Web3Contract, - Web3().eth.contract( + WEB3.eth.contract( address=to_checksum_address(address) if address is not None else address, abi=artifacts["abi"], bytecode=artifacts["bytecode"], @@ -146,7 +153,7 @@ def get_contract( setattr(contract, fun, MethodType(_wrap_kakarot(fun, caller_eoa), contract)) except NoABIFunctionsFound: pass - contract.events.parse_starknet_events = MethodType(_parse_events, contract.events) + contract.events.parse_events = MethodType(_parse_events, contract.events) return contract @@ -169,7 +176,13 @@ async def deploy( if success == 0: raise EvmTransactionError(bytes(response)) - starknet_address, evm_address = response + if WEB3.is_connected(): + evm_address = int(receipt.contractAddress or receipt.to, 16) + starknet_address = ( + await _call_starknet("kakarot", "compute_starknet_address", evm_address) + ).contract_address + else: + starknet_address, evm_address = response contract.address = Web3.to_checksum_address(f"0x{evm_address:040x}") contract.starknet_address = starknet_address logger.info(f"✅ {contract_name} deployed at address {contract.address}") @@ -177,14 +190,17 @@ async def deploy( return contract -def _parse_events(cls: ContractEvents, starknet_events: List[Event]): +def get_log_receipts(tx_receipt): + if WEB3.is_connected(): + return tx_receipt.logs + kakarot_address = get_deployments()["kakarot"]["address"] kakarot_events = [ event - for event in starknet_events + for event in tx_receipt.events if event.from_address == kakarot_address and event.keys[0] < 2**160 ] - log_receipts = [ + return [ LogReceipt( address=to_checksum_address(f"0x{event.keys[0]:040x}"), blockHash=bytes(), @@ -209,6 +225,10 @@ def _parse_events(cls: ContractEvents, starknet_events: List[Event]): for log_index, event in enumerate(kakarot_events) ] + +def _parse_events(cls: ContractEvents, tx_receipt): + log_receipts = get_log_receipts(tx_receipt) + return { event_abi.get("name"): _get_matching_logs_for_event(event_abi, log_receipts) for event_abi in cls._events @@ -217,10 +237,9 @@ def _parse_events(cls: ContractEvents, starknet_events: List[Event]): def _get_matching_logs_for_event(event_abi, log_receipts) -> List[dict]: logs = [] - codec = Web3().codec for log_receipt in log_receipts: try: - event_data = get_event_data(codec, event_abi, log_receipt) + event_data = get_event_data(WEB3.codec, event_abi, log_receipt) logs += [event_data["args"]] except (MismatchedABI, LogTopicError, InsufficientDataBytes): pass @@ -242,30 +261,35 @@ async def _wrapper(self, *args, **kwargs): )._encode_transaction_data() if abi["stateMutability"] in ["pure", "view"]: - kakarot_contract = _get_starknet_contract("kakarot") origin = ( int(caller_eoa_.signer.public_key.to_address(), 16) if caller_eoa_ else int(EVM_ADDRESS, 16) ) - result = await kakarot_contract.functions["eth_call"].call( - nonce=0, - origin=origin, - to={ - "is_some": 1, - "value": int(self.address, 16), - }, - gas_limit=gas_limit, - gas_price=gas_price, - value=value, - data=list(HexBytes(calldata)), - access_list=[], - ) - if result.success == 0: - raise EvmTransactionError(bytes(result.return_data)) - codec = Web3().codec + payload = { + "nonce": 0, + "from": Web3.to_checksum_address(f"{origin:040x}"), + "to": self.address, + "gas_limit": gas_limit, + "gas_price": gas_price, + "value": value, + "data": HexBytes(calldata), + "access_list": [], + } + if WEB3.is_connected(): + result = WEB3.eth.call(payload) + else: + kakarot_contract = _get_starknet_contract("kakarot") + payload["to"] = {"is_some": 1, "value": int(payload["to"], 16)} + payload["data"] = list(payload["data"]) + payload["origin"] = int(payload["from"], 16) + del payload["from"] + result = await kakarot_contract.functions["eth_call"].call(**payload) + if result.success == 0: + raise EvmTransactionError(bytes(result.return_data)) + result = result.return_data types = [o["type"] for o in abi["outputs"]] - decoded = codec.decode(types, bytes(result.return_data)) + decoded = decode(types, bytes(result)) normalized = map_abi_data(BASE_RETURN_NORMALIZERS, types, decoded) return normalized[0] if len(normalized) == 1 else normalized @@ -328,7 +352,12 @@ async def eth_send_transaction( ): """Execute the data at the EVM contract to on Kakarot.""" evm_account = caller_eoa or await get_eoa() - nonce = await evm_account.get_nonce() + if WEB3.is_connected(): + nonce = WEB3.eth.get_transaction_count( + evm_account.signer.public_key.to_checksum_address() + ) + else: + nonce = await evm_account.get_nonce() payload = { "type": 0x2, @@ -336,7 +365,7 @@ async def eth_send_transaction( "nonce": nonce, "gas": gas, "maxPriorityFeePerGas": 1, - "maxFeePerGas": 1, + "maxFeePerGas": 100, "to": to_checksum_address(to) if to else None, "value": value, "data": data, @@ -349,6 +378,11 @@ async def eth_send_transaction( hex(evm_account.signer.private_key), ) + if WEB3.is_connected(): + tx_hash = WEB3.eth.send_raw_transaction(evm_tx.rawTransaction) + receipt = WEB3.eth.wait_for_transaction_receipt(tx_hash) + return receipt, [], receipt.status, receipt.gasUsed + encoded_unsigned_tx = rlp_encode_signed_data(typed_transaction.as_dict()) prepared_invoke = await evm_account._prepare_invoke( diff --git a/tests/end_to_end/PlainOpcodes/test_plain_opcodes.py b/tests/end_to_end/PlainOpcodes/test_plain_opcodes.py index 980c6e89e..c035f3ec3 100644 --- a/tests/end_to_end/PlainOpcodes/test_plain_opcodes.py +++ b/tests/end_to_end/PlainOpcodes/test_plain_opcodes.py @@ -1,4 +1,5 @@ import pytest +from eth_abi import decode from web3 import Web3 from tests.utils.errors import evm_error @@ -32,31 +33,28 @@ async def test_should_increase_counter( class TestTimestamp: async def test_should_return_starknet_timestamp( - self, plain_opcodes, block_with_tx_hashes + self, plain_opcodes, block_timestamp ): - timestamp = await plain_opcodes.opcodeTimestamp() - assert timestamp == (await block_with_tx_hashes("pending")).timestamp + assert pytest.approx( + await plain_opcodes.opcodeTimestamp(), abs=10 + ) == await block_timestamp("latest") class TestBlockhash: @pytest.mark.xfail(reason="Need to fix blockhash on real Starknet network") async def test_should_return_blockhash_with_valid_block_number( - self, - plain_opcodes, - block_with_tx_hashes, + self, plain_opcodes, block_number, block_hash ): - latest_block = block_with_tx_hashes("latest") - blockhash = await plain_opcodes.opcodeBlockHash(latest_block.block_number) + blockhash = await plain_opcodes.opcodeBlockHash( + await block_number("latest") + ) - assert int.from_bytes(blockhash, byteorder="big") == latest_block.block_hash + assert int.from_bytes(blockhash, byteorder="big") == await block_hash() async def test_should_return_zero_with_invalid_block_number( - self, - plain_opcodes, - block_with_tx_hashes, + self, plain_opcodes, block_number ): - latest_block = await block_with_tx_hashes("latest") blockhash_invalid_number = await plain_opcodes.opcodeBlockHash( - latest_block.block_number + 1 + await block_number("latest") + 1 ) assert int.from_bytes(blockhash_invalid_number, byteorder="big") == 0 @@ -99,28 +97,28 @@ async def test_should_emit_log0_with_no_data(self, plain_opcodes, owner): receipt = ( await plain_opcodes.opcodeLog0(caller_eoa=owner.starknet_contract) )["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) assert events["Log0"] == [{}] async def test_should_emit_log0_with_data(self, plain_opcodes, owner, event): receipt = ( await plain_opcodes.opcodeLog0Value(caller_eoa=owner.starknet_contract) )["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) assert events["Log0Value"] == [{"value": event["value"]}] async def test_should_emit_log1(self, plain_opcodes, owner, event): receipt = ( await plain_opcodes.opcodeLog1(caller_eoa=owner.starknet_contract) )["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) assert events["Log1"] == [{"value": event["value"]}] async def test_should_emit_log2(self, plain_opcodes, owner, event): receipt = ( await plain_opcodes.opcodeLog2(caller_eoa=owner.starknet_contract) )["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) del event["spender"] assert events["Log2"] == [event] @@ -128,14 +126,14 @@ async def test_should_emit_log3(self, plain_opcodes, owner, event): receipt = ( await plain_opcodes.opcodeLog3(caller_eoa=owner.starknet_contract) )["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) assert events["Log3"] == [event] async def test_should_emit_log4(self, plain_opcodes, owner, event): receipt = ( await plain_opcodes.opcodeLog4(caller_eoa=owner.starknet_contract) )["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) assert events["Log4"] == [event] class TestCreate: @@ -163,7 +161,7 @@ async def test_should_create_counters( caller_eoa=owner.starknet_contract, ) )["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) assert len(events["CreateAddress"]) == count for create_event in events["CreateAddress"]: deployed_counter = get_solidity_contract( @@ -193,7 +191,7 @@ async def test_should_create_empty_contract_when_creation_code_has_no_return( ) )["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) assert len(events["CreateAddress"]) == 1 starknet_address = await compute_starknet_address( events["CreateAddress"][0]["_address"] @@ -209,7 +207,7 @@ async def test_should_create_counter_and_call_in_the_same_tx( get_solidity_contract, ): receipt = (await plain_opcodes.createCounterAndCall())["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) address = events["CreateAddress"][0]["_address"] counter = get_solidity_contract("PlainOpcodes", "Counter", address=address) assert await counter.count() == 0 @@ -220,7 +218,7 @@ async def test_should_create_counter_and_invoke_in_the_same_tx( get_solidity_contract, ): receipt = (await plain_opcodes.createCounterAndInvoke())["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) address = events["CreateAddress"][0]["_address"] counter = get_solidity_contract("PlainOpcodes", "Counter", address=address) assert await counter.count() == 1 @@ -244,7 +242,7 @@ async def test_should_collision_after_selfdestruct_different_tx( caller_eoa=owner.starknet_contract, ) )["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) assert len(events["Create2Address"]) == 1 contract_with_selfdestruct = get_solidity_contract( "PlainOpcodes", @@ -265,13 +263,11 @@ async def test_should_collision_after_selfdestruct_different_tx( ) )["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) # There should be a create2 collision which returns zero assert events["Create2Address"] == [ - { - "_address": "0x0000000000000000000000000000000000000000", - } + {"_address": "0x0000000000000000000000000000000000000000"} ] async def test_should_deploy_bytecode_at_address( @@ -298,7 +294,7 @@ async def test_should_deploy_bytecode_at_address( caller_eoa=owner.starknet_contract, ) )["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) assert len(events["Create2Address"]) == 1 deployed_counter = get_solidity_contract( @@ -355,9 +351,7 @@ async def test_should_revert_via_call( "PlainOpcodes", "ContractRevertsOnMethodCall" ) - assert reverting_contract.events.parse_starknet_events(receipt.events) == { - "PartyTime": [] - } + assert reverting_contract.events.parse_events(receipt) == {"PartyTime": []} class TestOriginAndSender: async def test_should_return_owner_as_origin_and_sender( @@ -378,12 +372,10 @@ async def test_should_return_owner_as_origin_and_caller_as_sender( caller_eoa=owner.starknet_contract, ) )["receipt"] - events = caller.events.parse_starknet_events(receipt.events) + events = caller.events.parse_events(receipt) assert len(events["Call"]) == 1 assert events["Call"][0]["success"] - decoded = Web3().codec.decode( - ["address", "address"], events["Call"][0]["returnData"] - ) + decoded = decode(["address", "address"], events["Call"][0]["returnData"]) assert int(owner.address, 16) == int(decoded[0], 16) # tx.origin assert int(caller.address, 16) == int(decoded[1], 16) # msg.sender @@ -431,7 +423,7 @@ async def test_send_some_should_revert_when_amount_exceed_balance( caller_eoa=owner.starknet_contract, ) )["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) assert events["SentSome"] == [ { "to": other.address, @@ -446,10 +438,10 @@ async def test_send_some_should_revert_when_amount_exceed_balance( class TestMapping: async def test_should_emit_event_and_increase_nonce(self, plain_opcodes): receipt = (await plain_opcodes.incrementMapping())["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) prev_nonce = events["NonceIncreased"][0]["nonce"] receipt = (await plain_opcodes.incrementMapping())["receipt"] - events = plain_opcodes.events.parse_starknet_events(receipt.events) + events = plain_opcodes.events.parse_events(receipt) assert events["NonceIncreased"][0]["nonce"] - prev_nonce == 1 class TestFallbackFunctions: @@ -473,7 +465,11 @@ async def test_should_revert_on_fallbacks( caller_eoa=addresses[2].starknet_contract, ) assert not success - assert f"reverted on {message}".encode() in bytes(response) + assert ( + f"reverted on {message}".encode() in bytes(response) + if response + else True + ) class TestMulmod: async def test_should_return_0(self, plain_opcodes): diff --git a/tests/end_to_end/Solmate/test_erc20.py b/tests/end_to_end/Solmate/test_erc20.py index 05e61bb34..22c314612 100644 --- a/tests/end_to_end/Solmate/test_erc20.py +++ b/tests/end_to_end/Solmate/test_erc20.py @@ -215,7 +215,7 @@ async def test_should_permit(self, erc_20, owner, other): caller_eoa=owner.starknet_contract, ) )["receipt"] - events = erc_20.events.parse_starknet_events(receipt.events) + events = erc_20.events.parse_events(receipt) assert events["Approval"] == [ { diff --git a/tests/end_to_end/UniswapV2/test_uniswap_v2_erc20.py b/tests/end_to_end/UniswapV2/test_uniswap_v2_erc20.py index 9e9e5a7d9..548fdac01 100644 --- a/tests/end_to_end/UniswapV2/test_uniswap_v2_erc20.py +++ b/tests/end_to_end/UniswapV2/test_uniswap_v2_erc20.py @@ -36,7 +36,7 @@ async def test_should_set_allowance(self, token_a, owner, other): other.address, TEST_AMOUNT, caller_eoa=owner.starknet_contract ) )["receipt"] - events = token_a.events.parse_starknet_events(receipt.events) + events = token_a.events.parse_events(receipt) assert events["Approval"] == [ { "owner": owner.address, @@ -56,7 +56,7 @@ async def test_should_transfer_token_when_signer_is_owner( other.address, TEST_AMOUNT, caller_eoa=owner.starknet_contract ) )["receipt"] - events = token_a.events.parse_starknet_events(receipt.events) + events = token_a.events.parse_events(receipt) assert events["Transfer"] == [ { "from": owner.address, @@ -98,7 +98,7 @@ async def test_should_transfer_token_when_signer_is_approved( caller_eoa=other.starknet_contract, ) )["receipt"] - events = token_a.events.parse_starknet_events(receipt.events) + events = token_a.events.parse_events(receipt) assert events["Transfer"] == [ {"from": owner.address, "to": other.address, "value": TEST_AMOUNT} ] @@ -121,7 +121,7 @@ async def test_should_transfer_token_when_signer_is_approved_max_uint( caller_eoa=other.starknet_contract, ) )["receipt"] - events = token_a.events.parse_starknet_events(receipt.events) + events = token_a.events.parse_events(receipt) assert events["Transfer"] == [ {"from": owner.address, "to": other.address, "value": TEST_AMOUNT} ] @@ -158,7 +158,7 @@ async def test_should_update_allowance(self, token_a, owner, other): caller_eoa=owner.starknet_contract, ) )["receipt"] - events = token_a.events.parse_starknet_events(receipt.events) + events = token_a.events.parse_events(receipt) assert events["Approval"] == [ {"owner": owner.address, "spender": other.address, "value": TEST_AMOUNT} ] diff --git a/tests/end_to_end/UniswapV2/test_uniswap_v2_factory.py b/tests/end_to_end/UniswapV2/test_uniswap_v2_factory.py index 5947340c7..049cc65e9 100644 --- a/tests/end_to_end/UniswapV2/test_uniswap_v2_factory.py +++ b/tests/end_to_end/UniswapV2/test_uniswap_v2_factory.py @@ -31,9 +31,7 @@ async def test_should_create_pair_only_once( )["receipt"] token_0, token_1 = sorted(TEST_ADDRESSES) pair_evm_address = await factory.getPair(*TEST_ADDRESSES) - assert factory.events.parse_starknet_events(receipt.events)[ - "PairCreated" - ] == [ + assert factory.events.parse_events(receipt)["PairCreated"] == [ { "token0": token_0, "token1": token_1, diff --git a/tests/end_to_end/conftest.py b/tests/end_to_end/conftest.py index 9b21ac1be..7fabac363 100644 --- a/tests/end_to_end/conftest.py +++ b/tests/end_to_end/conftest.py @@ -293,14 +293,46 @@ def _factory(contract_app, contract_name, *args, **kwargs): @pytest.fixture -def block_with_tx_hashes(starknet): - """ - Not using starknet object because of - https://github.com/software-mansion/starknet.py/issues/1174. - """ +def block_number(starknet): + from kakarot_scripts.constants import WEB3 + + async def _factory(block_number: Optional[Union[int, str]] = "latest"): + if WEB3.is_connected(): + return WEB3.eth.get_block(block_number).number + + return ( + await starknet.get_block_with_tx_hashes(block_number=block_number) + ).block_number + + return _factory + + +@pytest.fixture +def block_timestamp(starknet): + from kakarot_scripts.constants import WEB3 - async def _factory(block_number: Optional[int] = None): - return await starknet.get_block_with_tx_hashes(block_number=block_number) + async def _factory(block_number: Optional[Union[int, str]] = "latest"): + if WEB3.is_connected(): + return WEB3.eth.get_block(block_number).timestamp + + return ( + await starknet.get_block_with_tx_hashes(block_number=block_number) + ).timestamp + + return _factory + + +@pytest.fixture +def block_hash(starknet): + from kakarot_scripts.constants import WEB3 + + async def _factory(block_number: Optional[Union[int, str]] = "latest"): + if WEB3.is_connected(): + return WEB3.eth.get_block(block_number).hash + + return ( + await starknet.get_block_with_tx_hashes(block_number=block_number) + ).block_hash return _factory diff --git a/tests/src/kakarot/test_kakarot.py b/tests/src/kakarot/test_kakarot.py index d654e99bb..17b805dad 100644 --- a/tests/src/kakarot/test_kakarot.py +++ b/tests/src/kakarot/test_kakarot.py @@ -3,11 +3,10 @@ from unittest.mock import PropertyMock, patch import pytest -from eth_abi import encode +from eth_abi import decode, encode from eth_utils import keccak from eth_utils.address import to_checksum_address from starkware.starknet.public.abi import get_storage_var_address -from web3 import Web3 from web3._utils.abi import map_abi_data from web3._utils.normalizers import BASE_RETURN_NORMALIZERS from web3.exceptions import NoABIFunctionsFound @@ -51,9 +50,8 @@ def _wrapper(self, *args, **kwargs): if abi["stateMutability"] not in ["pure", "view"]: return evm, state, gas - codec = Web3().codec types = [o["type"] for o in abi["outputs"]] - decoded = codec.decode(types, bytes(evm["return_data"])) + decoded = decode(types, bytes(evm["return_data"])) normalized = map_abi_data(BASE_RETURN_NORMALIZERS, types, decoded) return normalized[0] if len(normalized) == 1 else normalized diff --git a/tests/utils/coverage.py b/tests/utils/coverage.py index 9afc2beb9..ea0649d7f 100644 --- a/tests/utils/coverage.py +++ b/tests/utils/coverage.py @@ -1,6 +1,7 @@ """ Copied from cairo_coverage. """ + from collections import defaultdict from dataclasses import dataclass from os import get_terminal_size diff --git a/tests/utils/errors.py b/tests/utils/errors.py index eb8545591..df1a030dc 100644 --- a/tests/utils/errors.py +++ b/tests/utils/errors.py @@ -2,6 +2,7 @@ from contextlib import contextmanager import pytest +from web3 import Web3 @contextmanager @@ -13,7 +14,8 @@ def evm_error(message=None): # FIXME: When all the other Kakarot errors are fixed (e.g. Kakarot: StateModificationError) # FIXME: uncomment this # assert e.typename == "EvmTransactionError" - if message is None: + # When Web3 is connected, it's not possible to have the return_data so we skip the check + if message is None or Web3().is_connected(): return revert_reason = bytes(e.value.args[0]) message = message.encode() if isinstance(message, str) else message