Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uses Self Type in contract factory methods #2997

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion ethpm/contract.py
Expand Up @@ -27,6 +27,9 @@
from ethpm.validation.misc import (
validate_empty_bytes,
)
from web3._utils.compat import (
Self,
)
from web3._utils.validation import (
validate_address,
)
Expand Down Expand Up @@ -59,7 +62,7 @@ def __init__(self, address: bytes, **kwargs: Any) -> None:
super().__init__(address=address, **kwargs) # type: ignore

@classmethod
def factory(cls, w3: "Web3", class_name: str = None, **kwargs: Any) -> Contract:
def factory(cls, w3: "Web3", class_name: str = None, **kwargs: Any) -> Type[Self]:
dep_link_refs = kwargs.get("unlinked_references")
bytecode = kwargs.get("bytecode")
needs_bytecode_linking = False
Expand Down
1 change: 1 addition & 0 deletions newsfragments/2997.misc.rst
@@ -0,0 +1 @@
Uses `Self` Type in contract factory methods
5 changes: 2 additions & 3 deletions setup.py
Expand Up @@ -13,7 +13,7 @@
"black>=22.1.0",
"flake8==3.8.3",
"isort>=5.11.0",
"mypy==0.910",
"mypy>=1.0.0",
"types-setuptools>=57.4.4",
"types-requests>=2.26.1",
"types-protobuf==3.19.13",
Expand Down Expand Up @@ -80,8 +80,7 @@
"protobuf>=4.21.6",
"pywin32>=223;platform_system=='Windows'",
"requests>=2.16.0",
# remove typing_extensions after python_requires>=3.8, see web3._utils.compat
"typing-extensions>=3.7.4.1,<5;python_version<'3.8'",
"typing-extensions>=4.0.1,<5",
"websockets>=10.0.0",
],
python_requires=">=3.7.2",
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Expand Up @@ -53,7 +53,7 @@ commands=
flake8 {toxinidir}/web3 {toxinidir}/ens {toxinidir}/ethpm {toxinidir}/tests --exclude {toxinidir}/ethpm/ethpm-spec,{toxinidir}/**/*_pb2.py
black {toxinidir}/ens {toxinidir}/ethpm {toxinidir}/web3 {toxinidir}/tests {toxinidir}/setup.py --exclude /ethpm/ethpm-spec/|/ethpm/_utils/protobuf/ipfs_file_pb2\.py --check
isort --check-only --diff {toxinidir}/web3/ {toxinidir}/ens/ {toxinidir}/ethpm/ {toxinidir}/tests/
mypy -p web3 -p ethpm -p ens --config-file {toxinidir}/mypy.ini
mypy -p web3 -p ens --config-file {toxinidir}/mypy.ini

[testenv:lint]
basepython: python
Expand Down
22 changes: 7 additions & 15 deletions web3/_utils/compat/__init__.py
@@ -1,16 +1,8 @@
import sys
# Changelog for `typing_extensions` for checking which types were added when
# https://github.com/python/typing_extensions/blob/main/CHANGELOG.md

# remove once web3 supports python>=3.8
# Types was added to typing in 3.8
if sys.version_info >= (3, 8):
from typing import (
Literal,
Protocol,
TypedDict,
)
else:
from typing_extensions import ( # noqa: F401
Literal,
Protocol,
TypedDict,
)
# Note that we do not need to explicitly check for python version here,
# because `typing_extensions` will do it for us and either import from `typing`
# or use the back-ported version of the type.

from typing_extensions import Literal, Protocol, TypedDict, Self # noqa: F401
10 changes: 6 additions & 4 deletions web3/_utils/contract_sources/compile_contracts.py
Expand Up @@ -48,6 +48,8 @@
import re
from typing import (
Any,
Dict,
List,
)

import solcx
Expand Down Expand Up @@ -79,7 +81,7 @@
files_to_compile = [user_filename] if user_filename else all_dot_sol_files


def _compile_dot_sol_files(dot_sol_filename: str) -> dict[str, Any]:
def _compile_dot_sol_files(dot_sol_filename: str) -> Dict[str, Any]:
compiled = solcx.compile_files(
[f"./{dot_sol_filename}"],
output_values=["abi", "bin", "bin-runtime"],
Expand All @@ -88,10 +90,10 @@ def _compile_dot_sol_files(dot_sol_filename: str) -> dict[str, Any]:


def _get_compiled_contract_data(
sol_file_output: dict[str, dict[str, str]],
sol_file_output: Dict[str, Dict[str, str]],
dot_sol_filename: str,
contract_name: str = None,
) -> dict[str, str]:
) -> Dict[str, str]:
if not contract_name:
contract_name = dot_sol_filename.replace(".sol", "")

