Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions ens/async_ens.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
AsyncContract,
AsyncContractFunction,
)
from web3.main import AsyncWeb3 # noqa: F401
from web3.main import ( # noqa: F401
AsyncWeb3,
)
from web3.middleware.base import ( # noqa: F401
Middleware,
)
Expand All @@ -96,12 +98,12 @@ class AsyncENS(BaseENS):
"""

# mypy types
w3: "AsyncWeb3"
w3: "AsyncWeb3[Any]"

def __init__(
self,
provider: "AsyncBaseProvider" = None,
addr: ChecksumAddress = None,
provider: Optional["AsyncBaseProvider"] = None,
addr: Optional[ChecksumAddress] = None,
middleware: Optional[Sequence[Tuple["Middleware", str]]] = None,
) -> None:
"""
Expand All @@ -123,7 +125,9 @@ def __init__(
)

@classmethod
def from_web3(cls, w3: "AsyncWeb3", addr: ChecksumAddress = None) -> "AsyncENS":
def from_web3(
cls, w3: "AsyncWeb3[Any]", addr: ChecksumAddress = None
) -> "AsyncENS":
"""
Generate an AsyncENS instance with web3

Expand Down
2 changes: 1 addition & 1 deletion ens/base_ens.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


class BaseENS:
w3: Union["AsyncWeb3", "Web3"] = None
w3: Union["AsyncWeb3[Any]", "Web3"] = None
ens: Union["Contract", "AsyncContract"] = None
_resolver_contract: Union[Type["Contract"], Type["AsyncContract"]] = None
_reverse_resolver_contract: Union[Type["Contract"], Type["AsyncContract"]] = None
Expand Down
3 changes: 2 additions & 1 deletion ens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def is_valid_ens_name(ens_name: str) -> bool:
def init_async_web3(
provider: "AsyncBaseProvider" = None,
middleware: Optional[Sequence[Tuple["Middleware", str]]] = (),
) -> "AsyncWeb3":
) -> "AsyncWeb3[Any]":
from web3 import (
AsyncWeb3 as AsyncWeb3Main,
)
Expand All @@ -327,6 +327,7 @@ def init_async_web3(
)
)

async_w3: "AsyncWeb3[Any]"
if provider is default:
async_w3 = AsyncWeb3Main(
middleware=middleware, ens=None, modules={"eth": (AsyncEthMain)}
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3761.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `AsyncWeb3` with respect to the provider it is instantiated with, fixing static type issues.
10 changes: 5 additions & 5 deletions tests/core/contracts/test_contract_call_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def test_call_get_byte_array_non_strict(non_strict_arrays_contract, call):
"args,expected",
[
([b"1"], [b"1"]),
(["0xDe"], [b"\xDe"]),
(["0xDe", "0xDe"], [b"\xDe", b"\xDe"]),
(["0xDe"], [b"\xde"]),
(["0xDe", "0xDe"], [b"\xde", b"\xde"]),
],
)
def test_set_byte_array(arrays_contract, call, transact, args, expected):
Expand All @@ -257,8 +257,8 @@ def test_set_byte_array(arrays_contract, call, transact, args, expected):
"args,expected",
[
([b"1"], [b"1"]),
(["0xDe"], [b"\xDe"]),
(["0xDe", "0xDe"], [b"\xDe", b"\xDe"]),
(["0xDe"], [b"\xde"]),
(["0xDe", "0xDe"], [b"\xde", b"\xde"]),
],
)
def test_set_byte_array_non_strict(
Expand Down Expand Up @@ -1503,7 +1503,7 @@ async def test_async_set_byte_array_non_strict(


@pytest.mark.asyncio
@pytest.mark.parametrize("args,expected", [([b"1"], [b"1"]), (["0xDe"], [b"\xDe"])])
@pytest.mark.parametrize("args,expected", [([b"1"], [b"1"]), (["0xDe"], [b"\xde"])])
async def test_async_set_byte_array_strict_by_default(
async_arrays_contract, async_call, async_transact, args, expected
):
Expand Down
6 changes: 3 additions & 3 deletions tests/core/utilities/test_abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def test_get_tuple_type_str_parts(
[
(
["bool[2]", "bytes"],
[[True, False], b"\x00\xFF"],
[("bool[2]", [("bool", True), ("bool", False)]), ("bytes", b"\x00\xFF")],
[[True, False], b"\x00\xff"],
[("bool[2]", [("bool", True), ("bool", False)]), ("bytes", b"\x00\xff")],
),
(
["uint256[]"],
Expand Down Expand Up @@ -337,7 +337,7 @@ def test_map_abi_data(

@pytest.mark.parametrize("arg", (6, 7, 9, 12, 20, 30))
def test_exact_length_bytes_encoder_raises_on_non_multiples_of_8_bit_size(
arg: Tuple[int, ...]
arg: Tuple[int, ...],
) -> None:
with pytest.raises(Web3ValueError, match="multiple of 8"):
_ = ExactLengthBytesEncoder(None, data_byte_size=2, value_bit_size=arg)
Expand Down
2 changes: 1 addition & 1 deletion tests/core/web3-module/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_to_int_hexstr(val, expected):
(b"\x01", "0x01"),
(b"\x10", "0x10"),
(b"\x01\x00", "0x0100"),
(b"\x00\x0F", "0x000f"),
(b"\x00\x0f", "0x000f"),
(b"", "0x"),
(0, "0x0"),
(1, "0x1"),
Expand Down
10 changes: 5 additions & 5 deletions web3/_utils/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ def _named_subtree(

def recursive_dict_to_namedtuple(data: Dict[str, Any]) -> Tuple[Any, ...]:
def _dict_to_namedtuple(
value: Union[Dict[str, Any], List[Any]]
value: Union[Dict[str, Any], List[Any]],
) -> Union[Tuple[Any, ...], List[Any]]:
if not isinstance(value, dict):
return value
Expand All @@ -864,7 +864,7 @@ def _dict_to_namedtuple(


def abi_decoded_namedtuple_factory(
fields: Tuple[Any, ...]
fields: Tuple[Any, ...],
) -> Callable[..., Tuple[Any, ...]]:
class ABIDecodedNamedTuple(namedtuple("ABIDecodedNamedTuple", fields, rename=True)): # type: ignore # noqa: E501
def __new__(self, args: Any) -> "ABIDecodedNamedTuple":
Expand All @@ -877,9 +877,9 @@ def __new__(self, args: Any) -> "ABIDecodedNamedTuple":


async def async_data_tree_map(
async_w3: "AsyncWeb3",
async_w3: "AsyncWeb3[Any]",
func: Callable[
["AsyncWeb3", TypeStr, Any], Coroutine[Any, Any, Tuple[TypeStr, Any]]
["AsyncWeb3[Any]", TypeStr, Any], Coroutine[Any, Any, Tuple[TypeStr, Any]]
],
data_tree: Any,
) -> "ABITypedData":
Expand All @@ -902,7 +902,7 @@ async def async_map_to_typed_data(elements: Any) -> "ABITypedData":

@reject_recursive_repeats
async def async_recursive_map(
async_w3: "AsyncWeb3",
async_w3: "AsyncWeb3[Any]",
func: Callable[[Any], Coroutine[Any, Any, TReturn]],
data: Any,
) -> TReturn:
Expand Down
21 changes: 12 additions & 9 deletions web3/_utils/async_transactions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
Optional,
Union,
Expand Down Expand Up @@ -46,13 +47,13 @@

# unused vars present in these funcs because they all need to have the same signature
async def _estimate_gas(
async_w3: "AsyncWeb3", tx: TxParams, _defaults: Dict[str, Union[bytes, int]]
async_w3: "AsyncWeb3[Any]", tx: TxParams, _defaults: Dict[str, Union[bytes, int]]
) -> int:
return await async_w3.eth.estimate_gas(tx)


async def _max_fee_per_gas(
async_w3: "AsyncWeb3", tx: TxParams, defaults: Dict[str, Union[bytes, int]]
async_w3: "AsyncWeb3[Any]", tx: TxParams, defaults: Dict[str, Union[bytes, int]]
) -> Wei:
block = await async_w3.eth.get_block("latest")
max_priority_fee = tx.get(
Expand All @@ -62,13 +63,13 @@ async def _max_fee_per_gas(


async def _max_priority_fee_gas(
async_w3: "AsyncWeb3", _tx: TxParams, _defaults: Dict[str, Union[bytes, int]]
async_w3: "AsyncWeb3[Any]", _tx: TxParams, _defaults: Dict[str, Union[bytes, int]]
) -> Wei:
return await async_w3.eth.max_priority_fee


async def _chain_id(
async_w3: "AsyncWeb3", _tx: TxParams, _defaults: Dict[str, Union[bytes, int]]
async_w3: "AsyncWeb3[Any]", _tx: TxParams, _defaults: Dict[str, Union[bytes, int]]
) -> int:
return await async_w3.eth.chain_id

Expand All @@ -92,7 +93,7 @@ async def get_block_gas_limit(


async def get_buffered_gas_estimate(
async_w3: "AsyncWeb3", transaction: TxParams, gas_buffer: int = 100000
async_w3: "AsyncWeb3[Any]", transaction: TxParams, gas_buffer: int = 100000
) -> int:
gas_estimate_transaction = cast(TxParams, dict(**transaction))

Expand All @@ -110,7 +111,9 @@ async def get_buffered_gas_estimate(
return min(gas_limit, gas_estimate + gas_buffer)


async def async_fill_nonce(async_w3: "AsyncWeb3", transaction: TxParams) -> TxParams:
async def async_fill_nonce(
async_w3: "AsyncWeb3[Any]", transaction: TxParams
) -> TxParams:
if "from" in transaction and "nonce" not in transaction:
tx_count = await async_w3.eth.get_transaction_count(
cast(ChecksumAddress, transaction["from"]),
Expand All @@ -121,7 +124,7 @@ async def async_fill_nonce(async_w3: "AsyncWeb3", transaction: TxParams) -> TxPa


async def async_fill_transaction_defaults(
async_w3: "AsyncWeb3", transaction: TxParams
async_w3: "AsyncWeb3[Any]", transaction: TxParams
) -> TxParams:
"""
If async_w3 is None, fill as much as possible while offline
Expand Down Expand Up @@ -165,7 +168,7 @@ async def async_fill_transaction_defaults(


async def async_get_required_transaction(
async_w3: "AsyncWeb3", transaction_hash: _Hash32
async_w3: "AsyncWeb3[Any]", transaction_hash: _Hash32
) -> TxData:
current_transaction = await async_w3.eth.get_transaction(transaction_hash)
if not current_transaction:
Expand All @@ -176,7 +179,7 @@ async def async_get_required_transaction(


async def async_replace_transaction(
async_w3: "AsyncWeb3", current_transaction: TxData, new_transaction: TxParams
async_w3: "AsyncWeb3[Any]", current_transaction: TxData, new_transaction: TxParams
) -> HexBytes:
new_transaction = prepare_replacement_transaction(
async_w3, current_transaction, new_transaction
Expand Down
2 changes: 1 addition & 1 deletion web3/_utils/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@


class RequestBatcher(Generic[TFunc]):
def __init__(self, web3: Union["AsyncWeb3", "Web3"]) -> None:
def __init__(self, web3: Union["AsyncWeb3[Any]", "Web3"]) -> None:
self.web3 = web3
self._requests_info: List[BatchRequestInformation] = []
self._async_requests_info: List[
Expand Down
4 changes: 2 additions & 2 deletions web3/_utils/caching/caching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _should_cache_response(


def handle_request_caching(
func: Callable[[SYNC_PROVIDER_TYPE, RPCEndpoint, Any], "RPCResponse"]
func: Callable[[SYNC_PROVIDER_TYPE, RPCEndpoint, Any], "RPCResponse"],
) -> Callable[..., "RPCResponse"]:
def wrapper(
provider: SYNC_PROVIDER_TYPE, method: RPCEndpoint, params: Any
Expand Down Expand Up @@ -401,7 +401,7 @@ def async_handle_recv_caching(
func: Callable[
["PersistentConnectionProvider", "RPCRequest"],
Coroutine[Any, Any, "RPCResponse"],
]
],
) -> Callable[..., Coroutine[Any, Any, "RPCResponse"]]:
async def wrapper(
provider: "PersistentConnectionProvider",
Expand Down
10 changes: 5 additions & 5 deletions web3/_utils/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def find_matching_event_abi(


def encode_abi(
w3: Union["AsyncWeb3", "Web3"],
w3: Union["AsyncWeb3[Any]", "Web3"],
abi: ABIElement,
arguments: Sequence[Any],
data: Optional[HexStr] = None,
Expand Down Expand Up @@ -168,7 +168,7 @@ def encode_abi(

def prepare_transaction(
address: ChecksumAddress,
w3: Union["AsyncWeb3", "Web3"],
w3: Union["AsyncWeb3[Any]", "Web3"],
abi_element_identifier: ABIElementIdentifier,
contract_abi: Optional[ABI] = None,
abi_callable: Optional[ABICallable] = None,
Expand Down Expand Up @@ -232,7 +232,7 @@ def prepare_transaction(


def encode_transaction_data(
w3: Union["AsyncWeb3", "Web3"],
w3: Union["AsyncWeb3[Any]", "Web3"],
abi_element_identifier: ABIElementIdentifier,
contract_abi: Optional[ABI] = None,
abi_callable: Optional[ABICallable] = None,
Expand Down Expand Up @@ -363,7 +363,7 @@ def parse_block_identifier_int(w3: "Web3", block_identifier_int: int) -> BlockNu


async def async_parse_block_identifier(
async_w3: "AsyncWeb3", block_identifier: BlockIdentifier
async_w3: "AsyncWeb3[Any]", block_identifier: BlockIdentifier
) -> BlockIdentifier:
if block_identifier is None:
return async_w3.eth.default_block
Expand All @@ -381,7 +381,7 @@ async def async_parse_block_identifier(


async def async_parse_block_identifier_int(
async_w3: "AsyncWeb3", block_identifier_int: int
async_w3: "AsyncWeb3[Any]", block_identifier_int: int
) -> BlockNumber:
if block_identifier_int >= 0:
block_num = block_identifier_int
Expand Down
2 changes: 1 addition & 1 deletion web3/_utils/ens.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def address(self, name: str) -> ChecksumAddress:

@contextmanager
def ens_addresses(
w3: Union["Web3", "AsyncWeb3"], name_addr_pairs: Dict[str, ChecksumAddress]
w3: Union["Web3", "AsyncWeb3[Any]"], name_addr_pairs: Dict[str, ChecksumAddress]
) -> Iterator[None]:
original_ens = w3.ens
if w3.provider.is_async:
Expand Down
2 changes: 1 addition & 1 deletion web3/_utils/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def deploy(self, w3: "Web3") -> "LogFilter":


class AsyncEventFilterBuilder(BaseEventFilterBuilder):
async def deploy(self, async_w3: "AsyncWeb3") -> "AsyncLogFilter":
async def deploy(self, async_w3: "AsyncWeb3[Any]") -> "AsyncLogFilter":
if not isinstance(async_w3, web3.AsyncWeb3):
raise Web3ValueError(f"Invalid web3 argument: got: {async_w3!r}")

Expand Down
Loading