Skip to content

Commit

Permalink
Improve typing of schedule/sensor and hooks (#7560)
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Apr 26, 2022
1 parent b49aadc commit 69546e9
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 180 deletions.
18 changes: 14 additions & 4 deletions python_modules/dagster/dagster/check/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,25 +93,35 @@ def bool_elem(ddict: Dict, key: str) -> bool:
# ##### CALLABLE
# ########################

T_Callable = TypeVar("T_Callable", bound=Callable)
U_Callable = TypeVar("U_Callable", bound=Callable)

def callable_param(obj: object, param_name: str) -> Callable:

def callable_param(obj: T_Callable, param_name: str) -> T_Callable:
if not callable(obj):
raise _param_not_callable_exception(obj, param_name)
return obj


@overload
def opt_callable_param(obj: object, param_name: str, default: Callable) -> Callable:
def opt_callable_param(obj: None, param_name: str, default: None = ...) -> None:
...


@overload
def opt_callable_param(obj: None, param_name: str, default: T_Callable) -> T_Callable:
...


@overload
def opt_callable_param(obj: object, param_name: str) -> Optional[Callable]:
def opt_callable_param(
obj: T_Callable, param_name: str, default: Optional[U_Callable] = ...
) -> T_Callable:
...


def opt_callable_param(
obj: object, param_name: str, default: Optional[Callable] = None
obj: Optional[Callable], param_name: str, default: Optional[Callable] = None
) -> Optional[Callable]:
if obj is not None and not callable(obj):
raise _param_not_callable_exception(obj, param_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,22 @@ def slack_on_materializations(context, event_list):


@overload
def success_hook(name: SuccessOrFailureHookFn) -> Union[HookDefinition, _Hook]:
def success_hook(name: SuccessOrFailureHookFn) -> HookDefinition:
...


@overload
def success_hook(
name: Optional[str] = ...,
required_resource_keys: Optional[AbstractSet[str]] = ...,
) -> Callable[[SuccessOrFailureHookFn], Union[HookDefinition, _Hook]]:
) -> Callable[[SuccessOrFailureHookFn], HookDefinition]:
...


def success_hook(
name: Optional[Union[SuccessOrFailureHookFn, str]] = None,
required_resource_keys: Optional[AbstractSet[str]] = None,
) -> Union[HookDefinition, _Hook, Callable[[SuccessOrFailureHookFn], Union[HookDefinition, _Hook]]]:
) -> Union[HookDefinition, Callable[[SuccessOrFailureHookFn], HookDefinition]]:
"""Create a hook on step success events with the specified parameters from the decorated function.
Args:
Expand All @@ -171,7 +171,7 @@ def do_something_on_success(context):
"""

def wrapper(fn: Callable[["HookContext"], Any]) -> Union[HookDefinition, _Hook]:
def wrapper(fn: SuccessOrFailureHookFn) -> HookDefinition:

check.callable_param(fn, "fn")

Expand Down Expand Up @@ -206,22 +206,22 @@ def _success_hook(


@overload
def failure_hook(name: SuccessOrFailureHookFn) -> Union[HookDefinition, _Hook]:
def failure_hook(name: SuccessOrFailureHookFn) -> HookDefinition:
...


@overload
def failure_hook(
name: Optional[str] = ...,
required_resource_keys: Optional[AbstractSet[str]] = ...,
) -> Callable[[SuccessOrFailureHookFn], Union[HookDefinition, _Hook]]:
) -> Callable[[SuccessOrFailureHookFn], HookDefinition]:
...


def failure_hook(
name: Optional[Union[SuccessOrFailureHookFn, str]] = None,
required_resource_keys: Optional[AbstractSet[str]] = None,
) -> Union[HookDefinition, _Hook, Callable[[SuccessOrFailureHookFn], Union[HookDefinition, _Hook]]]:
) -> Union[HookDefinition, Callable[[SuccessOrFailureHookFn], HookDefinition]]:
"""Create a hook on step failure events with the specified parameters from the decorated function.
Args:
Expand All @@ -245,7 +245,7 @@ def do_something_on_failure(context):
"""

def wrapper(fn: Callable[["HookContext"], Any]) -> Union[HookDefinition, _Hook]:
def wrapper(fn: Callable[["HookContext"], Any]) -> HookDefinition:
check.callable_param(fn, "fn")