Expand All @@ -111,7 +113,7 @@ def _get_compiled_contract_data(
contracts_in_file = {}


def compile_files(file_list: list[str]) -> None:
def compile_files(file_list: List[str]) -> None:
for filename in file_list:
with open(os.path.join(os.getcwd(), filename), "r") as f:
dot_sol_file = f.readlines()
Expand Down
14 changes: 7 additions & 7 deletions web3/_utils/events.py
Expand Up @@ -140,7 +140,7 @@ def construct_event_topic_set(
for arg, arg_options in zipped_abi_and_args
]

topics = list(normalize_topic_list([event_topic] + encoded_args)) # type: ignore
topics = list(normalize_topic_list([event_topic] + encoded_args))
return topics


Expand Down Expand Up @@ -394,12 +394,12 @@ def address(self, value: ChecksumAddress) -> None:
def ordered_args(self) -> Tuple[Any, ...]:
return tuple(map(self.args.__getitem__, self._ordered_arg_names))

@property # type: ignore
@property
@to_tuple
def indexed_args(self) -> Tuple[Any, ...]:
return tuple(filter(is_indexed, self.ordered_args))

@property # type: ignore
@property
@to_tuple
def data_args(self) -> Tuple[Any, ...]:
return tuple(filter(is_not_indexed, self.ordered_args))
Expand Down Expand Up @@ -432,8 +432,8 @@ def deploy(self, w3: "Web3") -> "LogFilter":
if not isinstance(w3, web3.Web3):
raise ValueError(f"Invalid web3 argument: got: {w3!r}")

for arg in AttributeDict.values(self.args):
arg._immutable = True
for arg in AttributeDict.values(self.args): # type: ignore[arg-type]
arg._immutable = True # type: ignore[attr-defined]
self._immutable = True

log_filter = cast("LogFilter", w3.eth.filter(self.filter_params))
Expand All @@ -450,8 +450,8 @@ async def deploy(self, async_w3: "AsyncWeb3") -> "AsyncLogFilter":
if not isinstance(async_w3, web3.AsyncWeb3):
raise ValueError(f"Invalid web3 argument: got: {async_w3!r}")

for arg in AttributeDict.values(self.args):
arg._immutable = True
for arg in AttributeDict.values(self.args): # type: ignore[arg-type]
arg._immutable = True # type: ignore[attr-defined]
self._immutable = True

log_filter = await async_w3.eth.filter(self.filter_params)
Expand Down
2 changes: 1 addition & 1 deletion web3/_utils/module_testing/web3_module.py
Expand Up @@ -229,7 +229,7 @@ def test_solidity_keccak(
expected: HexBytes,
) -> None:
if isinstance(expected, type) and issubclass(expected, Exception):
with pytest.raises(expected): # type: ignore
with pytest.raises(expected):
w3.solidity_keccak(types, values)
return

Expand Down
14 changes: 9 additions & 5 deletions web3/contract/async_contract.py
Expand Up @@ -9,6 +9,7 @@
List,
Optional,
Sequence,
Type,
cast,
)

