Skip to content

Commit

Permalink
Derive origin from pipeline run instead of the arg to ExecuteRunArgs (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
gibsondan committed Jun 7, 2022
1 parent 20f5889 commit 38b71a2
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 35 deletions.
41 changes: 31 additions & 10 deletions python_modules/dagster/dagster/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
import sys
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, cast

import click

Expand All @@ -12,13 +12,16 @@
get_working_directory_from_kwargs,
python_origin_target_argument,
)
from dagster.core.definitions.reconstruct import ReconstructablePipeline
from dagster.core.errors import DagsterExecutionInterruptedError
from dagster.core.events import DagsterEvent, DagsterEventType, EngineEventData
from dagster.core.execution.api import create_execution_plan, execute_plan_iterator
from dagster.core.execution.run_cancellation_thread import start_run_cancellation_thread
from dagster.core.instance import DagsterInstance
from dagster.core.origin import DEFAULT_DAGSTER_ENTRY_POINT, get_python_environment_entry_point
from dagster.core.origin import (
DEFAULT_DAGSTER_ENTRY_POINT,
PipelinePythonOrigin,
get_python_environment_entry_point,
)
from dagster.core.storage.pipeline_run import PipelineRun
from dagster.core.test_utils import mock_system_timezone
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
Expand Down Expand Up @@ -53,7 +56,6 @@ def api_cli():
def execute_run_command(input_json):
with capture_interrupts():
args = deserialize_as(input_json, ExecuteRunArgs)
recon_pipeline = recon_pipeline_from_origin(args.pipeline_origin)

