Skip to content

Commit

Permalink
[typing] fix typing in daemon tests (#12475)
Browse files Browse the repository at this point in the history
Adds typing to fixtures used in daemon tests.
  • Loading branch information
dpeng817 committed Feb 24, 2023
1 parent 65e92a1 commit 281ec4b
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 72 deletions.
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/_core/execution/backfill.py
Expand Up @@ -57,7 +57,7 @@ def __new__(
backfill_id: str,
status: BulkActionStatus,
from_failure: bool,
tags: Mapping[str, str],
tags: Optional[Mapping[str, str]],
backfill_timestamp: float,
error: Optional[SerializableErrorInfo] = None,
asset_selection: Optional[Sequence[AssetKey]] = None,
Expand Down
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_core/test_utils.py
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import NamedTuple, Optional, Sequence, TypeVar
from typing import Generator, NamedTuple, Optional, Sequence, TypeVar

import pendulum

Expand Down Expand Up @@ -452,7 +452,7 @@ def in_process_test_workspace(instance, loadable_target_origin, container_image=
def create_test_daemon_workspace_context(
workspace_load_target: WorkspaceLoadTarget,
instance: DagsterInstance,
):
) -> Generator[WorkspaceProcessContext, None, None]:
"""Creates a DynamicWorkspace suitable for passing into a DagsterDaemon loop when running tests.
"""
from dagster._daemon.controller import create_daemon_grpc_server_registry
Expand Down
30 changes: 21 additions & 9 deletions python_modules/dagster/dagster_tests/daemon_tests/conftest.py
@@ -1,18 +1,25 @@
import os
import sys
from typing import Iterator, Optional, cast

import pytest
from dagster._core.host_representation import InProcessRepositoryLocationOrigin
from dagster import DagsterInstance
from dagster._core.host_representation import (
ExternalRepository,
InProcessRepositoryLocationOrigin,
RepositoryLocation,
)
from dagster._core.test_utils import (
InProcessTestWorkspaceLoadTarget,
create_test_daemon_workspace_context,
instance_for_test,
)
from dagster._core.types.loadable_target_origin import LoadableTargetOrigin
from dagster._core.workspace.context import WorkspaceProcessContext


@pytest.fixture(name="instance_module_scoped", scope="module")
def instance_module_scoped_fixture():
def instance_module_scoped_fixture() -> Iterator[DagsterInstance]:
with instance_for_test(
overrides={
"run_launcher": {
Expand All @@ -25,7 +32,7 @@ def instance_module_scoped_fixture():


@pytest.fixture(name="instance", scope="function")
def instance_fixture(instance_module_scoped):
def instance_fixture(instance_module_scoped) -> Iterator[DagsterInstance]:
instance_module_scoped.wipe()
instance_module_scoped.wipe_all_schedules()
yield instance_module_scoped
Expand All @@ -41,21 +48,26 @@ def workspace_load_target(attribute=None):


@pytest.fixture(name="workspace_context", scope="module")
def workspace_fixture(instance_module_scoped):
def workspace_fixture(instance_module_scoped) -> Iterator[WorkspaceProcessContext]:
with create_test_daemon_workspace_context(
workspace_load_target=workspace_load_target(), instance=instance_module_scoped
) as workspace_context:
yield workspace_context


@pytest.fixture(name="external_repo", scope="module")
def external_repo_fixture(workspace_context):
return next(
iter(workspace_context.create_request_context().get_workspace_snapshot().values())
).repository_location.get_repository("the_repo")
def external_repo_fixture(
workspace_context: WorkspaceProcessContext,
) -> Iterator[ExternalRepository]:
yield cast(
RepositoryLocation,
next(
iter(workspace_context.create_request_context().get_workspace_snapshot().values())
).repository_location,
).get_repository("the_repo")


def loadable_target_origin(attribute=None):
def loadable_target_origin(attribute: Optional[str] = None) -> LoadableTargetOrigin:
return LoadableTargetOrigin(
executable_path=sys.executable,
module_name="dagster_tests.daemon_tests.test_backfill",
Expand Down
104 changes: 75 additions & 29 deletions python_modules/dagster/dagster_tests/daemon_tests/test_backfill.py
Expand Up @@ -10,6 +10,7 @@
Any,
AssetKey,
AssetsDefinition,
DagsterInstance,
Field,
In,
Nothing,
Expand All @@ -27,18 +28,24 @@
from dagster._core.execution.api import execute_pipeline
from dagster._core.execution.backfill import BulkActionStatus, PartitionBackfill
from dagster._core.host_representation import (
ExternalRepository,
ExternalRepositoryOrigin,
InProcessRepositoryLocationOrigin,
)
from dagster._core.storage.pipeline_run import DagsterRunStatus, RunsFilter
from dagster._core.storage.tags import BACKFILL_ID_TAG, PARTITION_NAME_TAG, PARTITION_SET_TAG
from dagster._core.test_utils import step_did_not_run, step_failed, step_succeeded
from dagster._core.test_utils import (
step_did_not_run,
step_failed,
step_succeeded,
)
from dagster._core.types.loadable_target_origin import LoadableTargetOrigin
from dagster._core.workspace.context import WorkspaceProcessContext
from dagster._daemon import get_default_daemon_logger
from dagster._daemon.backfill import execute_backfill_iteration
from dagster._legacy import ModeDefinition, pipeline
from dagster._seven import IS_WINDOWS, get_system_temp_directory
from dagster._utils import touch_file
from dagster._utils import len_iter, touch_file
from dagster._utils.error import SerializableErrorInfo

default_mode_def = ModeDefinition(resource_defs={"io_manager": fs_io_manager})
Expand Down Expand Up @@ -304,7 +311,11 @@ def wait_for_all_runs_to_finish(instance, timeout=10):
break


def test_simple_backfill(instance, workspace_context, external_repo):
def test_simple_backfill(
instance: DagsterInstance,
workspace_context: WorkspaceProcessContext,
external_repo: ExternalRepository,
):
external_partition_set = external_repo.get_external_partition_set("simple_partition_set")
instance.add_backfill(
PartitionBackfill(
Expand Down Expand Up @@ -333,7 +344,11 @@ def test_simple_backfill(instance, workspace_context, external_repo):
assert three.tags[PARTITION_NAME_TAG] == "three"


def test_canceled_backfill(instance, workspace_context, external_repo):
def test_canceled_backfill(
instance: DagsterInstance,
workspace_context: WorkspaceProcessContext,
external_repo: ExternalRepository,
):
external_partition_set = external_repo.get_external_partition_set("simple_partition_set")
instance.add_backfill(
PartitionBackfill(
Expand All @@ -349,8 +364,8 @@ def test_canceled_backfill(instance, workspace_context, external_repo):
)
assert instance.get_runs_count() == 0

iterator = execute_backfill_iteration(
workspace_context, get_default_daemon_logger("BackfillDaemon")
iterator = iter(
execute_backfill_iteration(workspace_context, get_default_daemon_logger("BackfillDaemon"))
)
next(iterator)
assert instance.get_runs_count() == 1
Expand All @@ -359,11 +374,16 @@ def test_canceled_backfill(instance, workspace_context, external_repo):
instance.update_backfill(backfill.with_status(BulkActionStatus.CANCELED))
list(iterator)
backfill = instance.get_backfill(backfill.backfill_id)
assert backfill
assert backfill.status == BulkActionStatus.CANCELED
assert instance.get_runs_count() == 1


def test_failure_backfill(instance, workspace_context, external_repo):
def test_failure_backfill(
instance: DagsterInstance,
workspace_context: WorkspaceProcessContext,
external_repo: ExternalRepository,
):
output_file = _failure_flag_file()
external_partition_set = external_repo.get_external_partition_set(
"conditionally_fail_partition_set"
Expand Down Expand Up @@ -464,7 +484,11 @@ def test_failure_backfill(instance, workspace_context, external_repo):


@pytest.mark.skipif(IS_WINDOWS, reason="flaky in windows")
def test_partial_backfill(instance, workspace_context, external_repo):
def test_partial_backfill(
instance: DagsterInstance,
workspace_context: WorkspaceProcessContext,
external_repo: ExternalRepository,
):
external_partition_set = external_repo.get_external_partition_set("partial_partition_set")

# create full runs, where every step is executed
Expand Down Expand Up @@ -552,7 +576,11 @@ def test_partial_backfill(instance, workspace_context, external_repo):
assert step_did_not_run(instance, three, "step_three")


def test_large_backfill(instance, workspace_context, external_repo):
def test_large_backfill(
instance: DagsterInstance,
workspace_context: WorkspaceProcessContext,
external_repo: ExternalRepository,
):
external_partition_set = external_repo.get_external_partition_set("large_partition_set")
instance.add_backfill(
PartitionBackfill(
Expand Down Expand Up @@ -597,7 +625,11 @@ def test_unloadable_backfill(instance, workspace_context):
assert isinstance(backfill.error, SerializableErrorInfo)


def test_backfill_from_partitioned_job(instance, workspace_context, external_repo):
def test_backfill_from_partitioned_job(
instance: DagsterInstance,
workspace_context: WorkspaceProcessContext,
external_repo: ExternalRepository,
):
partition_name_list = [
partition.name for partition in my_config.partitions_def.get_partitions()
]
Expand All @@ -621,17 +653,23 @@ def test_backfill_from_partitioned_job(instance, workspace_context, external_rep
list(execute_backfill_iteration(workspace_context, get_default_daemon_logger("BackfillDaemon")))

assert instance.get_runs_count() == 3
runs = reversed(instance.get_runs())
runs = reversed(list(instance.get_runs()))
for idx, run in enumerate(runs):
assert run.tags[BACKFILL_ID_TAG] == "partition_schedule_from_job"
assert run.tags[PARTITION_NAME_TAG] == partition_name_list[idx]
assert run.tags[PARTITION_SET_TAG] == "comp_always_succeed_partition_set"


def test_backfill_with_asset_selection(instance, workspace_context, external_repo):
def test_backfill_with_asset_selection(
instance: DagsterInstance,
workspace_context: WorkspaceProcessContext,
external_repo: ExternalRepository,
):
partition_name_list = [partition.name for partition in static_partitions.get_partitions()]
asset_selection = [AssetKey("foo"), AssetKey("a1"), AssetKey("bar")]
asset_job_name = the_repo.get_implicit_job_def_for_assets(asset_selection).name
job_def = the_repo.get_implicit_job_def_for_assets(asset_selection)
assert job_def
asset_job_name = job_def.name
partition_set_name = f"{asset_job_name}_partition_set"
external_partition_set = external_repo.get_external_partition_set(partition_set_name)
instance.add_backfill(
Expand All @@ -655,7 +693,7 @@ def test_backfill_with_asset_selection(instance, workspace_context, external_rep
wait_for_all_runs_to_finish(instance, timeout=30)

assert instance.get_runs_count() == 3
runs = reversed(instance.get_runs())
runs = reversed(list(instance.get_runs()))
for idx, run in enumerate(runs):
assert run.tags[BACKFILL_ID_TAG] == "backfill_with_asset_selection"
assert run.tags[PARTITION_NAME_TAG] == partition_name_list[idx]
Expand All @@ -665,13 +703,17 @@ def test_backfill_with_asset_selection(instance, workspace_context, external_rep
assert step_succeeded(instance, run, "bar")
# selected
for asset_key in asset_selection:
assert len(instance.run_ids_for_asset_key(asset_key)) == 3
assert len_iter(instance.run_ids_for_asset_key(asset_key)) == 3
# not selected
for asset_key in [AssetKey("a2"), AssetKey("b2"), AssetKey("baz")]:
assert len(instance.run_ids_for_asset_key(asset_key)) == 0
assert len_iter(instance.run_ids_for_asset_key(asset_key)) == 0


def test_pure_asset_backfill(instance, workspace_context, external_repo):
def test_pure_asset_backfill(
instance: DagsterInstance,
workspace_context: WorkspaceProcessContext,
external_repo: ExternalRepository,
):
del external_repo

partition_name_list = [partition.name for partition in static_partitions.get_partitions()]
Expand All @@ -690,9 +732,9 @@ def test_pure_asset_backfill(instance, workspace_context, external_repo):
)
)
assert instance.get_runs_count() == 0
assert (
instance.get_backfill("backfill_with_asset_selection").status == BulkActionStatus.REQUESTED
)
backfill = instance.get_backfill("backfill_with_asset_selection")
assert backfill
assert backfill.status == BulkActionStatus.REQUESTED

list(execute_backfill_iteration(workspace_context, get_default_daemon_logger("BackfillDaemon")))
assert instance.get_runs_count() == 3
Expand All @@ -701,25 +743,29 @@ def test_pure_asset_backfill(instance, workspace_context, external_repo):
wait_for_all_runs_to_finish(instance, timeout=30)

assert instance.get_runs_count() == 3
runs = reversed(instance.get_runs())
runs = reversed(list(instance.get_runs()))
for run in runs:
assert run.tags[BACKFILL_ID_TAG] == "backfill_with_asset_selection"
assert step_succeeded(instance, run, "foo")
assert step_succeeded(instance, run, "reusable")
assert step_succeeded(instance, run, "bar")
# selected
for asset_key in asset_selection:
assert len(instance.run_ids_for_asset_key(asset_key)) == 3
assert len_iter(instance.run_ids_for_asset_key(asset_key)) == 3
# not selected
for asset_key in [AssetKey("a2"), AssetKey("b2"), AssetKey("baz")]:
assert len(instance.run_ids_for_asset_key(asset_key)) == 0
assert len_iter(instance.run_ids_for_asset_key(asset_key)) == 0

assert (
instance.get_backfill("backfill_with_asset_selection").status == BulkActionStatus.COMPLETED
)
backfill = instance.get_backfill("backfill_with_asset_selection")
assert backfill
assert backfill.status == BulkActionStatus.COMPLETED


def test_backfill_from_failure_for_subselection(instance, workspace_context, external_repo):
def test_backfill_from_failure_for_subselection(
instance: DagsterInstance,
workspace_context: WorkspaceProcessContext,
external_repo: ExternalRepository,
):
partition = parallel_failure_partition_set.get_partition("one")
run_config = parallel_failure_partition_set.run_config_for_partition(partition)
tags = parallel_failure_partition_set.tags_for_partition(partition)
Expand All @@ -738,7 +784,7 @@ def test_backfill_from_failure_for_subselection(instance, workspace_context, ext

assert instance.get_runs_count() == 1
wait_for_all_runs_to_finish(instance)
run = instance.get_runs()[0]
run = list(instance.get_runs())[0]
assert run.status == DagsterRunStatus.FAILURE

instance.add_backfill(
Expand All @@ -756,7 +802,7 @@ def test_backfill_from_failure_for_subselection(instance, workspace_context, ext

list(execute_backfill_iteration(workspace_context, get_default_daemon_logger("BackfillDaemon")))
assert instance.get_runs_count() == 2
run = instance.get_runs(limit=1)[0]
run = list(instance.get_runs(limit=1))[0]
assert run.solids_to_execute
assert run.solid_selection
assert len(run.solids_to_execute) == 2
Expand Down

0 comments on commit 281ec4b

Please sign in to comment.