expected_positionals = ["context"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,7 @@
import datetime
import warnings
from functools import update_wrapper
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
NamedTuple,
Optional,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast

from dagster import check
from dagster.core.definitions.partition import (
Expand All @@ -36,56 +25,43 @@
create_offset_partition_selector,
)

from ...decorator_utils import get_function_params
from ...storage.tags import check_tags
from ..graph_definition import GraphDefinition
from ..mode import DEFAULT_MODE_NAME
from ..pipeline_definition import PipelineDefinition
from ..run_request import RunRequest, SkipReason
from ..schedule_definition import DefaultScheduleStatus, ScheduleDefinition, is_context_provided
from ..schedule_definition import (
DecoratedScheduleFunction,
DefaultScheduleStatus,
RawScheduleEvaluationFunction,
RunRequestIterator,
ScheduleDefinition,
ScheduleEvaluationContext,
is_context_provided,
)

if TYPE_CHECKING:
from dagster import Partition, ScheduleEvaluationContext
from dagster import Partition

# Error messages are long
# pylint: disable=C0301

RunConfig = Dict[str, Any]
RunRequestGenerator = Generator[Union[RunRequest, SkipReason], None, None]


class DecoratedScheduleFunction(NamedTuple):
"""Wrapper around the decorated schedule function. Keeps track of both to better support the
optimal return value for direct invocation of the evaluation function"""

decorated_fn: Callable[..., Union[RunRequest, SkipReason, RunConfig, RunRequestGenerator]]
wrapped_fn: Callable[["ScheduleEvaluationContext"], RunRequestGenerator]
has_context_arg: bool


def schedule(
cron_schedule: str,
pipeline_name: Optional[str] = None,
name: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
tags_fn: Optional[Callable[["ScheduleEvaluationContext"], Optional[Dict[str, str]]]] = None,
tags_fn: Optional[Callable[[ScheduleEvaluationContext], Optional[Dict[str, str]]]] = None,
solid_selection: Optional[List[str]] = None,
mode: Optional[str] = "default",
should_execute: Optional[Callable[["ScheduleEvaluationContext"], bool]] = None,
should_execute: Optional[Callable[[ScheduleEvaluationContext], bool]] = None,
environment_vars: Optional[Dict[str, str]] = None,
execution_timezone: Optional[str] = None,
description: Optional[str] = None,
job: Optional[Union[PipelineDefinition, GraphDefinition]] = None,
default_status: DefaultScheduleStatus = DefaultScheduleStatus.STOPPED,
) -> Callable[
[
Callable[
...,
Union[RunRequest, SkipReason, RunConfig, RunRequestGenerator],
]
],
ScheduleDefinition,
]:
) -> Callable[[RawScheduleEvaluationFunction], ScheduleDefinition]:
"""
Creates a schedule following the provided cron schedule and requests runs for the provided job.
Expand Down Expand Up @@ -118,7 +94,7 @@ def schedule(
mode (Optional[str]): The pipeline mode in which to execute this schedule.
(Default: 'default')
should_execute (Optional[Callable[[ScheduleEvaluationContext], bool]]): A function that runs at
schedule execution tie to determine whether a schedule should execute or skip. Takes a
schedule execution time to determine whether a schedule should execute or skip. Takes a
:py:class:`~dagster.ScheduleEvaluationContext` and returns a boolean (``True`` if the
schedule should execute). Defaults to a function that always returns ``True``.
environment_vars (Optional[Dict[str, str]]): Any environment variables to set when executing
Expand All @@ -133,33 +109,21 @@ def schedule(
status can be overridden from Dagit or via the GraphQL API.
"""

def inner(
fn: Callable[
...,
Union[RunRequest, SkipReason, RunConfig, RunRequestGenerator],
]
) -> ScheduleDefinition:
def inner(fn: RawScheduleEvaluationFunction) -> ScheduleDefinition:
check.callable_param(fn, "fn")

schedule_name = name or fn.__name__

# perform upfront validation of schedule tags
_tags_fn: Optional[Callable[["ScheduleEvaluationContext"], Dict[str, str]]] = None
if tags_fn and tags:
raise DagsterInvalidDefinitionError(
"Attempted to provide both tags_fn and tags as arguments"
" to ScheduleDefinition. Must provide only one of the two."
)
elif tags:
check_tags(tags, "tags")
_tags_fn = cast(Callable[["ScheduleEvaluationContext"], Dict[str, str]], lambda _: tags)
elif tags_fn:
_tags_fn = cast(
Callable[["ScheduleEvaluationContext"], Dict[str, str]],
lambda context: tags_fn(context) or {},
)

def _wrapped_fn(context: "ScheduleEvaluationContext"):
def _wrapped_fn(context: ScheduleEvaluationContext) -> RunRequestIterator:
if should_execute:
with user_code_error_boundary(
ScheduleExecutionError,
Expand All @@ -175,22 +139,26 @@ def _wrapped_fn(context: "ScheduleEvaluationContext"):
ScheduleExecutionError,
lambda: f"Error occurred during the evaluation of schedule {schedule_name}",
):
result = fn(context) if has_context_arg else fn()
if is_context_provided(fn):
result = fn(context)
else:
result = fn() # type: ignore

