Skip to content

Commit

Permalink
[dagster-aws] update emr pyspark step launcher (#7604)
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenKephart committed Apr 27, 2022
1 parent 3bcb757 commit 00fa234
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
import boto3
from dagster_aws.s3.file_manager import S3FileHandle, S3FileManager

from dagster.core.execution.plan.external_step import PICKLED_EVENTS_FILE_NAME, run_step_from_ref
from dagster.core.instance import DagsterInstance
from dagster.core.execution.plan.external_step import (
PICKLED_EVENTS_FILE_NAME,
external_instance_from_step_run_ref,
run_step_from_ref,
)
from dagster.serdes import serialize_value

DONE = object()

Expand All @@ -27,7 +31,7 @@ def main(step_run_ref_bucket, s3_dir_key):
events_s3_key = os.path.dirname(s3_dir_key) + "/" + PICKLED_EVENTS_FILE_NAME

def put_events(events):
file_obj = io.BytesIO(pickle.dumps(events))
file_obj = io.BytesIO(pickle.dumps(serialize_value(events)))
session.put_object(Body=file_obj, Bucket=events_bucket, Key=events_s3_key)

# Set up a thread to handle writing events back to the plan process, so execution doesn't get
Expand All @@ -39,13 +43,15 @@ def put_events(events):
)
event_writing_thread.start()

with DagsterInstance.ephemeral() as instance:
try:
for event in run_step_from_ref(step_run_ref, instance):
events_queue.put(event)
finally:
events_queue.put(DONE)
event_writing_thread.join()
try:
instance = external_instance_from_step_run_ref(
step_run_ref, event_listener_fn=events_queue.put
)
# consume iterator
list(run_step_from_ref(step_run_ref, instance))
finally:
events_queue.put(DONE)
event_writing_thread.join()


def event_writing_loop(events_queue, put_events_fn):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from dagster import Field, StringSource, check, resource
from dagster.core.definitions.step_launcher import StepLauncher
from dagster.core.errors import DagsterInvariantViolationError, raise_execution_interrupts
from dagster.core.events import log_step_event
from dagster.core.execution.plan.external_step import (
PICKLED_EVENTS_FILE_NAME,
PICKLED_STEP_RUN_REF_FILE_NAME,
step_context_to_step_run_ref,
)
from dagster.serdes import deserialize_value

# On EMR, Spark is installed here
EMR_SPARK_HOME = "/usr/lib/spark/"
Expand Down Expand Up @@ -307,25 +307,24 @@ def launch_step(self, step_context, prior_attempts_count):
0
]

yield from self.wait_for_completion_and_log(
log, run_id, step_key, emr_step_id, step_context
)
yield from self.wait_for_completion_and_log(run_id, step_key, emr_step_id, step_context)

def wait_for_completion_and_log(self, log, run_id, step_key, emr_step_id, step_context):
def wait_for_completion_and_log(self, run_id, step_key, emr_step_id, step_context):
s3 = boto3.resource("s3", region_name=self.region_name)
try:
for event in self.wait_for_completion(log, s3, run_id, step_key, emr_step_id):
log_step_event(step_context, event)
for event in self.wait_for_completion(step_context, s3, run_id, step_key, emr_step_id):
yield event
except EmrError as emr_error:
if self.wait_for_logs:
self._log_logs_from_s3(log, emr_step_id)
self._log_logs_from_s3(step_context.log, emr_step_id)
raise emr_error

if self.wait_for_logs:
self._log_logs_from_s3(log, emr_step_id)
self._log_logs_from_s3(step_context.log, emr_step_id)

