Skip to content

Commit

Permalink
Create a new task definition for each launch
Browse files Browse the repository at this point in the history
Summary:
When originally conceived, the EcsRunLauncher launched tasks using the
same task definition used by its current process. This was to punt on
needing to register new task definitions because:

1. There's not a great way to identify if, given a set of input
   parameters, a task definition that satisfies those constraints
   already exists.
2. There's a lot of variety in how you can configure a task definition
   and I thought constraining our options to just a known valid
   configuration would make development easier.

The fatal flaw of this plan is that the current process that calls
launch_run is usually the daemon. If your daemon container doesn't also
include your repository and pipeline code, then the ECS task correctly
stands up but the Dagster step fails execution because it can't
load the pipeline definition.

ECS allows you to override a lot of container behavior - but it doesn't
allow you to override the actual image. So we're forced to either find a
suitable task definition that already exists or bit the bullet and
create a new suitable task definition.

We can know the pipeline's repository origin's image at the time we want
to launch the run because it can be set via DAGSTER_CURRENT_IMAGE:

90079c4

We could list all running tasks and find ones that use that image. But
it's possible that multiple tasks uses the image and I could see it
getting confusing if we chose the "wrong" one (even though technically
things would still work. For example, if both a repo1 task and a repo2
task use the same image and we're trying to run a pipeline from repo1,
we could accidentally do so using a task definition for repo2. Blergh.

So instead, I've decided to forge ahead on creating a task definition
for each run. We start with the parent task definition just like we
previously did. But then we munge it so that it's suitable to pass back
in as arguments to `ecs.register_task_definition()`. We remove the
"daemon" container and add a new "run" container.

There are two major optimizations that we should make to this before
recommending it for production use:

1. Don't create a new task definition if a suitable active one already
   exists. This is perhaps easier said than done because ECS doesn't
   provide a mechanism for checking if a given task definition exists.
   So we'll probably need to either read through the ECS documentation
   and hardcode its default behaviors or we'll need to loosen our
   definition of "matching."
2. Garbage collect unused revisions. Once we're done running our task,
   we should deregister its task definition so users' AWS accounts
   aren't littered with tons of active but outdated task definitions.
3. Allow the task definition to be overridden. This can be done fairly
   trivially by changing:

    ```
    self.ecs.register_task_definition(**task_definition)
    ```
   to
    ```
    self.ecs.register_task_definition{**{**task_definition, **overrides}})
    ```

   although we'll want to give some thought and care to the exact
   implementation.

Test Plan: unit

Reviewers: alangenfeld, dgibson

Reviewed By: dgibson

Differential Revision: https://dagster.phacility.com/D8486
  • Loading branch information
jmsanders committed Jun 23, 2021
1 parent b44e369 commit cb07e82
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 42 deletions.
102 changes: 77 additions & 25 deletions python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import dataclass
from typing import List
from typing import Any, Dict, List

import boto3
import requests
Expand All @@ -12,12 +12,11 @@

@dataclass
class TaskMetadata:
arn: str
container: str
family: str
cluster: str
subnets: List[str]
security_groups: List[str]
task_definition: Dict[str, Any]
container_definition: Dict[str, Any]


class EcsRunLauncher(RunLauncher, ConfigurableClass):
Expand Down Expand Up @@ -49,35 +48,78 @@ def _run_tags(self, task_arn):

def launch_run(self, run, external_pipeline):
"""
Launch a run using the same task definition as the parent process but
overriding its command to execute `dagster api execute_run` instead.
Launch a run in an ECS task.
Currently, Fargate is the only supported launchType and awsvpc is the
only supported networkMode. These are the defaults that are set up by
docker-compose when you use the Dagster ECS reference deployment.
When using the Dagster ECS reference deployment, the parent process
will be running in a daemon task so pipeline runs will all be part of
the daemon task definition family.
This method creates a new task definition revision for every run.
First, the process that calls this method finds its own task
definition. Next, it creates a new task definition based on its own
with several important overrides:
TODO: Support creating a new task definition (with custom config)
instead of spawning from the parent process.
1. The command is replaced with a call to `dagster api execute_run`
2. The image is overridden with the pipeline's origin's image.
"""
metadata = self._task_metadata()
pipeline_origin = external_pipeline.get_python_origin()
image = pipeline_origin.repository_origin.container_image

input_json = serialize_dagster_namedtuple(
ExecuteRunArgs(
pipeline_origin=external_pipeline.get_python_origin(),
pipeline_origin=pipeline_origin,
pipeline_run_id=run.run_id,
instance_ref=self._instance.get_ref(),
)
)
command = ["dagster", "api", "execute_run", input_json]

# Start with the current processes's tasks's definition but remove extra
# keys that aren't useful for creating a new task definition (status,
# revision, etc.)
expected_keys = [
key
for key in self.ecs.meta.service_model.shape_for(
"RegisterTaskDefinitionRequest"
).members
]
task_definition = dict(
(key, metadata.task_definition[key])
for key in expected_keys
if key in metadata.task_definition.keys()
)

# The current process might not be running in a container that has the
# pipeline's code installed. Inherit most of the processes's container
# definition (things like environment, dependencies, etc.) but replace
# the image with the pipeline origin's image and give it a new name.
# TODO: Configurable task definitions
container_definitions = task_definition["containerDefinitions"]
container_definitions.remove(metadata.container_definition)
container_definitions.append(
{**metadata.container_definition, "name": "run", "image": image}
)
task_definition = {
**task_definition,
"family": "dagster-run",
"containerDefinitions": container_definitions,
}

# Register the task overridden task definition as a revision to the
# "dagster-run" family.
# TODO: Only register the task definition if a matching one doesn't
# already exist. Otherwise, we risk exhausting the revisions limit
# (1,000,000 per family) with unnecessary revisions:
# https://docs.aws.amazon.com/AmazonECS/latest/developerguide/service-quotas.html
self.ecs.register_task_definition(**task_definition)

# Run a task using the new task definition and the same network
# configuration as this processes's task.
response = self.ecs.run_task(
taskDefinition=metadata.family,
taskDefinition=task_definition["family"],
cluster=metadata.cluster,
overrides={"containerOverrides": [{"name": metadata.container, "command": command}]},
overrides={"containerOverrides": [{"name": "run", "command": command}]},
networkConfiguration={
"awsvpcConfiguration": {
"subnets": metadata.subnets,
Expand Down Expand Up @@ -118,23 +160,19 @@ def _task_metadata(self):
"""
ECS injects an environment variable into each Fargate task. The value
of this environment variable is a url that can be queried to introspect
information about the running task:
information about the current processes's running task:
https://docs.aws.amazon.com/AmazonECS/latest/userguide/task-metadata-endpoint-v4-fargate.html
We use this so we can spawn new tasks using the same task definition as
the existing process.
"""
container_metadata_uri = os.environ.get("ECS_CONTAINER_METADATA_URI_V4")
container = requests.get(container_metadata_uri).json()["Name"]
name = requests.get(container_metadata_uri).json()["Name"]

task_metadata_uri = container_metadata_uri + "/task"
response = requests.get(task_metadata_uri).json()
cluster = response.get("Cluster")
arn = response.get("TaskARN")
family = response.get("Family")
task_arn = response.get("TaskARN")

task = self.ecs.describe_tasks(tasks=[arn], cluster=cluster)["tasks"][0]
task = self.ecs.describe_tasks(tasks=[task_arn], cluster=cluster)["tasks"][0]
enis = []
subnets = []
for attachment in task["attachments"]:
Expand All @@ -150,11 +188,25 @@ def _task_metadata(self):
for group in eni.groups:
security_groups.append(group["GroupId"])

task_definition_arn = task["taskDefinitionArn"]
task_definition = self.ecs.describe_task_definition(taskDefinition=task_definition_arn)[
"taskDefinition"
]

container_definition = next(
iter(
[
container
for container in task_definition["containerDefinitions"]
if container["name"] == name
]
)
)

return TaskMetadata(
arn=arn,
container=container,
family=family,
cluster=cluster,
subnets=subnets,
security_groups=security_groups,
task_definition=task_definition,
container_definition=container_definition,
)
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import itertools
import uuid
from collections import defaultdict
Expand Down Expand Up @@ -41,7 +42,7 @@ def wrapper(*args, **kwargs):
if not self.stub_count:
self.stubber.deactivate()
self.stubber.assert_no_pending_responses()
return response
return copy.deepcopy(response)
except Exception as ex:
# Exceptions should reset the stubber
self.stub_count = 0
Expand Down Expand Up @@ -80,6 +81,7 @@ class StubbedEcs:
def __init__(self, boto3_client):
self.client = boto3_client
self.stubber = Stubber(self.client)
self.meta = self.client.meta

self.tasks = defaultdict(list)
self.task_definitions = defaultdict(list)
Expand Down Expand Up @@ -167,6 +169,20 @@ def list_tags_for_resource(self, **kwargs):
)
return self.client.list_tags_for_resource(**kwargs)

@stubbed
def list_task_definitions(self, **kwargs):
arns = [
task_definition["taskDefinitionArn"]
for task_definition in itertools.chain.from_iterable(self.task_definitions.values())
]

self.stubber.add_response(
method="list_task_definitions",
service_response={"taskDefinitionArns": arns},
expected_params={**kwargs},
)
return self.client.list_task_definitions(**kwargs)

@stubbed
def list_tasks(self, **kwargs):
"""
Expand Down Expand Up @@ -230,7 +246,11 @@ def run_task(self, **kwargs):
)["taskDefinition"]

is_awsvpc = task_definition.get("networkMode") == "awsvpc"
containers = task_definition.get("containerDefinitions", [])
containers = []
for container in task_definition.get("containerDefinitions", []):
containers.append(
{key: value for key, value in container.items() if key in ["name", "image"]}
)

network_configuration = kwargs.get("networkConfiguration", {})
vpc_configuration = network_configuration.get("awsvpcConfiguration")
Expand Down Expand Up @@ -303,8 +323,8 @@ def stop_task(self, **kwargs):

if tasks:
stopped_task = tasks[0]
stopped_task["lastStatus"] = "STOPPED"
self.tasks[cluster].remove(tasks[0])
stopped_task["lastStatus"] = "STOPPED"
self.tasks[cluster].append(stopped_task)
self.stubber.add_response(
method="stop_task",
Expand Down Expand Up @@ -353,7 +373,7 @@ def _cluster_arn(self, cluster):
return self._arn("cluster", self._cluster(cluster))

def _task_arn(self, cluster):
return self._arn("task", f"{self._cluster(cluster)}/{uuid.uuid4()})")
return self._arn("task", f"{self._cluster(cluster)}/{uuid.uuid4()}")

def _task_definition_arn(self, family, revision):
return self._arn("task-definition", f"{family}:{revision}")
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,26 @@


@pytest.fixture
def task_definition(ecs):
def image():
return "dagster:latest"


@pytest.fixture
def environment():
return [{"name": "FOO", "value": "bar"}]


@pytest.fixture
def task_definition(ecs, image, environment):
return ecs.register_task_definition(
family="dagster",
containerDefinitions=[{"name": "dagster", "image": "dagster:latest"}],
containerDefinitions=[
{
"name": "dagster",
"image": image,
"environment": environment,
}
],
networkMode="awsvpc",
)["taskDefinition"]

Expand All @@ -30,7 +46,7 @@ def task(ecs, network_interface, security_group, task_definition):


@pytest.fixture
def instance(ecs, ec2, task, task_definition, monkeypatch, requests_mock):
def instance(ecs, ec2, task, monkeypatch, requests_mock):
container_uri = "http://metadata_host"
monkeypatch.setenv("ECS_CONTAINER_METADATA_URI_V4", container_uri)
container = task["containers"][0]["name"]
Expand All @@ -42,7 +58,6 @@ def instance(ecs, ec2, task, task_definition, monkeypatch, requests_mock):
json={
"Cluster": task["clusterArn"],
"TaskARN": task["taskArn"],
"Family": task_definition["family"],
},
)
overrides = {"run_launcher": {"module": "dagster_aws.ecs", "class": "EcsRunLauncher"}}
Expand All @@ -58,9 +73,11 @@ def pipeline():


@pytest.fixture
def external_pipeline():
def external_pipeline(image):
with InProcessRepositoryLocationOrigin(
ReconstructableRepository.for_file(repo.__file__, repo.repository.__name__),
ReconstructableRepository.for_file(
repo.__file__, repo.repository.__name__, container_image=image
),
).create_location() as location:
yield location.get_repository(repo.repository.__name__).get_full_external_pipeline(
repo.pipeline.__name__
Expand All @@ -72,16 +89,39 @@ def run(instance, pipeline):
return instance.create_run_for_pipeline(pipeline)


def test_launching(ecs, instance, run, external_pipeline, subnet, network_interface):
def test_launching(
ecs, instance, run, external_pipeline, subnet, network_interface, image, environment
):
assert not run.tags
initial_task_definitions = ecs.list_task_definitions()["taskDefinitionArns"]
initial_tasks = ecs.list_tasks()["taskArns"]

instance.launch_run(run.run_id, external_pipeline)

# A new task definition is created
task_definitions = ecs.list_task_definitions()["taskDefinitionArns"]
assert len(task_definitions) == len(initial_task_definitions) + 1
task_definition_arn = list(set(task_definitions).difference(initial_task_definitions))[0]
task_definition = ecs.describe_task_definition(taskDefinition=task_definition_arn)
task_definition = task_definition["taskDefinition"]

# It has a new family, name, and image
assert task_definition["family"] == "dagster-run"
assert len(task_definition["containerDefinitions"]) == 1
container_definition = task_definition["containerDefinitions"][0]
assert container_definition["name"] == "run"
assert container_definition["image"] == image
# But other stuff is inhereted from the parent task definition
assert container_definition["environment"] == environment

# A new task is launched
tasks = ecs.list_tasks()["taskArns"]
assert len(tasks) == len(initial_tasks) + 1
task_arn = list(set(tasks).difference(initial_tasks))[0]
task = ecs.describe_tasks(tasks=[task_arn])["tasks"][0]
assert subnet.id in str(task)
assert network_interface.id in str(task)
assert task["taskDefinitionArn"] == task_definition["taskDefinitionArn"]

# The run is tagged with info about the ECS task
assert instance.get_run_by_id(run.run_id).tags["ecs/task_arn"] == task_arn
Expand All @@ -91,11 +131,13 @@ def test_launching(ecs, instance, run, external_pipeline, subnet, network_interf
assert ecs.list_tags_for_resource(resourceArn=task_arn)["tags"][0]["key"] == "dagster/run_id"
assert ecs.list_tags_for_resource(resourceArn=task_arn)["tags"][0]["value"] == run.run_id

# We override the command to run our pipeline
task = ecs.describe_tasks(tasks=[task_arn])
assert subnet.id in str(task)
assert network_interface.id in str(task)
assert "execute_run" in task["tasks"][0]["overrides"]["containerOverrides"][0]["command"]
# We set pipeline-specific overides
overrides = task["overrides"]["containerOverrides"]
assert len(overrides) == 1
override = overrides[0]
assert override["name"] == "run"
assert "execute_run" in override["command"]
assert run.run_id in str(override["command"])


def test_termination(instance, run, external_pipeline):
Expand Down
Loading

0 comments on commit cb07e82

Please sign in to comment.