Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipes] subprocess termination forwarding #18685

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions python_modules/dagster/dagster/_core/pipes/subprocess.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import signal
from subprocess import Popen
from typing import Mapping, Optional, Sequence, Union

Expand All @@ -7,7 +8,7 @@
from dagster import _check as check
from dagster._annotations import experimental, public
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.errors import DagsterPipesExecutionError
from dagster._core.errors import DagsterExecutionInterruptedError, DagsterPipesExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.pipes.client import (
PipesClient,
Expand Down Expand Up @@ -37,6 +38,10 @@ class _PipesSubprocess(PipesClient):
context into the subprocess. Defaults to :py:class:`PipesTempFileContextInjector`.
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages from
the subprocess. Defaults to :py:class:`PipesTempFileMessageReader`.
forward_termination (bool): Whether to send a SIGINT signal to the subprocess
if the orchestration process is interrupted or canceled. Defaults to True.
termination_timeout_seconds (float): How long to wait after forwarding termination
for the subprocess to exit. Defaults to 2.
"""

def __init__(
Expand All @@ -45,6 +50,8 @@ def __init__(
cwd: Optional[str] = None,
context_injector: Optional[PipesContextInjector] = None,
message_reader: Optional[PipesMessageReader] = None,
forward_termination: bool = True,
termination_timeout_seconds: float = 2,
):
self.env = check.opt_mapping_param(env, "env", key_type=str, value_type=str)
self.cwd = check.opt_str_param(cwd, "cwd")
Expand All @@ -64,6 +71,10 @@ def __init__(
)
or PipesTempFileMessageReader()
)
self.forward_termination = check.bool_param(forward_termination, "forward_termination")
self.termination_timeout_seconds = check.numeric_param(
termination_timeout_seconds, "termination_timeout_seconds"
)

@classmethod
def _is_dagster_maintained(cls) -> bool:
Expand Down Expand Up @@ -108,11 +119,20 @@ def run(
**pipes_session.get_bootstrap_env_vars(),
},
)
process.wait()
if process.returncode != 0:
raise DagsterPipesExecutionError(
f"External execution process failed with code {process.returncode}"
)
try:
process.wait()
if process.returncode != 0:
raise DagsterPipesExecutionError(
f"External execution process failed with code {process.returncode}"
)
except DagsterExecutionInterruptedError:
if self.forward_termination:
context.log.info("[pipes] execution interrupted, sending SIGINT to subprocess.")
# send sigint to give external process chance to exit gracefully
process.send_signal(signal.SIGINT)
process.wait(timeout=self.termination_timeout_seconds)
raise

return PipesClientCompletedInvocation(pipes_session)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import shutil
import subprocess
import textwrap
import time
from contextlib import contextmanager
from multiprocessing import Process
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Iterator

Expand Down Expand Up @@ -37,6 +39,7 @@
from dagster._core.errors import DagsterInvariantViolationError, DagsterPipesExecutionError
from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext
from dagster._core.execution.context.invocation import build_asset_context
from dagster._core.instance import DagsterInstance
from dagster._core.instance_for_test import instance_for_test
from dagster._core.pipes.subprocess import (
PipesSubprocessClient,
Expand All @@ -48,6 +51,7 @@
open_pipes_session,
)
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus
from dagster._utils import process_is_alive
from dagster._utils.env import environ
from dagster_pipes import DagsterPipesError

Expand Down Expand Up @@ -694,3 +698,66 @@ def bad_msg(context: OpExecutionContext, pipes_client: PipesSubprocessClient):
"Object of type Cursed is not JSON serializable"
in pipes_events[1].dagster_event.engine_event_data.error.message
)


def _execute_job(spin_timeout, subproc_log_path):
def script_fn():
import os
import time

from dagster_pipes import open_dagster_pipes

with open_dagster_pipes() as pipes:
timeout = pipes.get_extra("timeout")
log_path = pipes.get_extra("log_path")
with open(log_path, "w") as f:
f.write(f"{os.getpid()}")
f.flush()
start = time.time()
while time.time() - start < timeout:
...
Comment on lines +705 to +718
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2]


with temp_script(script_fn) as script_path:

@op
def stalling_pipes_op(
context: OpExecutionContext,
):
cmd = [_PYTHON_EXECUTABLE, script_path]
PipesSubprocessClient().run(
command=cmd,
context=context,
extras={
"timeout": spin_timeout,
"log_path": subproc_log_path,
},
)

@job
def pipes_job():
stalling_pipes_op()

return pipes_job.execute_in_process(
instance=DagsterInstance.get(),
raise_on_error=False,
)


def test_cancellation():
spin_timeout = 600
with instance_for_test(), NamedTemporaryFile() as subproc_log_path:
p = Process(target=_execute_job, args=(spin_timeout, subproc_log_path.name))
p.start()
pid = None
while p.is_alive():
data = subproc_log_path.read().decode("utf-8")
if data:
pid = int(data)
time.sleep(0.1)
p.terminate()
break

p.join(timeout=1)
assert not p.is_alive()
assert pid
assert not process_is_alive(pid)