Expand All @@ -33,6 +34,9 @@
from web3._utils.async_transactions import (
fill_transaction_defaults as async_fill_transaction_defaults,
)
from web3._utils.compat import (
Self,
)
from web3._utils.contracts import (
async_parse_block_identifier,
parse_block_identifier_no_extra_call,
Expand Down Expand Up @@ -239,7 +243,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "AsyncContractFunction":
return clone

@classmethod
def factory(cls, class_name: str, **kwargs: Any) -> "AsyncContractFunction":
def factory(cls, class_name: str, **kwargs: Any) -> Self:
return PropertyCheckingFactory(class_name, (cls,), kwargs)(kwargs.get("abi"))

async def call(
Expand Down Expand Up @@ -449,7 +453,7 @@ def __init__(self, address: Optional[ChecksumAddress] = None) -> None:
@classmethod
def factory(
cls, w3: "AsyncWeb3", class_name: Optional[str] = None, **kwargs: Any
) -> "AsyncContract":
) -> Type[Self]:
kwargs["w3"] = w3

normalizers = {
Expand All @@ -460,7 +464,7 @@ def factory(
}

contract = cast(
AsyncContract,
Type[Self],
PropertyCheckingFactory(
class_name or cls.__name__,
(cls,),
Expand Down Expand Up @@ -491,7 +495,7 @@ def factory(
return contract

@classmethod
def constructor(cls, *args: Any, **kwargs: Any) -> "AsyncContractConstructor":
def constructor(cls, *args: Any, **kwargs: Any) -> Self:
"""
:param args: The contract constructor arguments as positional arguments
:param kwargs: The contract constructor arguments as keyword arguments
Expand Down Expand Up @@ -549,7 +553,7 @@ def __init__(

self._functions = filter_by_type("function", self.abi)
for func in self._functions:
fn: AsyncContractFunction = AsyncContractFunction.factory(
fn = AsyncContractFunction.factory(
func["name"],
w3=self.w3,
contract_abi=self.abi,
Expand Down
13 changes: 7 additions & 6 deletions web3/contract/base_contract.py
Expand Up @@ -454,8 +454,9 @@ def _get_call_txparams(self, transaction: Optional[TxParams] = None) -> TxParams
call_transaction.setdefault("to", self.address)
if self.w3.eth.default_account is not empty:
# type ignored b/c check prevents an empty default_account
call_transaction.setdefault(
"from", self.w3.eth.default_account # type: ignore
call_transaction.setdefault( # type: ignore
"from",
self.w3.eth.default_account, # type: ignore
)

if "to" not in call_transaction:
Expand Down Expand Up @@ -485,7 +486,7 @@ def _transact(self, transaction: Optional[TxParams] = None) -> TxParams:
transact_transaction.setdefault("to", self.address)
if self.w3.eth.default_account is not empty:
# type ignored b/c check prevents an empty default_account
transact_transaction.setdefault(
transact_transaction.setdefault( # type: ignore
"from", self.w3.eth.default_account # type: ignore
)

Expand Down Expand Up @@ -516,7 +517,7 @@ def _estimate_gas(self, transaction: Optional[TxParams] = None) -> TxParams:
estimate_gas_transaction.setdefault("to", self.address)
if self.w3.eth.default_account is not empty:
# type ignored b/c check prevents an empty default_account
estimate_gas_transaction.setdefault(
estimate_gas_transaction.setdefault( # type: ignore
"from", self.w3.eth.default_account # type: ignore
)

Expand Down Expand Up @@ -1041,7 +1042,7 @@ def _estimate_gas(self, transaction: Optional[TxParams] = None) -> TxParams:

if self.w3.eth.default_account is not empty:
# type ignored b/c check prevents an empty default_account
estimate_gas_transaction.setdefault(
estimate_gas_transaction.setdefault( # type: ignore
"from", self.w3.eth.default_account # type: ignore
)

Expand All @@ -1060,7 +1061,7 @@ def _get_transaction(self, transaction: Optional[TxParams] = None) -> TxParams:

if self.w3.eth.default_account is not empty:
# type ignored b/c check prevents an empty default_account
transact_transaction.setdefault(
transact_transaction.setdefault( # type: ignore
"from", self.w3.eth.default_account # type: ignore
)

Expand Down
12 changes: 8 additions & 4 deletions web3/contract/contract.py
Expand Up @@ -8,6 +8,7 @@
List,
Optional,
Sequence,
Type,
cast,
)

Expand All @@ -29,6 +30,9 @@
filter_by_type,
receive_func_abi_exists,
)
from web3._utils.compat import (
Self,
)
from web3._utils.contracts import (
parse_block_identifier,
)
Expand Down Expand Up @@ -235,7 +239,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "ContractFunction":
return clone

@classmethod
def factory(cls, class_name: str, **kwargs: Any) -> "ContractFunction":
def factory(cls, class_name: str, **kwargs: Any) -> Self:
return PropertyCheckingFactory(class_name, (cls,), kwargs)(kwargs.get("abi"))

def call(
Expand Down Expand Up @@ -448,7 +452,7 @@ def __init__(self, address: Optional[ChecksumAddress] = None) -> None:
@classmethod
def factory(
cls, w3: "Web3", class_name: Optional[str] = None, **kwargs: Any
) -> "Contract":
) -> Type[Self]:
kwargs["w3"] = w3

normalizers = {
Expand All @@ -459,7 +463,7 @@ def factory(
}

contract = cast(
Contract,
Type[Self],
PropertyCheckingFactory(
class_name or cls.__name__,
(cls,),
Expand Down Expand Up @@ -549,7 +553,7 @@ def __init__(

self._functions = filter_by_type("function", self.abi)
for func in self._functions:
fn: ContractFunction = ContractFunction.factory(
fn = ContractFunction.factory(
func["name"],
w3=self.w3,
contract_abi=self.abi,
Expand Down
2 changes: 1 addition & 1 deletion web3/main.py
Expand Up @@ -335,7 +335,7 @@ def pm(self) -> "PM":
if hasattr(self, "_pm"):
# ignored b/c property is dynamically set
# via enable_unstable_package_management_api
return self._pm # type: ignore
return self._pm
else:
raise AttributeError(
"The Package Management feature is disabled by default until "
Expand Down