Skip to content

Commit

Permalink
function type annotations (batch 2) (#6933)
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Mar 7, 2022
1 parent d1d6003 commit 3c2a0b5
Show file tree
Hide file tree
Showing 14 changed files with 190 additions and 107 deletions.
43 changes: 35 additions & 8 deletions python_modules/dagster/dagster/check/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def opt_str_elem(ddict: Dict, key: str) -> Optional[str]:
def tuple_param(
obj: object,
param_name: str,
of_type: TypeOrTupleOfTypes = None,
of_type: Optional[TypeOrTupleOfTypes] = None,
of_shape: Optional[Tuple[TypeOrTupleOfTypes, ...]] = None,
) -> Tuple:
"""Ensure param is a tuple and is of a specified type. `of_type` defines a variadic tuple type--
Expand All @@ -927,11 +927,33 @@ def tuple_param(
return _check_tuple_items(obj, of_type, of_shape)


@overload
def opt_tuple_param(
obj: object,
param_name: str,
default: Tuple = None,
of_type: TypeOrTupleOfTypes = None,
default: Tuple,
of_type: Optional[TypeOrTupleOfTypes] = None,
of_shape: Optional[Tuple[TypeOrTupleOfTypes, ...]] = None,
) -> Tuple:
...


@overload
def opt_tuple_param(
obj: object,
param_name: str,
default: None = ...,
of_type: TypeOrTupleOfTypes = ...,
of_shape: Optional[Tuple[TypeOrTupleOfTypes, ...]] = ...,
) -> Optional[Tuple]:
...


def opt_tuple_param(
obj: object,
param_name: str,
default: Optional[Tuple] = None,
of_type: Optional[TypeOrTupleOfTypes] = None,
of_shape: Optional[Tuple[TypeOrTupleOfTypes, ...]] = None,
) -> Optional[Tuple]:
"""Ensure optional param is a tuple and is of a specified type. `default` is returned if `obj`
Expand All @@ -957,9 +979,9 @@ def opt_tuple_param(

def is_tuple(
obj: object,
of_type: TypeOrTupleOfTypes = None,
of_type: Optional[TypeOrTupleOfTypes] = None,
of_shape: Optional[Tuple[TypeOrTupleOfTypes, ...]] = None,
desc: str = None,
desc: Optional[str] = None,
) -> Tuple:
"""Ensure target is a tuple and is of a specified type. `of_type` defines a variadic tuple
type-- `obj` may be of any length, but each element must match the `of_type` argmument.
Expand Down Expand Up @@ -1090,7 +1112,10 @@ def _element_check_error(


def _param_type_mismatch_exception(
obj: object, ttype: TypeOrTupleOfTypes, param_name: str, additional_message: str = None
obj: object,
ttype: TypeOrTupleOfTypes,
param_name: str,
additional_message: Optional[str] = None,
) -> ParameterCheckError:
additional_message = " " + additional_message if additional_message else ""
if isinstance(ttype, tuple):
Expand Down Expand Up @@ -1119,7 +1144,9 @@ def _param_class_mismatch_exception(
)


def _type_mismatch_error(obj: object, ttype: TypeOrTupleOfTypes, desc: str = None) -> CheckError:
def _type_mismatch_error(
obj: object, ttype: TypeOrTupleOfTypes, desc: Optional[str] = None
) -> CheckError:
type_message = (
f"not one of {sorted([t.__name__ for t in ttype])}"
if isinstance(ttype, tuple)
Expand All @@ -1138,7 +1165,7 @@ def _param_not_callable_exception(obj: Any, param_name: str) -> ParameterCheckEr
)


def _param_invariant_exception(param_name: str, desc: str = None) -> ParameterCheckError:
def _param_invariant_exception(param_name: str, desc: Optional[str] = None) -> ParameterCheckError:
return ParameterCheckError(
f"Invariant violation for parameter {param_name}. Description: {desc}"
)
25 changes: 16 additions & 9 deletions python_modules/dagster/dagster/cli/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import sys
from typing import Any, Callable, Optional

import click

Expand All @@ -9,8 +10,9 @@
get_working_directory_from_kwargs,
python_origin_target_argument,
)
from dagster.core.definitions.reconstructable import ReconstructablePipeline
from dagster.core.errors import DagsterExecutionInterruptedError
from dagster.core.events import DagsterEventType, EngineEventData
from dagster.core.events import DagsterEvent, DagsterEventType, EngineEventData
from dagster.core.execution.api import create_execution_plan, execute_plan_iterator
from dagster.core.execution.run_cancellation_thread import start_run_cancellation_thread
from dagster.core.instance import DagsterInstance
Expand Down Expand Up @@ -77,18 +79,19 @@ def send_to_buffer(event):


def _execute_run_command_body(
recon_pipeline, pipeline_run_id, instance, write_stream_fn, set_exit_code_on_failure
):
recon_pipeline: ReconstructablePipeline,
pipeline_run_id: str,
instance: DagsterInstance,
write_stream_fn: Callable[[DagsterEvent], Any],
set_exit_code_on_failure: bool,
) -> int:
if instance.should_start_background_run_thread:
cancellation_thread, cancellation_thread_shutdown_event = start_run_cancellation_thread(
instance, pipeline_run_id
)

pipeline_run = instance.get_run_by_id(pipeline_run_id)

check.inst(
pipeline_run,
PipelineRun,
pipeline_run: PipelineRun = check.not_none(
instance.get_run_by_id(pipeline_run_id),
"Pipeline run with id '{}' not found for run execution.".format(pipeline_run_id),
)

Expand Down Expand Up @@ -171,7 +174,11 @@ def send_to_buffer(event):


def _resume_run_command_body(
recon_pipeline, pipeline_run_id, instance, write_stream_fn, set_exit_code_on_failure
recon_pipeline: ReconstructablePipeline,
pipeline_run_id: Optional[str],
instance: DagsterInstance,
write_stream_fn: Callable[[DagsterEvent], Any],
set_exit_code_on_failure: bool,
):
if instance.should_start_background_run_thread:
cancellation_thread, cancellation_thread_shutdown_event = start_run_cancellation_thread(
Expand Down
18 changes: 11 additions & 7 deletions python_modules/dagster/dagster/cli/config_scaffolder.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,43 @@
from typing import Optional

from dagster import PipelineDefinition, check
from dagster.config.config_type import ConfigType, ConfigTypeKind
from dagster.core.definitions import create_run_config_schema


def scaffold_pipeline_config(pipeline_def, skip_non_required=True, mode=None):
def scaffold_pipeline_config(
pipeline_def: PipelineDefinition, skip_non_required: bool = True, mode: Optional[str] = None
):
check.inst_param(pipeline_def, "pipeline_def", PipelineDefinition)
check.bool_param(skip_non_required, "skip_non_required")

env_config_type = create_run_config_schema(pipeline_def, mode=mode).config_type

env_dict = {}

for env_field_name, env_field in env_config_type.fields.items():
for env_field_name, env_field in env_config_type.fields.items(): # type: ignore
if skip_non_required and not env_field.is_required:
continue

# unfortunately we have to treat this special for now
if env_field_name == "context":
if skip_non_required and not env_config_type.fields["context"].is_required:
if skip_non_required and not env_config_type.fields["context"].is_required: # type: ignore
continue

env_dict[env_field_name] = scaffold_type(env_field.config_type, skip_non_required)

return env_dict


def scaffold_type(config_type, skip_non_required=True):
def scaffold_type(config_type: ConfigType, skip_non_required: bool = True):
check.inst_param(config_type, "config_type", ConfigType)
check.bool_param(skip_non_required, "skip_non_required")

# Right now selectors and composites have the same
# scaffolding logic, which might not be wise.
if ConfigTypeKind.has_fields(config_type.kind):
default_dict = {}
for field_name, field in config_type.fields.items():
for field_name, field in config_type.fields.items(): # type: ignore
if skip_non_required and not field.is_required:
continue

Expand All @@ -44,13 +48,13 @@ def scaffold_type(config_type, skip_non_required=True):
elif config_type.kind == ConfigTypeKind.SCALAR:
defaults = {"String": "", "Int": 0, "Bool": True}

return defaults[config_type.given_name]
return defaults[config_type.given_name] # type: ignore
elif config_type.kind == ConfigTypeKind.ARRAY:
return []
elif config_type.kind == ConfigTypeKind.MAP:
return {}
elif config_type.kind == ConfigTypeKind.ENUM:
return "|".join(sorted(map(lambda v: v.config_value, config_type.enum_values)))
return "|".join(sorted(map(lambda v: v.config_value, config_type.enum_values))) # type: ignore
else:
check.failed(
"Do not know how to scaffold {type_name}".format(type_name=config_type.given_name)
Expand Down
10 changes: 5 additions & 5 deletions python_modules/dagster/dagster/cli/debug.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from gzip import GzipFile
from typing import Tuple

import click
from tqdm import tqdm

from dagster import DagsterInstance, check
from dagster import DagsterInstance
from dagster.core.debug import DebugRunPayload
from dagster.core.storage.pipeline_run import PipelineRunStatus, RunsFilter
from dagster.serdes import deserialize_json_to_dagster_namedtuple
from dagster.serdes import deserialize_as


def _recent_failed_runs_text(instance):
Expand Down Expand Up @@ -59,13 +60,12 @@ def export_command(run_id, output_file):
name="import", help="Import the relevant artifacts for a pipeline/job run from a file."
)
@click.argument("input_files", nargs=-1, type=click.Path(exists=True))
def import_command(input_files):
def import_command(input_files: Tuple[str, ...]):
debug_payloads = []
for input_file in input_files:
with GzipFile(input_file, "rb") as file:
blob = file.read().decode("utf-8")
debug_payload = deserialize_json_to_dagster_namedtuple(blob)
check.invariant(isinstance(debug_payload, DebugRunPayload))
debug_payload = deserialize_as(blob, DebugRunPayload)
debug_payloads.append(debug_payload)

with DagsterInstance.get() as instance:
Expand Down
4 changes: 3 additions & 1 deletion python_modules/dagster/dagster/cli/job.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict

import click

from dagster import __version__ as dagster_version
Expand Down Expand Up @@ -94,7 +96,7 @@ def job_list_versions_command(**kwargs):
execute_list_versions_command(instance, kwargs)


def execute_list_versions_command(instance, kwargs):
def execute_list_versions_command(instance: DagsterInstance, kwargs: Dict[str, object]):
check.inst_param(instance, "instance", DagsterInstance)

config = list(check.opt_tuple_param(kwargs.get("config"), "config", default=(), of_type=str))
Expand Down
5 changes: 3 additions & 2 deletions python_modules/dagster/dagster/cli/load_handle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Dict, cast

from click import UsageError

Expand All @@ -16,7 +17,7 @@ def _cli_load_invariant(condition, msg=None):
raise UsageError(msg)


def recon_repo_for_cli_args(kwargs):
def recon_repo_for_cli_args(kwargs: Dict[str, object]):
"""Builds a ReconstructableRepository for CLI arguments, which can be any of the combinations
for repo loading above.
"""
Expand All @@ -39,7 +40,7 @@ def recon_repo_for_cli_args(kwargs):
_cli_load_invariant(kwargs.get("repository_yaml") is None)
_cli_load_invariant(kwargs.get("module_name") is None)
return ReconstructableRepository.for_file(
os.path.abspath(kwargs["python_file"]),
os.path.abspath(cast(str, kwargs["python_file"])),
kwargs["fn_name"],
get_working_directory_from_kwargs(kwargs),
)
Expand Down

0 comments on commit 3c2a0b5

Please sign in to comment.