Skip to content

Commit

Permalink
namedtuple to NamedTuple (2nd batch) (#6753)
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Feb 25, 2022
1 parent a1acd3f commit 032f442
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 53 deletions.
37 changes: 25 additions & 12 deletions python_modules/dagster/dagster/core/definitions/step_launcher.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import TYPE_CHECKING, Dict, NamedTuple, Optional

from dagster import check
from dagster.core.definitions.reconstructable import ReconstructablePipeline
from dagster.core.execution.retries import RetryMode
from dagster.core.storage.pipeline_run import PipelineRun

if TYPE_CHECKING:
from dagster.core.execution.plan.state import KnownExecutionState


class StepRunRef(
namedtuple(
NamedTuple(
"_StepRunRef",
"run_config pipeline_run run_id retry_mode step_key recon_pipeline prior_attempts_count known_state parent_run",
[
("run_config", Dict[str, object]),
("pipeline_run", PipelineRun),
("run_id", str),
("retry_mode", RetryMode),
("step_key", str),
("recon_pipeline", ReconstructablePipeline),
("prior_attempts_count", int),
("known_state", Optional["KnownExecutionState"]),
("parent_run", Optional[PipelineRun]),
],
)
):
"""
Expand All @@ -20,15 +33,15 @@ class StepRunRef(

def __new__(
cls,
run_config,
pipeline_run,
run_id,
retry_mode,
step_key,
recon_pipeline,
prior_attempts_count,
known_state,
parent_run,
run_config: Dict[str, object],
pipeline_run: PipelineRun,
run_id: str,
retry_mode: RetryMode,
step_key: str,
recon_pipeline: ReconstructablePipeline,
prior_attempts_count: int,
known_state: Optional["KnownExecutionState"],
parent_run: Optional[PipelineRun],
):
from dagster.core.execution.plan.state import KnownExecutionState

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def step_context_to_step_run_ref(
run_id=step_context.pipeline_run.run_id,
step_key=step_context.step.key,
retry_mode=retry_mode,
recon_pipeline=recon_pipeline,
recon_pipeline=recon_pipeline, # type: ignore
prior_attempts_count=prior_attempts_count,
known_state=step_context.execution_plan.known_state,
parent_run=parent_run,
Expand Down
21 changes: 13 additions & 8 deletions python_modules/dagster/dagster/core/executor/init.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from collections import namedtuple
from typing import Dict, NamedTuple

from dagster import check
from dagster.core.definitions import ExecutorDefinition, IPipeline
from dagster.core.instance import DagsterInstance


class InitExecutorContext(
namedtuple(
NamedTuple(
"InitExecutorContext",
"job executor_def executor_config instance",
[
("job", IPipeline),
("executor_def", ExecutorDefinition),
("executor_config", Dict[str, object]),
("instance", DagsterInstance),
],
)
):
"""Executor-specific initialization context.
Expand All @@ -23,16 +28,16 @@ class InitExecutorContext(

def __new__(
cls,
job,
executor_def,
executor_config,
instance,
job: IPipeline,
executor_def: ExecutorDefinition,
executor_config: Dict[str, object],
instance: DagsterInstance,
):
return super(InitExecutorContext, cls).__new__(
cls,
job=check.inst_param(job, "job", IPipeline),
executor_def=check.inst_param(executor_def, "executor_def", ExecutorDefinition),
executor_config=check.dict_param(executor_config, executor_config, key_type=str),
executor_config=check.dict_param(executor_config, "executor_config", key_type=str),
instance=check.inst_param(instance, "instance", DagsterInstance),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,26 @@

@whitelist_for_serdes
class ExternalRepositoryData(
namedtuple(
NamedTuple(
"_ExternalRepositoryData",
"name external_pipeline_datas external_schedule_datas external_partition_set_datas external_sensor_datas external_asset_graph_data",
[
("name", str),
("external_pipeline_datas", Sequence["ExternalPipelineData"]),
("external_schedule_datas", Sequence["ExternalScheduleData"]),
("external_partition_set_datas", Sequence["ExternalPartitionSetData"]),
("external_sensor_datas", Sequence["ExternalSensorData"]),
("external_asset_graph_data", Sequence["ExternalAssetNode"]),
],
)
):
def __new__(
cls,
name,
external_pipeline_datas,
external_schedule_datas,
external_partition_set_datas,
external_sensor_datas=None,
external_asset_graph_data=None,
name: str,
external_pipeline_datas: Sequence["ExternalPipelineData"],
external_schedule_datas: Sequence["ExternalScheduleData"],
external_partition_set_datas: Sequence["ExternalPartitionSetData"],
external_sensor_datas: Sequence["ExternalSensorData"] = None,
external_asset_graph_data: Sequence["ExternalAssetNode"] = None,
):
return super(ExternalRepositoryData, cls).__new__(
cls,
Expand Down Expand Up @@ -125,9 +132,21 @@ def get_external_sensor_data(self, name):

@whitelist_for_serdes
class ExternalPipelineSubsetResult(
namedtuple("_ExternalPipelineSubsetResult", "success error external_pipeline_data")
NamedTuple(
"_ExternalPipelineSubsetResult",
[
("success", bool),
("error", Optional[SerializableErrorInfo]),
("external_pipeline_data", Optional["ExternalPipelineData"]),
],
)
):
def __new__(cls, success, error=None, external_pipeline_data=None):
def __new__(
cls,
success: bool,
error: Optional[SerializableErrorInfo] = None,
external_pipeline_data: Optional["ExternalPipelineData"] = None,
):
return super(ExternalPipelineSubsetResult, cls).__new__(
cls,
success=check.bool_param(success, "success"),
Expand All @@ -140,13 +159,24 @@ def __new__(cls, success, error=None, external_pipeline_data=None):

@whitelist_for_serdes
class ExternalPipelineData(
namedtuple(
NamedTuple(
"_ExternalPipelineData",
"name pipeline_snapshot active_presets parent_pipeline_snapshot is_job",
[
("name", str),
("pipeline_snapshot", PipelineSnapshot),
("active_presets", Sequence["ExternalPresetData"]),
("parent_pipeline_snapshot", Optional[PipelineSnapshot]),
("is_job", bool),
],
)
):
def __new__(
cls, name, pipeline_snapshot, active_presets, parent_pipeline_snapshot, is_job=False
cls,
name: str,
pipeline_snapshot: PipelineSnapshot,
active_presets: Sequence["ExternalPresetData"],
parent_pipeline_snapshot: Optional[PipelineSnapshot],
is_job: bool = False,
):
return super(ExternalPipelineData, cls).__new__(
cls,
Expand Down Expand Up @@ -581,7 +611,9 @@ def __new__(
)


def external_repository_data_from_def(repository_def):
def external_repository_data_from_def(
repository_def: RepositoryDefinition,
) -> ExternalRepositoryData:
check.inst_param(repository_def, "repository_def", RepositoryDefinition)

pipelines = repository_def.get_all_pipelines()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import namedtuple
from enum import Enum
from typing import NamedTuple, Optional

from dagster import check

Expand All @@ -12,9 +12,23 @@ class LocationStateChangeEventType(Enum):


class LocationStateChangeEvent(
namedtuple("_LocationStateChangeEvent", "event_type location_name message server_id")
NamedTuple(
"_LocationStateChangeEvent",
[
("event_type", LocationStateChangeEventType),
("location_name", str),
("message", str),
("server_id", Optional[str]),
],
)
):
def __new__(cls, event_type, location_name, message, server_id=None):
def __new__(
cls,
event_type: LocationStateChangeEventType,
location_name: str,
message: str,
server_id: Optional[str] = None,
):
return super(LocationStateChangeEvent, cls).__new__(
cls,
check.inst_param(event_type, "event_type", LocationStateChangeEventType),
Expand Down
29 changes: 20 additions & 9 deletions python_modules/dagster/dagster/core/host_representation/origin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import sys
from abc import ABC, abstractmethod
from collections import namedtuple
from contextlib import contextmanager
from inspect import Parameter
from typing import (
Expand Down Expand Up @@ -320,13 +319,16 @@ def shutdown_server(self):

@whitelist_for_serdes
class ExternalRepositoryOrigin(
namedtuple("_ExternalRepositoryOrigin", "repository_location_origin repository_name")
NamedTuple(
"_ExternalRepositoryOrigin",
[("repository_location_origin", RepositoryLocationOrigin), ("repository_name", str)],
)
):
"""Serializable representation of an ExternalRepository that can be used to
uniquely it or reload it in across process boundaries.
"""

def __new__(cls, repository_location_origin, repository_name):
def __new__(cls, repository_location_origin: RepositoryLocationOrigin, repository_name: str):
return super(ExternalRepositoryOrigin, cls).__new__(
cls,
check.inst_param(
Expand All @@ -350,13 +352,16 @@ def get_partition_set_origin(self, partition_set_name):

@whitelist_for_serdes
class ExternalPipelineOrigin(
namedtuple("_ExternalPipelineOrigin", "external_repository_origin pipeline_name")
NamedTuple(
"_ExternalPipelineOrigin",
[("external_repository_origin", ExternalRepositoryOrigin), ("pipeline_name", str)],
)
):
"""Serializable representation of an ExternalPipeline that can be used to
uniquely it or reload it in across process boundaries.
"""

def __new__(cls, external_repository_origin, pipeline_name):
def __new__(cls, external_repository_origin: ExternalRepositoryOrigin, pipeline_name: str):
return super(ExternalPipelineOrigin, cls).__new__(
cls,
check.inst_param(
Expand Down Expand Up @@ -420,13 +425,16 @@ def value_to_storage_dict(

@whitelist_for_serdes(serializer=ExternalInstigatorOriginSerializer)
class ExternalInstigatorOrigin(
namedtuple("_ExternalInstigatorOrigin", "external_repository_origin instigator_name")
NamedTuple(
"_ExternalInstigatorOrigin",
[("external_repository_origin", ExternalRepositoryOrigin), ("instigator_name", str)],
)
):
"""Serializable representation of an ExternalJob that can be used to
uniquely it or reload it in across process boundaries.
"""

def __new__(cls, external_repository_origin, instigator_name):
def __new__(cls, external_repository_origin: ExternalRepositoryOrigin, instigator_name: str):
return super(ExternalInstigatorOrigin, cls).__new__(
cls,
check.inst_param(
Expand All @@ -451,13 +459,16 @@ def get_id(self):

@whitelist_for_serdes
class ExternalPartitionSetOrigin(
namedtuple("_PartitionSetOrigin", "external_repository_origin partition_set_name")
NamedTuple(
"_PartitionSetOrigin",
[("external_repository_origin", ExternalRepositoryOrigin), ("partition_set_name", str)],
)
):
"""Serializable representation of an ExternalPartitionSet that can be used to
uniquely it or reload it in across process boundaries.
"""

def __new__(cls, external_repository_origin, partition_set_name):
def __new__(cls, external_repository_origin: ExternalRepositoryOrigin, partition_set_name: str):
return super(ExternalPartitionSetOrigin, cls).__new__(
cls,
check.inst_param(
Expand Down
16 changes: 10 additions & 6 deletions python_modules/dagster/dagster/core/scheduler/execution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import namedtuple
from typing import List, NamedTuple, Optional

from dagster import check
from dagster.serdes import whitelist_for_serdes
Expand All @@ -11,16 +11,20 @@ class ScheduledExecutionResult:

@whitelist_for_serdes
class ScheduledExecutionSkipped(
namedtuple("_ScheduledExecutionSkipped", ""), ScheduledExecutionResult
NamedTuple("_ScheduledExecutionSkipped", []), ScheduledExecutionResult
):
pass


@whitelist_for_serdes
class ScheduledExecutionFailed(
namedtuple("_ScheduledExecutionFailed", "run_id errors"), ScheduledExecutionResult
NamedTuple(
"_ScheduledExecutionFailed",
[("run_id", Optional[str]), ("errors", List[SerializableErrorInfo])],
),
ScheduledExecutionResult,
):
def __new__(cls, run_id, errors):
def __new__(cls, run_id: Optional[str], errors: List[SerializableErrorInfo]):
return super(ScheduledExecutionFailed, cls).__new__(
cls,
run_id=check.opt_str_param(run_id, "run_id"),
Expand All @@ -30,9 +34,9 @@ def __new__(cls, run_id, errors):

@whitelist_for_serdes
class ScheduledExecutionSuccess(
namedtuple("_ScheduledExecutionSuccess", "run_id"), ScheduledExecutionResult
NamedTuple("_ScheduledExecutionSuccess", [("run_id", str)]), ScheduledExecutionResult
):
def __new__(cls, run_id):
def __new__(cls, run_id: str):
return super(ScheduledExecutionSuccess, cls).__new__(
cls, run_id=check.str_param(run_id, "run_id")
)

0 comments on commit 032f442

Please sign in to comment.