Skip to content

Commit

Permalink
[pyright] [core] misc (#11362)
Browse files Browse the repository at this point in the history
### Summary & Motivation

Miscellaneous assortment of typing fixes required to pass `pyright` in
core `dagster`. This PR contains all of the fixes that did not fit into
any more coherent theme:

- Adding/updating assorted annotations
- Facotring out certain callback types into reusable type aliases
- Guaranteeing certain local variables are defined on access
- A few calls to e.g. `check.not_none` to fix type inference

### How I Tested These Changes

Pyright, BK
  • Loading branch information
smackesey committed Jan 15, 2023
1 parent 585cff5 commit ac05d59
Show file tree
Hide file tree
Showing 57 changed files with 485 additions and 282 deletions.
6 changes: 4 additions & 2 deletions python_modules/dagster/dagster/_check/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,15 @@ def is_callable(obj: object, additional_message: Optional[str] = None) -> Callab
# ##### CLASS
# ########################

T_Type = TypeVar("T_Type", bound=type)


def class_param(
obj: object,
obj: T_Type,
param_name: str,
superclass: Optional[type] = None,
additional_message: Optional[str] = None,
) -> type:
) -> T_Type:
if not isinstance(obj, type):
raise _param_class_mismatch_exception(
obj, param_name, superclass, False, additional_message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def _get_code_pointer_dict_from_kwargs(kwargs: ClickArgMapping) -> Mapping[str,
)
package_name = check.opt_str_elem(kwargs, "package_name")
working_directory = get_working_directory_from_kwargs(kwargs)
attribute = kwargs.get("attribute")
attribute = check.opt_str_elem(kwargs, "attribute")
if python_file:
_check_cli_arguments_none(kwargs, "module_name", "package_name")
return {
Expand Down
6 changes: 4 additions & 2 deletions python_modules/dagster/dagster/_config/config_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,9 @@ def __init__(
):
from .field import resolve_to_config_type

self.scalar_type = resolve_to_config_type(scalar_type)
self.scalar_type = check.inst(
cast(ConfigType, resolve_to_config_type(scalar_type)), ConfigType
)
self.non_scalar_type = resolve_to_config_type(non_scalar_schema)

check.param_invariant(self.scalar_type.kind == ConfigTypeKind.SCALAR, "scalar_type")
Expand All @@ -426,7 +428,7 @@ def __init__(
)

def type_iterator(self) -> Iterator["ConfigType"]:
yield from self.scalar_type.type_iterator()
yield from self.scalar_type.type_iterator() # type: ignore
yield from self.non_scalar_type.type_iterator()
yield from super().type_iterator()

Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/_config/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def resolve_to_config_type(obj: object) -> Union[ConfigType, bool]:
)
)