with (
DagsterInstance.from_ref(args.instance_ref)
Expand All @@ -66,7 +68,6 @@ def send_to_buffer(event):
buffer.append(serialize_dagster_namedtuple(event))

return_code = _execute_run_command_body(
recon_pipeline,
args.pipeline_run_id,
instance,
send_to_buffer,
Expand All @@ -81,7 +82,6 @@ def send_to_buffer(event):


def _execute_run_command_body(
recon_pipeline: ReconstructablePipeline,
pipeline_run_id: str,
instance: DagsterInstance,
write_stream_fn: Callable[[DagsterEvent], Any],
Expand All @@ -97,6 +97,16 @@ def _execute_run_command_body(
"Pipeline run with id '{}' not found for run execution.".format(pipeline_run_id),
)

check.inst(
pipeline_run.pipeline_code_origin,
PipelinePythonOrigin,
"Pipeline run with id '{}' does not include an origin.".format(pipeline_run_id),
)

recon_pipeline = recon_pipeline_from_origin(
cast(PipelinePythonOrigin, pipeline_run.pipeline_code_origin)
)

pid = os.getpid()
instance.report_engine_event(
"Started process for run (pid: {pid}).".format(pid=pid),
Expand Down Expand Up @@ -148,7 +158,6 @@ def _execute_run_command_body(
def resume_run_command(input_json):
with capture_interrupts():
args = deserialize_as(input_json, ResumeRunArgs)
recon_pipeline = recon_pipeline_from_origin(args.pipeline_origin)

with (
DagsterInstance.from_ref(args.instance_ref)
Expand All @@ -161,7 +170,6 @@ def send_to_buffer(event):
buffer.append(serialize_dagster_namedtuple(event))

return_code = _resume_run_command_body(
recon_pipeline,
args.pipeline_run_id,
instance,
send_to_buffer,
Expand All @@ -176,7 +184,6 @@ def send_to_buffer(event):


def _resume_run_command_body(
recon_pipeline: ReconstructablePipeline,
pipeline_run_id: Optional[str],
instance: DagsterInstance,
write_stream_fn: Callable[[DagsterEvent], Any],
Expand All @@ -192,6 +199,15 @@ def _resume_run_command_body(
PipelineRun,
"Pipeline run with id '{}' not found for run execution.".format(pipeline_run_id),
)
check.inst(
pipeline_run.pipeline_code_origin,
PipelinePythonOrigin,
"Pipeline run with id '{}' does not include an origin.".format(pipeline_run_id),
)

recon_pipeline = recon_pipeline_from_origin(
cast(PipelinePythonOrigin, pipeline_run.pipeline_code_origin)
)

pid = os.getpid()
instance.report_engine_event(
Expand Down Expand Up @@ -339,6 +355,11 @@ def _execute_step_command_body(
PipelineRun,
"Pipeline run with id '{}' not found for step execution".format(args.pipeline_run_id),
)
check.inst(
pipeline_run.pipeline_code_origin,
PipelinePythonOrigin,
"Pipeline run with id '{}' does not include an origin.".format(args.pipeline_run_id),
)

if args.should_verify_step:
success = verify_step(
Expand All @@ -351,7 +372,7 @@ def _execute_step_command_body(
return

recon_pipeline = recon_pipeline_from_origin(
args.pipeline_origin
cast(PipelinePythonOrigin, pipeline_run.pipeline_code_origin)
).subset_for_execution_from_existing_pipeline(
pipeline_run.solids_to_execute, pipeline_run.asset_selection
)
Expand Down
2 changes: 2 additions & 0 deletions python_modules/dagster/dagster/core/instance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,7 @@ def register_managed_run(
execution_plan_snapshot,
parent_pipeline_snapshot,
solid_selection=None,
pipeline_code_origin=None,
):
# The usage of this method is limited to dagster-airflow, specifically in Dagster
# Operators that are executed in Airflow. Because a common workflow in Airflow is to
Expand Down Expand Up @@ -1190,6 +1191,7 @@ def register_managed_run(
pipeline_snapshot=pipeline_snapshot,
execution_plan_snapshot=execution_plan_snapshot,
parent_pipeline_snapshot=parent_pipeline_snapshot,
pipeline_code_origin=pipeline_code_origin,
)

def get_run():
Expand Down
10 changes: 6 additions & 4 deletions python_modules/dagster/dagster/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,12 +653,14 @@ def StartRun(self, request, _context):
)

try:
execute_run_args = check.inst(
execute_external_pipeline_args = check.inst(
deserialize_json_to_dagster_namedtuple(request.serialized_execute_run_args),
ExecuteExternalPipelineArgs,
)
run_id = execute_run_args.pipeline_run_id
recon_pipeline = self._recon_pipeline_from_origin(execute_run_args.pipeline_origin)
run_id = execute_external_pipeline_args.pipeline_run_id
recon_pipeline = self._recon_pipeline_from_origin(
execute_external_pipeline_args.pipeline_origin
)

except:
return api_pb2.StartRunReply(
Expand Down Expand Up @@ -689,7 +691,7 @@ def StartRun(self, request, _context):
execution_process.start()
self._executions[run_id] = (
execution_process,
execute_run_args.instance_ref,
execute_external_pipeline_args.instance_ref,
)
self._termination_events[run_id] = termination_event

Expand Down
3 changes: 3 additions & 0 deletions python_modules/dagster/dagster/grpc/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ExecuteRunArgs(
NamedTuple(
"_ExecuteRunArgs",
[
# Deprecated, only needed for back-compat since it can be pulled from the PipelineRun
("pipeline_origin", PipelinePythonOrigin),
("pipeline_run_id", str),
("instance_ref", Optional[InstanceRef]),
Expand Down Expand Up @@ -123,6 +124,7 @@ class ResumeRunArgs(
NamedTuple(
"_ResumeRunArgs",
[
# Deprecated, only needed for back-compat since it can be pulled from the PipelineRun
("pipeline_origin", PipelinePythonOrigin),
("pipeline_run_id", str),
("instance_ref", Optional[InstanceRef]),
Expand Down Expand Up @@ -196,6 +198,7 @@ class ExecuteStepArgs(
NamedTuple(
"_ExecuteStepArgs",
[
# Deprecated, only needed for back-compat since it can be pulled from the PipelineRun
("pipeline_origin", PipelinePythonOrigin),
("pipeline_run_id", str),
("step_keys_to_execute", Optional[List[str]]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def test_execute_run():
runner = CliRunner()

instance = DagsterInstance.get()
run = create_run_for_test(instance, pipeline_name="foo", run_id="new_run")
run = create_run_for_test(
instance,
pipeline_name="foo",
run_id="new_run",
pipeline_code_origin=pipeline_handle.get_python_origin(),
)

input_json = serialize_dagster_namedtuple(
ExecuteRunArgs(
Expand Down Expand Up @@ -83,7 +88,12 @@ def test_execute_run_fail_pipeline():
runner = CliRunner()

instance = DagsterInstance.get()
run = create_run_for_test(instance, pipeline_name="foo", run_id="new_run")
run = create_run_for_test(
instance,
pipeline_name="foo",
run_id="new_run",
pipeline_code_origin=pipeline_handle.get_python_origin(),
)

input_json = serialize_dagster_namedtuple(
ExecuteRunArgs(
Expand All @@ -102,7 +112,10 @@ def test_execute_run_fail_pipeline():
assert "RUN_FAILURE" in result.stdout, "no match, result: {}".format(result)

run = create_run_for_test(
instance, pipeline_name="foo", run_id="new_run_raise_on_error"
instance,
pipeline_name="foo",
run_id="new_run_raise_on_error",
pipeline_code_origin=pipeline_handle.get_python_origin(),
)

input_json_raise_on_failure = serialize_dagster_namedtuple(
Expand Down Expand Up @@ -208,7 +221,12 @@ def test_execute_step():
with get_foo_pipeline_handle(instance) as pipeline_handle:
runner = CliRunner()

run = create_run_for_test(instance, pipeline_name="foo", run_id="new_run")
run = create_run_for_test(
instance,
pipeline_name="foo",
run_id="new_run",
pipeline_code_origin=pipeline_handle.get_python_origin(),
)

input_json = serialize_dagster_namedtuple(
ExecuteStepArgs(
Expand Down Expand Up @@ -240,7 +258,12 @@ def test_execute_step_1():
with get_foo_pipeline_handle(instance) as pipeline_handle:
runner = CliRunner()

run = create_run_for_test(instance, pipeline_name="foo", run_id="new_run")
run = create_run_for_test(
instance,
pipeline_name="foo",
run_id="new_run",
pipeline_code_origin=pipeline_handle.get_python_origin(),
)

input_json = serialize_dagster_namedtuple(
ExecuteStepArgs(
Expand Down Expand Up @@ -275,6 +298,7 @@ def test_execute_step_verify_step():
instance,
pipeline_name="foo",
run_id="new_run",
pipeline_code_origin=pipeline_handle.get_python_origin(),
)

input_json = serialize_dagster_namedtuple(
Expand Down Expand Up @@ -338,6 +362,7 @@ def test_execute_step_verify_step_framework_error(mock_verify_step):
instance,
pipeline_name="foo",
run_id="new_run",
pipeline_code_origin=pipeline_handle.get_python_origin(),
)

input_json = serialize_dagster_namedtuple(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def foo_pipline():
def test_step_handler_context():
recon_pipeline = reconstructable(foo_pipline)
with instance_for_test() as instance:
run = create_run_for_test(instance)
run = create_run_for_test(instance, pipeline_code_origin=recon_pipeline.get_python_origin())

args = ExecuteStepArgs(
pipeline_origin=recon_pipeline.get_python_origin(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def execute(self, context):
try:
tags = {AIRFLOW_EXECUTION_DATE_STR: context.get("ts")} if "ts" in context else {}

recon_pipeline = self.recon_repo.get_reconstructable_pipeline(self.pipeline_name)

pipeline_run = self.instance.register_managed_run(
pipeline_name=self.pipeline_name,
run_id=self.run_id,
Expand All @@ -242,6 +244,7 @@ def execute(self, context):
pipeline_snapshot=self.pipeline_snapshot,
execution_plan_snapshot=self.execution_plan_snapshot,
parent_pipeline_snapshot=self.parent_pipeline_snapshot,
pipeline_code_origin=recon_pipeline.get_python_origin(),
)
if self._should_skip(pipeline_run):
raise AirflowSkipException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def invoke_steps_within_python_operator(
invocation_args, ts, dag_run, **kwargs
): # pylint: disable=unused-argument
mode = invocation_args.mode
recon_repo = invocation_args.recon_repo
pipeline_name = invocation_args.pipeline_name
step_keys = invocation_args.step_keys
instance_ref = invocation_args.instance_ref
Expand All @@ -93,6 +94,8 @@ def invoke_steps_within_python_operator(
execution_plan_snapshot = invocation_args.execution_plan_snapshot
parent_pipeline_snapshot = invocation_args.parent_pipeline_snapshot

recon_pipeline = recon_repo.get_reconstructable_pipeline(pipeline_name)

run_id = dag_run.run_id

instance = DagsterInstance.from_ref(instance_ref) if instance_ref else None
Expand All @@ -112,6 +115,7 @@ def invoke_steps_within_python_operator(
pipeline_snapshot=pipeline_snapshot,
execution_plan_snapshot=execution_plan_snapshot,
parent_pipeline_snapshot=parent_pipeline_snapshot,
pipeline_code_origin=recon_pipeline.get_python_origin(),
)

recon_pipeline = recon_repo.get_reconstructable_pipeline(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def _execute_step_docker(
docker_image = (
docker_config["image"]
if docker_config.get("image")
else execute_step_args.pipeline_origin.repository_origin.container_image
else pipeline_run.pipeline_code_origin.repository_origin.container_image
)

if not docker_image:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _execute_step_k8s_job(
pod_name,
component="step_worker",
labels={
"dagster/job": execute_step_args.pipeline_origin.pipeline_name,
"dagster/job": pipeline_run.pipeline_name,
"dagster/op": step_key,
"dagster/run-id": execute_step_args.pipeline_run_id,
},
Expand Down

0 comments on commit 38b71a2

Please sign in to comment.