Skip to content

Commit

Permalink
Limit pods per run with k8s_job_executor (#7846)
Browse files Browse the repository at this point in the history
Duplicate of #7752 so that I [could push changes](community/community#5634).

Closes #6580
  • Loading branch information
johannkm committed May 12, 2022
1 parent 61afd9a commit 21237cf
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def get_steps_to_execute(self, limit: Optional[int] = None) -> List[ExecutionSte
key=self._sort_key_fn,
)

if limit:
if limit is not None:
steps = steps[:limit]

for step in steps:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,16 @@ def __init__(
retries: RetryMode,
sleep_seconds: Optional[float] = None,
check_step_health_interval_seconds: Optional[int] = None,
max_concurrent: Optional[int] = None,
should_verify_step: bool = False,
):
self._step_handler = step_handler
self._retries = retries

self._max_concurrent = check.opt_int_param(max_concurrent, "max_concurrent")
if self._max_concurrent is not None:
check.invariant(self._max_concurrent > 0, "max_concurrent must be > 0")

self._sleep_seconds = cast(
float,
check.opt_float_param(sleep_seconds, "sleep_seconds", default=DEFAULT_SLEEP_SECONDS),
Expand Down Expand Up @@ -227,7 +233,15 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut
running_steps,
)

for step in active_execution.get_steps_to_execute():
if self._max_concurrent is not None:
max_steps_to_run = self._max_concurrent - len(running_steps)
check.invariant(
max_steps_to_run >= 0, "More steps are active than max_concurrent"
)
else:
max_steps_to_run = None # disables limit

for step in active_execution.get_steps_to_execute(max_steps_to_run):
running_steps[step.key] = step
self._log_new_events(
self._step_handler.launch_step(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import subprocess
import time
from typing import List

from dagster import executor, pipeline, reconstructable, solid
from dagster import executor, job, op, reconstructable
from dagster.config.field_utils import Permissive
from dagster.core.definitions.executor_definition import multiple_process_executor_requirements
from dagster.core.definitions.mode import ModeDefinition
from dagster.core.events import DagsterEvent, DagsterEventType
from dagster.core.execution.api import execute_pipeline
from dagster.core.execution.retries import RetryMode
from dagster.core.executor.step_delegating import StepDelegatingExecutor, StepHandler
from dagster.core.storage.fs_io_manager import fs_io_manager
from dagster.core.test_utils import instance_for_test


Expand All @@ -18,7 +17,7 @@ class TestStepHandler(StepHandler):
# are left alive when the test ends. Non-test step handlers should not keep their own state in memory.
processes = [] # type: ignore
launch_step_count = 0 # type: ignore
saw_baz_solid = False
saw_baz_op = False
check_step_health_count = 0 # type: ignore
terminate_step_count = 0 # type: ignore
verify_step_count = 0 # type: ignore
Expand All @@ -30,9 +29,9 @@ def name(self):
def launch_step(self, step_handler_context):
if step_handler_context.execute_step_args.should_verify_step:
TestStepHandler.verify_step_count += 1
if step_handler_context.execute_step_args.step_keys_to_execute[0] == "baz_solid":
TestStepHandler.saw_baz_solid = True
assert step_handler_context.step_tags["baz_solid"] == {"foo": "bar"}
if step_handler_context.execute_step_args.step_keys_to_execute[0] == "baz_op":
TestStepHandler.saw_baz_op = True
assert step_handler_context.step_tags["baz_op"] == {"foo": "bar"}

TestStepHandler.launch_step_count += 1
print("TestStepHandler Launching Step!") # pylint: disable=print-call
Expand Down Expand Up @@ -70,45 +69,33 @@ def wait_for_processes(cls):
)
def test_step_delegating_executor(exc_init):
return StepDelegatingExecutor(
TestStepHandler(),
retries=RetryMode.DISABLED,
sleep_seconds=exc_init.executor_config.get("sleep_seconds"),
check_step_health_interval_seconds=exc_init.executor_config.get(
"check_step_health_interval_seconds"
),
TestStepHandler(), retries=RetryMode.DISABLED, **exc_init.executor_config
)


@solid
def bar_solid(_):
@op
def bar_op(_):
return "bar"


@solid(tags={"foo": "bar"})
def baz_solid(_, bar):
@op(tags={"foo": "bar"})
def baz_op(_, bar):
return bar * 2


@pipeline(
mode_defs=[
ModeDefinition(
executor_defs=[test_step_delegating_executor],
resource_defs={"io_manager": fs_io_manager},
)
]
)
def foo_pipline():
baz_solid(bar_solid())
bar_solid()
@job(executor_def=test_step_delegating_executor)
def foo_job():
baz_op(bar_op())
bar_op()


def test_execute():
TestStepHandler.reset()
with instance_for_test() as instance:
result = execute_pipeline(
reconstructable(foo_pipline),
reconstructable(foo_job),
instance=instance,
run_config={"execution": {"test_step_delegating_executor": {"config": {}}}},
run_config={"execution": {"config": {}}},
)
TestStepHandler.wait_for_processes()

Expand All @@ -120,7 +107,7 @@ def test_execute():
)
assert any(["STEP_START" in event for event in result.event_list])
assert result.success
assert TestStepHandler.saw_baz_solid
assert TestStepHandler.saw_baz_op
assert TestStepHandler.verify_step_count == 0


Expand Down Expand Up @@ -180,15 +167,9 @@ def test_execute_intervals():
TestStepHandler.reset()
with instance_for_test() as instance:
result = execute_pipeline(
reconstructable(foo_pipline),
reconstructable(foo_job),
instance=instance,
run_config={
"execution": {
"test_step_delegating_executor": {
"config": {"check_step_health_interval_seconds": 60}
}
}
},
run_config={"execution": {"config": {"check_step_health_interval_seconds": 60}}},
)
TestStepHandler.wait_for_processes()

Expand All @@ -201,15 +182,9 @@ def test_execute_intervals():
TestStepHandler.reset()
with instance_for_test() as instance:
result = execute_pipeline(
reconstructable(foo_pipline),
reconstructable(foo_job),
instance=instance,
run_config={
"execution": {
"test_step_delegating_executor": {
"config": {"check_step_health_interval_seconds": 0}
}
}
},
run_config={"execution": {"config": {"check_step_health_interval_seconds": 0}}},
)
TestStepHandler.wait_for_processes()

Expand All @@ -220,6 +195,41 @@ def test_execute_intervals():
assert TestStepHandler.check_step_health_count >= 3


@op
def slow_op(_):
time.sleep(2)


@job(executor_def=test_step_delegating_executor)
def three_op_job():
for i in range(3):
slow_op.alias(f"slow_op_{i}")()


def test_max_concurrent():
TestStepHandler.reset()
with instance_for_test() as instance:
result = execute_pipeline(
reconstructable(three_op_job),
instance=instance,
run_config={"execution": {"config": {"max_concurrent": 1}}},
)
TestStepHandler.wait_for_processes()
assert result.success

# test that all the steps run serially, since max_concurrent is 1
active_step = None
for event in result.event_list:
if event.event_type_value == DagsterEventType.STEP_START.value:
assert active_step is None, "A second step started before the first finished!"
active_step = event.step_key
elif event.event_type_value == DagsterEventType.STEP_SUCCESS.value:
assert (
active_step == event.step_key
), "A step finished that wasn't supposed to be active!"
active_step = None


@executor(
name="test_step_delegating_executor_verify_step",
requirements=multiple_process_executor_requirements(),
Expand All @@ -237,26 +247,19 @@ def test_step_delegating_executor_verify_step(exc_init):
)


@pipeline(
mode_defs=[
ModeDefinition(
executor_defs=[test_step_delegating_executor_verify_step],
resource_defs={"io_manager": fs_io_manager},
)
]
)
def foo_pipline_verify_step():
baz_solid(bar_solid())
bar_solid()
@job(executor_def=test_step_delegating_executor_verify_step)
def foo_job_verify_step():
baz_op(bar_op())
bar_op()


def test_execute_verify_step():
TestStepHandler.reset()
with instance_for_test() as instance:
result = execute_pipeline(
reconstructable(foo_pipline_verify_step),
reconstructable(foo_job_verify_step),
instance=instance,
run_config={"execution": {"test_step_delegating_executor_verify_step": {"config": {}}}},
run_config={"execution": {"config": {}}},
)
TestStepHandler.wait_for_processes()

Expand Down
14 changes: 13 additions & 1 deletion python_modules/libraries/dagster-k8s/dagster_k8s/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import kubernetes
from dagster_k8s.launcher import K8sRunLauncher

from dagster import Field, StringSource
from dagster import Field, IntSource, StringSource
from dagster import _check as check
from dagster import executor
from dagster.core.definitions.executor_definition import multiple_process_executor_requirements
Expand Down Expand Up @@ -35,6 +35,12 @@
{
"job_namespace": Field(StringSource, is_required=False),
"retries": get_retries_config(),
"max_concurrency": Field(
IntSource,
is_required=False,
description="Limit on the number of pods that will run concurrently within the scope "
"of a Dagster run. Note that this limit is per run, not global.",
),
},
),
requirements=multiple_process_executor_requirements(),
Expand Down Expand Up @@ -64,6 +70,11 @@ def k8s_job_executor(init_context: InitExecutorContext) -> Executor:
env_secrets: ...
env_vars: ...
job_image: ... # leave out if using userDeployments
max_concurrent: ...
`max_concurrent` limits the number of pods that will execute concurrently for one run. By default
there is no limit- it will maximally parallel as allowed by the DAG. Note that this is not a
global limit.
Configuration set on the Kubernetes Jobs and Pods created by the `K8sRunLauncher` will also be
set on Kubernetes Jobs and Pods created by the `k8s_job_executor`.
Expand Down Expand Up @@ -100,6 +111,7 @@ def k8s_job_executor(init_context: InitExecutorContext) -> Executor:
kubeconfig_file=run_launcher.kubeconfig_file,
),
retries=RetryMode.from_config(init_context.executor_config["retries"]), # type: ignore
max_concurrent=exc_cfg.get("max_concurrent"),
should_verify_step=True,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_executor_init_container_context(
InitExecutorContext(
job=InMemoryPipeline(bar),
executor_def=k8s_job_executor,
executor_config={"env_vars": ["FOO_TEST"], "retries": {}},
executor_config={"env_vars": ["FOO_TEST"], "retries": {}, "max_concurrent": 4},
instance=k8s_run_launcher_instance,
)
)
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_executor_init_container_context(
"BAZ_TEST",
]
)

assert executor._max_concurrent == 4
assert sorted(
executor._step_handler._get_container_context(step_handler_context).resources
) == sorted(
Expand Down

0 comments on commit 21237cf

Please sign in to comment.