if not key_type.kind == ConfigTypeKind.SCALAR:
if not key_type.kind == ConfigTypeKind.SCALAR: # type: ignore
raise DagsterInvalidDefinitionError(
"Non-scalar key in map specification: {key} in map {collection}".format(
key=repr(key), collection=obj
Expand Down
16 changes: 13 additions & 3 deletions python_modules/dagster/dagster/_core/definitions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def __init__(
if is_in_composition():
current_context().add_pending_invocation(self)

def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs) -> Any:
from ..execution.context.invocation import UnboundOpExecutionContext
from .decorators.solid_decorator import DecoratedOpFunction
from .solid_invocation import op_invocation_result
Expand Down Expand Up @@ -374,7 +374,15 @@ def __call__(self, *args, **kwargs):
return op_invocation_result(self, None, *args, **kwargs)

assert_in_composition(node_name, self.node_def)
input_bindings = {}
input_bindings: Dict[
str,
Union[
InvokedNodeOutputHandle,
InputMappingNode,
DynamicFanIn,
List[Union[InvokedNodeOutputHandle, InputMappingNode]],
],
] = {}

# handle *args
for idx, output_node in enumerate(args):
Expand Down Expand Up @@ -448,7 +456,9 @@ def __call__(self, *args, **kwargs):
)

outputs = [output_def for output_def in self.node_def.output_defs]
invoked_output_handles = {}
invoked_output_handles: Dict[
str, Union[InvokedNodeDynamicOutputWrapper, InvokedNodeOutputHandle]
] = {}
for output_def in outputs:
if output_def.is_dynamic:
invoked_output_handles[output_def.name] = InvokedNodeDynamicOutputWrapper(
Expand Down
6 changes: 5 additions & 1 deletion python_modules/dagster/dagster/_core/definitions/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Callable, Mapping, NamedTuple, Optional, Union, cast

from typing_extensions import TypeAlias

import dagster._check as check
from dagster._builtins import BuiltinEnum
from dagster._config import (
Expand All @@ -14,6 +16,8 @@

from .definition_config_schema import convert_user_facing_definition_config_schema

ConfigMappingFn: TypeAlias = Callable[[Any], Any]


def is_callable_valid_config_arg(config: Union[Callable[..., Any], Mapping[str, object]]) -> bool:
return BuiltinEnum.contains(config) or is_supported_config_python_builtin(config)
Expand Down Expand Up @@ -51,7 +55,7 @@ class ConfigMapping(

def __new__(
cls,
config_fn: Callable[[Any], Any],
config_fn: ConfigMappingFn,
config_schema: Optional[Any] = None,
receive_processed_config_values: Optional[bool] = None,
):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Mapping, Optional, TypeVar, Union
from typing import Any, Callable, Optional, TypeVar, Union

from typing_extensions import Self

Expand All @@ -8,6 +8,7 @@
_check as check,
)
from dagster._config import EvaluateValueResult
from dagster._config.config_schema import UserConfigSchema

from .definition_config_schema import (
CoercableToConfigSchema,
Expand Down Expand Up @@ -117,7 +118,7 @@ def configured(
self,
config_or_config_fn: Any,
name: str,
config_schema: Optional[Mapping[str, Any]] = None,
config_schema: Optional[UserConfigSchema] = None,
description: Optional[str] = None,
) -> Self: # type: ignore [valid-type] # (until mypy supports Self)
"""
Expand Down Expand Up @@ -196,7 +197,7 @@ def _check_configurable_param(configurable: ConfigurableDefinition) -> None:

def configured(
configurable: T_Configurable,
config_schema: Optional[Mapping[str, Any]] = None,
config_schema: Optional[UserConfigSchema] = None,
**kwargs: Any,
) -> Callable[[object], T_Configurable]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dagster._check as check
from dagster._config import UserConfigSchema

from ..config import ConfigMapping
from ..config import ConfigMapping, ConfigMappingFn


class _ConfigMapping:
Expand All @@ -29,7 +29,7 @@ def __call__(self, fn: Callable[..., Any]) -> ConfigMapping:

@overload
def config_mapping(
config_fn: Callable[..., Any],
config_fn: ConfigMappingFn,
) -> ConfigMapping:
...

Expand All @@ -39,7 +39,7 @@ def config_mapping(
*,
config_schema: UserConfigSchema = ...,
receive_processed_config_values: Optional[bool] = ...,
) -> Union[_ConfigMapping, ConfigMapping]:
) -> Callable[[ConfigMappingFn], ConfigMapping]:
...


Expand All @@ -48,7 +48,7 @@ def config_mapping(
*,
config_schema: Optional[UserConfigSchema] = None,
receive_processed_config_values: Optional[bool] = None,
) -> Union[ConfigMapping, _ConfigMapping]:
) -> Union[Callable[[ConfigMappingFn], ConfigMapping], ConfigMapping]:
"""Create a config mapping with the specified parameters from the decorated function.
The config schema will be inferred from the type signature of the decorated function if not explicitly provided.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from functools import update_wrapper
from typing import Any, Callable, List, Mapping, Optional, Union, overload
from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
TypeVar,
Union,
overload,
)

import dagster._check as check
from dagster._core.decorator_utils import get_function_params
Expand All @@ -21,8 +33,10 @@
from ..sensor_definition import SensorDefinition
from ..unresolved_asset_job_definition import UnresolvedAssetJobDefinition

T = TypeVar("T")


def _flatten(items):
def _flatten(items: Iterable[Union[T, List[T]]]) -> Iterator[T]:
for x in items:
if isinstance(x, List):
# switch to `yield from _flatten(x)` to support multiple layers of nesting
Expand All @@ -49,7 +63,7 @@ def __init__(
)

def __call__(
self, fn: Callable[[], Any]
self, fn: Callable[[], Sequence[Any]]
) -> Union[RepositoryDefinition, PendingRepositoryDefinition]:
from dagster._core.definitions import AssetGroup, AssetsDefinition, SourceAsset
from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition
Expand Down Expand Up @@ -157,7 +171,7 @@ def __call__(


@overload
def repository(definitions_fn: Callable[..., Any]) -> RepositoryDefinition:
def repository(definitions_fn: Callable[..., Sequence[Any]]) -> RepositoryDefinition:
...


Expand All @@ -173,7 +187,7 @@ def repository(


def repository(
definitions_fn: Optional[Callable[..., Any]] = None,
definitions_fn: Optional[Callable[..., Sequence[Any]]] = None,
*,
name: Optional[str] = None,
description: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def from_string(handle_str: str) -> "NodeHandle":
return NodeHandle.from_path(path)

@classmethod
def from_dict(cls, dict_repr: Dict[str, Any]) -> Optional["NodeHandle"]:
def from_dict(cls, dict_repr: Mapping[str, Any]) -> "NodeHandle":
"""This method makes it possible to load a potentially nested NodeHandle after a
roundtrip through json.loads(json.dumps(NodeHandle._asdict())).
"""
Expand All @@ -505,14 +505,16 @@ def from_dict(cls, dict_repr: Dict[str, Any]) -> Optional["NodeHandle"]:
)

if isinstance(dict_repr["parent"], (list, tuple)):
dict_repr["parent"] = NodeHandle.from_dict(
parent = NodeHandle.from_dict(
{
"name": dict_repr["parent"][0],
"parent": dict_repr["parent"][1],
}
)
else:
parent = dict_repr["parent"]

return NodeHandle(**{k: dict_repr[k] for k in ["name", "parent"]})
return NodeHandle(name=dict_repr["name"], parent=parent)


class NodeInputHandle(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import update_wrapper
from typing import TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, Sequence, Union, overload

from typing_extensions import TypeAlias
from typing_extensions import Self, TypeAlias

import dagster._check as check
from dagster._annotations import public
Expand Down Expand Up @@ -146,9 +146,9 @@ def configured(
self,
config_or_config_fn: Any,
name: Optional[str] = None,
config_schema: Optional[Mapping[str, Any]] = None,
config_schema: Optional[UserConfigSchema] = None,
description: Optional[str] = None,
):
) -> Self: # type: ignore # fmt: skip
"""
Wraps this object in an object of the same type that provides configuration to the inner
object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,10 @@ def minutes_late(
self.cron_schedule, evaluation_time, ret_type=datetime.datetime, is_prev=True
)
evaluation_tick = next(schedule_ticks)
else:
elif evaluation_time is not None:
evaluation_tick = evaluation_time
else:
check.failed("Must provide an evaluation time if not using a cron schedule")

minutes_late = 0.0
for used_data_time in used_data_times.values():
Expand Down

0 comments on commit ac05d59

Please sign in to comment.