if isinstance(result, dict):
# this is the run-config based decorated function, wrap the evaluated run config
# and tags in a RunRequest
evaluated_run_config = copy.deepcopy(result)
evaluated_tags = _tags_fn(context) if _tags_fn else None
evaluated_tags = tags or (tags_fn and tags_fn(context)) or None
yield RunRequest(
run_key=None,
run_config=evaluated_run_config,
tags=evaluated_tags,
)
else:
# this is a run-request based decorated function
yield from ensure_gen(result)
yield from cast(RunRequestIterator, ensure_gen(result))

has_context_arg = is_context_provided(get_function_params(fn))
has_context_arg = is_context_provided(fn)
evaluation_fn = DecoratedScheduleFunction(
decorated_fn=fn,
wrapped_fn=_wrapped_fn,
Expand Down Expand Up @@ -543,7 +511,7 @@ def daily_schedule(
tags_fn_for_date: Optional[Callable[[datetime.datetime], Optional[Dict[str, str]]]] = None,
solid_selection: Optional[List[str]] = None,
mode: Optional[str] = "default",
should_execute: Optional[Callable[["ScheduleEvaluationContext"], bool]] = None,
should_execute: Optional[Callable[[ScheduleEvaluationContext], bool]] = None,
environment_vars: Optional[Dict[str, str]] = None,
end_date: Optional[datetime.datetime] = None,
execution_timezone: Optional[str] = None,
Expand Down Expand Up @@ -686,7 +654,7 @@ def hourly_schedule(
tags_fn_for_date: Optional[Callable[[datetime.datetime], Optional[Dict[str, str]]]] = None,
solid_selection: Optional[List[str]] = None,
mode: Optional[str] = "default",
should_execute: Optional[Callable[["ScheduleEvaluationContext"], bool]] = None,
should_execute: Optional[Callable[[ScheduleEvaluationContext], bool]] = None,
environment_vars: Optional[Dict[str, str]] = None,
end_date: Optional[datetime.datetime] = None,
execution_timezone: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from functools import update_wrapper
from typing import TYPE_CHECKING, Callable, Generator, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union

from dagster import check
from dagster.core.definitions.sensor_definition import (
Expand All @@ -16,10 +16,11 @@
from ..graph_definition import GraphDefinition
from ..job_definition import JobDefinition
from ..sensor_definition import (
AssetMaterializationFunction,
AssetSensorDefinition,
RawSensorEvaluationFunction,
RunRequest,
SensorDefinition,
SensorEvaluationContext,
SkipReason,
)

Expand All @@ -37,15 +38,7 @@ def sensor(
job: Optional[Union[GraphDefinition, JobDefinition]] = None,
jobs: Optional[Sequence[Union[GraphDefinition, JobDefinition]]] = None,
default_status: DefaultSensorStatus = DefaultSensorStatus.STOPPED,
) -> Callable[
[
Callable[
[SensorEvaluationContext],
Union[Generator[Union[RunRequest, SkipReason], None, None], RunRequest, SkipReason],
]
],
SensorDefinition,
]:
) -> Callable[[RawSensorEvaluationFunction], SensorDefinition]:
"""
Creates a sensor where the decorated function is used as the sensor's evaluation function. The
decorated function may:
Expand Down Expand Up @@ -80,12 +73,7 @@ def sensor(
"""
check.opt_str_param(name, "name")

def inner(
fn: Callable[
["SensorEvaluationContext"],
Union[Generator[Union[SkipReason, RunRequest], None, None], SkipReason, RunRequest],
]
) -> SensorDefinition:
def inner(fn: RawSensorEvaluationFunction) -> SensorDefinition:
check.callable_param(fn, "fn")

sensor_def = SensorDefinition(
Expand Down Expand Up @@ -119,18 +107,7 @@ def asset_sensor(
job: Optional[Union[GraphDefinition, JobDefinition]] = None,
jobs: Optional[Sequence[Union[GraphDefinition, JobDefinition]]] = None,
default_status: DefaultSensorStatus = DefaultSensorStatus.STOPPED,
) -> Callable[
[
Callable[
[
"SensorEvaluationContext",
"EventLogEntry",
],
Union[Generator[Union[RunRequest, SkipReason], None, None], RunRequest, SkipReason],
]
],
AssetSensorDefinition,
]:
) -> Callable[[AssetMaterializationFunction,], AssetSensorDefinition,]:
"""
Creates an asset sensor where the decorated function is used as the asset sensor's evaluation
function. The decorated function may:
Expand Down Expand Up @@ -167,15 +144,7 @@ def asset_sensor(

check.opt_str_param(name, "name")

def inner(
fn: Callable[
[
"SensorEvaluationContext",
"EventLogEntry",
],
Union[Generator[Union[SkipReason, RunRequest], None, None], SkipReason, RunRequest],
]
) -> AssetSensorDefinition:
def inner(fn: AssetMaterializationFunction) -> AssetSensorDefinition:
check.callable_param(fn, "fn")
sensor_name = name or fn.__name__

Expand Down

0 comments on commit 69546e9

Please sign in to comment.