Skip to content

Commit

Permalink
assorted function type annotations (batch 1) (#6800)
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Feb 28, 2022
1 parent 63556df commit c1a6018
Show file tree
Hide file tree
Showing 14 changed files with 253 additions and 164 deletions.
7 changes: 6 additions & 1 deletion python_modules/dagster/dagster/api/get_server_id.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import TYPE_CHECKING

from dagster import check
from dagster.core.errors import DagsterUserCodeProcessError
from dagster.utils.error import SerializableErrorInfo

if TYPE_CHECKING:
from dagster.grpc.client import DagsterGrpcClient


def sync_get_server_id(api_client):
def sync_get_server_id(api_client: "DagsterGrpcClient") -> str:
from dagster.grpc.client import DagsterGrpcClient

check.inst_param(api_client, "api_client", DagsterGrpcClient)
Expand Down
28 changes: 17 additions & 11 deletions python_modules/dagster/dagster/api/list_repositories.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from typing import TYPE_CHECKING, Optional

from dagster import check
from dagster.core.errors import DagsterUserCodeProcessError
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
from dagster.serdes import deserialize_json_to_dagster_namedtuple
from dagster.grpc.types import ListRepositoriesResponse
from dagster.serdes import deserialize_as
from dagster.utils.error import SerializableErrorInfo

if TYPE_CHECKING:
from dagster.grpc.client import DagsterGrpcClient


def sync_list_repositories_grpc(api_client):
def sync_list_repositories_grpc(api_client: "DagsterGrpcClient") -> ListRepositoriesResponse:
from dagster.grpc.client import DagsterGrpcClient
from dagster.grpc.types import ListRepositoriesResponse

check.inst_param(api_client, "api_client", DagsterGrpcClient)
result = check.inst(
deserialize_json_to_dagster_namedtuple(api_client.list_repositories()),
result = deserialize_as(
api_client.list_repositories(),
(ListRepositoriesResponse, SerializableErrorInfo),
)
if isinstance(result, SerializableErrorInfo):
Expand All @@ -23,19 +28,20 @@ def sync_list_repositories_grpc(api_client):


def sync_list_repositories_ephemeral_grpc(
executable_path,
python_file,
module_name,
working_directory,
attribute,
package_name,
executable_path: str,
python_file: Optional[str],
module_name: Optional[str],
working_directory: Optional[str],
attribute: Optional[str],
package_name: Optional[str],
):
from dagster.grpc.client import ephemeral_grpc_api_client

check.str_param(executable_path, "executable_path")
check.opt_str_param(python_file, "python_file")
check.opt_str_param(module_name, "module_name")
check.opt_str_param(working_directory, "working_directory")
check.opt_str_param(attribute, "attribute")
check.opt_str_param(package_name, "package_name")

with ephemeral_grpc_api_client(
Expand Down
9 changes: 8 additions & 1 deletion python_modules/dagster/dagster/api/notebook_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import TYPE_CHECKING

from dagster import check

if TYPE_CHECKING:
from dagster.grpc.client import DagsterGrpcClient


def sync_get_streaming_external_notebook_data_grpc(api_client, notebook_path):
def sync_get_streaming_external_notebook_data_grpc(
api_client: "DagsterGrpcClient", notebook_path: str
):
from dagster.grpc.client import DagsterGrpcClient

check.inst_param(api_client, "api_client", DagsterGrpcClient)
Expand Down
57 changes: 29 additions & 28 deletions python_modules/dagster/dagster/api/snapshot_execution_plan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING, Any, List, Mapping, Optional

from dagster import check
from dagster.core.errors import DagsterUserCodeProcessError
from dagster.core.execution.plan.state import KnownExecutionState
Expand All @@ -8,48 +10,47 @@
ExecutionPlanSnapshotErrorData,
)
from dagster.grpc.types import ExecutionPlanSnapshotArgs
from dagster.serdes import deserialize_json_to_dagster_namedtuple
from dagster.serdes import deserialize_as

if TYPE_CHECKING:
from dagster.grpc.client import DagsterGrpcClient


def sync_get_external_execution_plan_grpc(
api_client,
pipeline_origin,
run_config,
mode,
pipeline_snapshot_id,
solid_selection=None,
step_keys_to_execute=None,
known_state=None,
instance=None,
):
api_client: "DagsterGrpcClient",
pipeline_origin: ExternalPipelineOrigin,
run_config: Mapping[str, Any],
mode: str,
pipeline_snapshot_id: str,
solid_selection: Optional[List[str]] = None,
step_keys_to_execute: Optional[List[str]] = None,
known_state: Optional[KnownExecutionState] = None,
instance: Optional[DagsterInstance] = None,
) -> ExecutionPlanSnapshot:
from dagster.grpc.client import DagsterGrpcClient

check.inst_param(api_client, "api_client", DagsterGrpcClient)
check.inst_param(pipeline_origin, "pipeline_origin", ExternalPipelineOrigin)
check.opt_list_param(solid_selection, "solid_selection", of_type=str)
check.dict_param(run_config, "run_config")
check.dict_param(run_config, "run_config", key_type=str)
check.str_param(mode, "mode")
check.opt_nullable_list_param(step_keys_to_execute, "step_keys_to_execute", of_type=str)
check.str_param(pipeline_snapshot_id, "pipeline_snapshot_id")
check.opt_inst_param(known_state, "known_state", KnownExecutionState)
check.opt_inst_param(instance, "instance", DagsterInstance)

result = check.inst(
deserialize_json_to_dagster_namedtuple(
api_client.execution_plan_snapshot(
execution_plan_snapshot_args=ExecutionPlanSnapshotArgs(
pipeline_origin=pipeline_origin,
solid_selection=solid_selection,
run_config=run_config,
mode=mode,
step_keys_to_execute=step_keys_to_execute,
pipeline_snapshot_id=pipeline_snapshot_id,
known_state=known_state,
instance_ref=instance.get_ref()
if instance and instance.is_persistent
else None,
)
),
result = deserialize_as(
api_client.execution_plan_snapshot(
execution_plan_snapshot_args=ExecutionPlanSnapshotArgs(
pipeline_origin=pipeline_origin,
solid_selection=solid_selection,
run_config=run_config,
mode=mode,
step_keys_to_execute=step_keys_to_execute,
pipeline_snapshot_id=pipeline_snapshot_id,
known_state=known_state,
instance_ref=instance.get_ref() if instance and instance.is_persistent else None,
)
),
(ExecutionPlanSnapshot, ExecutionPlanSnapshotErrorData),
)
Expand Down
88 changes: 48 additions & 40 deletions python_modules/dagster/dagster/api/snapshot_partition.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING, List

from dagster import check
from dagster.core.errors import DagsterUserCodeProcessError
from dagster.core.host_representation.external_data import (
Expand All @@ -9,24 +11,27 @@
)
from dagster.core.host_representation.handle import RepositoryHandle
from dagster.grpc.types import PartitionArgs, PartitionNamesArgs, PartitionSetExecutionParamArgs
from dagster.serdes import deserialize_json_to_dagster_namedtuple
from dagster.serdes import deserialize_as

if TYPE_CHECKING:
from dagster.grpc.client import DagsterGrpcClient


def sync_get_external_partition_names_grpc(api_client, repository_handle, partition_set_name):
def sync_get_external_partition_names_grpc(
api_client: "DagsterGrpcClient", repository_handle: RepositoryHandle, partition_set_name: str
) -> ExternalPartitionNamesData:
from dagster.grpc.client import DagsterGrpcClient

check.inst_param(api_client, "api_client", DagsterGrpcClient)
check.inst_param(repository_handle, "repository_handle", RepositoryHandle)
check.str_param(partition_set_name, "partition_set_name")
repository_origin = repository_handle.get_external_origin()
result = check.inst(
deserialize_json_to_dagster_namedtuple(
api_client.external_partition_names(
partition_names_args=PartitionNamesArgs(
repository_origin=repository_origin,
partition_set_name=partition_set_name,
),
)
result = deserialize_as(
api_client.external_partition_names(
partition_names_args=PartitionNamesArgs(
repository_origin=repository_origin,
partition_set_name=partition_set_name,
),
),
(ExternalPartitionNamesData, ExternalPartitionExecutionErrorData),
)
Expand All @@ -37,23 +42,24 @@ def sync_get_external_partition_names_grpc(api_client, repository_handle, partit


def sync_get_external_partition_config_grpc(
api_client, repository_handle, partition_set_name, partition_name
):
api_client: "DagsterGrpcClient",
repository_handle: RepositoryHandle,
partition_set_name: str,
partition_name: str,
) -> ExternalPartitionConfigData:
from dagster.grpc.client import DagsterGrpcClient

check.inst_param(api_client, "api_client", DagsterGrpcClient)
check.inst_param(repository_handle, "repository_handle", RepositoryHandle)
check.str_param(partition_set_name, "partition_set_name")
check.str_param(partition_name, "partition_name")
repository_origin = repository_handle.get_external_origin()
result = check.inst(
deserialize_json_to_dagster_namedtuple(
api_client.external_partition_config(
partition_args=PartitionArgs(
repository_origin=repository_origin,
partition_set_name=partition_set_name,
partition_name=partition_name,
),
result = deserialize_as(
api_client.external_partition_config(
partition_args=PartitionArgs(
repository_origin=repository_origin,
partition_set_name=partition_set_name,
partition_name=partition_name,
),
),
(ExternalPartitionConfigData, ExternalPartitionExecutionErrorData),
Expand All @@ -65,8 +71,11 @@ def sync_get_external_partition_config_grpc(


def sync_get_external_partition_tags_grpc(
api_client, repository_handle, partition_set_name, partition_name
):
api_client: "DagsterGrpcClient",
repository_handle: RepositoryHandle,
partition_set_name: str,
partition_name: str,
) -> ExternalPartitionTagsData:
from dagster.grpc.client import DagsterGrpcClient

check.inst_param(api_client, "api_client", DagsterGrpcClient)
Expand All @@ -75,14 +84,12 @@ def sync_get_external_partition_tags_grpc(
check.str_param(partition_name, "partition_name")

repository_origin = repository_handle.get_external_origin()
result = check.inst(
deserialize_json_to_dagster_namedtuple(
api_client.external_partition_tags(
partition_args=PartitionArgs(
repository_origin=repository_origin,
partition_set_name=partition_set_name,
partition_name=partition_name,
),
result = deserialize_as(
api_client.external_partition_tags(
partition_args=PartitionArgs(
repository_origin=repository_origin,
partition_set_name=partition_set_name,
partition_name=partition_name,
),
),
(ExternalPartitionTagsData, ExternalPartitionExecutionErrorData),
Expand All @@ -94,8 +101,11 @@ def sync_get_external_partition_tags_grpc(


def sync_get_external_partition_set_execution_param_data_grpc(
api_client, repository_handle, partition_set_name, partition_names
):
api_client: "DagsterGrpcClient",
repository_handle: RepositoryHandle,
partition_set_name: str,
partition_names: List[str],
) -> ExternalPartitionSetExecutionParamData:
from dagster.grpc.client import DagsterGrpcClient

check.inst_param(api_client, "api_client", DagsterGrpcClient)
Expand All @@ -105,14 +115,12 @@ def sync_get_external_partition_set_execution_param_data_grpc(

repository_origin = repository_handle.get_external_origin()

result = check.inst(
deserialize_json_to_dagster_namedtuple(
api_client.external_partition_set_execution_params(
partition_set_execution_param_args=PartitionSetExecutionParamArgs(
repository_origin=repository_origin,
partition_set_name=partition_set_name,
partition_names=partition_names,
),
result = deserialize_as(
api_client.external_partition_set_execution_params(
partition_set_execution_param_args=PartitionSetExecutionParamArgs(
repository_origin=repository_origin,
partition_set_name=partition_set_name,
partition_names=partition_names,
),
),
(ExternalPartitionSetExecutionParamData, ExternalPartitionExecutionErrorData),
Expand Down
23 changes: 15 additions & 8 deletions python_modules/dagster/dagster/api/snapshot_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
from typing import TYPE_CHECKING, List, Optional

from dagster import check
from dagster.core.errors import DagsterUserCodeProcessError
from dagster.core.host_representation.external_data import ExternalPipelineSubsetResult
from dagster.core.host_representation.origin import ExternalPipelineOrigin
from dagster.grpc.types import PipelineSubsetSnapshotArgs
from dagster.serdes import deserialize_json_to_dagster_namedtuple
from dagster.serdes import deserialize_as

if TYPE_CHECKING:
from dagster.grpc.client import DagsterGrpcClient


def sync_get_external_pipeline_subset_grpc(api_client, pipeline_origin, solid_selection=None):
def sync_get_external_pipeline_subset_grpc(
api_client: "DagsterGrpcClient",
pipeline_origin: ExternalPipelineOrigin,
solid_selection: Optional[List[str]] = None,
) -> ExternalPipelineSubsetResult:
from dagster.grpc.client import DagsterGrpcClient

check.inst_param(api_client, "api_client", DagsterGrpcClient)
check.inst_param(pipeline_origin, "pipeline_origin", ExternalPipelineOrigin)
check.opt_list_param(solid_selection, "solid_selection", of_type=str)

result = check.inst(
deserialize_json_to_dagster_namedtuple(
api_client.external_pipeline_subset(
pipeline_subset_snapshot_args=PipelineSubsetSnapshotArgs(
pipeline_origin=pipeline_origin, solid_selection=solid_selection
),
result = deserialize_as(
api_client.external_pipeline_subset(
pipeline_subset_snapshot_args=PipelineSubsetSnapshotArgs(
pipeline_origin=pipeline_origin, solid_selection=solid_selection
),
),
ExternalPipelineSubsetResult,
Expand Down

0 comments on commit c1a6018

Please sign in to comment.