def wait_for_completion(self, log, s3, run_id, step_key, emr_step_id, check_interval=15):
def wait_for_completion(
self, step_context, s3, run_id, step_key, emr_step_id, check_interval=15
):
"""We want to wait for the EMR steps to complete, and while that's happening, we want to
yield any events that have been written to S3 for us by the remote process.
After the the EMR steps complete, we want a final chance to fetch events before finishing
Expand All @@ -339,13 +338,19 @@ def wait_for_completion(self, log, s3, run_id, step_key, emr_step_id, check_inte
while not done:
with raise_execution_interrupts():
time.sleep(check_interval) # AWS rate-limits us if we poll it too often
done = self.emr_job_runner.is_emr_step_complete(log, self.cluster_id, emr_step_id)
done = self.emr_job_runner.is_emr_step_complete(
step_context.log, self.cluster_id, emr_step_id
)

all_events_new = self.read_events(s3, run_id, step_key)

if len(all_events_new) > len(all_events):
for i in range(len(all_events), len(all_events_new)):
yield all_events_new[i]
event = all_events_new[i]
# write each event from the EMR instance to the local instance
step_context.instance.handle_new_event(event)
if event.is_dagster_event:
yield event.dagster_event
all_events = all_events_new

def read_events(self, s3, run_id, step_key):
Expand All @@ -355,7 +360,7 @@ def read_events(self, s3, run_id, step_key):

try:
events_data = events_s3_obj.get()["Body"].read()
return pickle.loads(events_data)
return deserialize_value(pickle.loads(events_data))
except ClientError as ex:
# The file might not be there yet, which is fine
if ex.response["Error"]["Code"] == "NoSuchKey":
Expand Down Expand Up @@ -434,9 +439,20 @@ def _main_file_name(self):
def _main_file_local_path(self):
return emr_step_main.__file__

def _sanitize_step_key(self, step_key: str) -> str:
# step_keys of dynamic steps contain brackets, which are invalid characters
return step_key.replace("[", "__").replace("]", "__")

def _artifact_s3_uri(self, run_id, step_key, filename):
key = self._artifact_s3_key(run_id, step_key, filename)
key = self._artifact_s3_key(run_id, self._sanitize_step_key(step_key), filename)
return "s3://{bucket}/{key}".format(bucket=self.staging_bucket, key=key)

def _artifact_s3_key(self, run_id, step_key, filename):
return "/".join([self.staging_prefix, run_id, step_key, os.path.basename(filename)])
return "/".join(
[
self.staging_prefix,
run_id,
self._sanitize_step_key(step_key),
os.path.basename(filename),
]
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from dagster.core.definitions.no_step_launcher import no_step_launcher
from dagster.core.errors import DagsterSubprocessError
from dagster.core.test_utils import instance_for_test
from dagster.utils.merger import deep_merge_dicts
from dagster.utils.test import create_test_pipeline_execution_context

Expand Down Expand Up @@ -114,9 +115,11 @@ def test_local():
@mock.patch("dagster_aws.emr.pyspark_step_launcher.EmrPySparkStepLauncher.read_events")
@mock.patch("dagster_aws.emr.emr.EmrJobRunner.is_emr_step_complete")
def test_pyspark_emr(mock_is_emr_step_complete, mock_read_events, mock_s3_bucket):
mock_read_events.return_value = execute_pipeline(
reconstructable(define_do_nothing_pipe), mode="local"
).events_by_step_key["do_nothing_solid"]
with instance_for_test() as instance:
execute_pipeline(reconstructable(define_do_nothing_pipe), mode="local", instance=instance)
mock_read_events.return_value = [
record.event_log_entry for record in instance.get_event_records()
]

run_job_flow_args = dict(
Instances={
Expand Down Expand Up @@ -233,6 +236,8 @@ def test_fetch_logs_on_fail(
_mock_log_step_event, mock_log_logs, mock_wait_for_completion, _mock_boto3_resource
):
mock_log = mock.MagicMock()
mock_step_context = mock.MagicMock()
mock_step_context.log = mock_log
mock_wait_for_completion.side_effect = EmrError()

step_launcher = EmrPySparkStepLauncher(
Expand All @@ -248,7 +253,7 @@ def test_fetch_logs_on_fail(
)

with pytest.raises(EmrError):
for _ in step_launcher.wait_for_completion_and_log(mock_log, None, None, None, None):
for _ in step_launcher.wait_for_completion_and_log(None, None, None, mock_step_context):
pass

assert mock_log_logs.call_count == 1
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,39 @@

from dagster_aws.emr.pyspark_step_launcher import EmrPySparkStepLauncher

EVENTS = [object(), object(), object()]
from dagster import DagsterEvent, EventLogEntry
from dagster.core.execution.plan.objects import StepSuccessData

EVENTS = [
EventLogEntry(
run_id="1",
error_info=None,
level=20,
user_message="foo",
timestamp=1.0,
dagster_event=DagsterEvent(event_type_value="STEP_START", pipeline_name="foo"),
),
EventLogEntry(
run_id="1",
error_info=None,
level=20,
user_message="bar",
timestamp=2.0,
dagster_event=None,
),
EventLogEntry(
run_id="1",
error_info=None,
level=20,
user_message="baz",
timestamp=3.0,
dagster_event=DagsterEvent(
event_type_value="STEP_SUCCESS",
pipeline_name="foo",
event_specific_data=StepSuccessData(duration_ms=2.0),
),
),
]


@mock.patch(
Expand All @@ -27,4 +59,4 @@ def test_wait_for_completion(_mock_is_emr_step_complete, _mock_read_events):
yielded_events = list(
launcher.wait_for_completion(mock.MagicMock(), None, None, None, None, check_interval=0)
)
assert yielded_events == EVENTS
assert yielded_events == [event.dagster_event for event in EVENTS if event.is_dagster_event]

0 comments on commit 00fa234

Please sign in to comment.