Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core feature] Add Raw AWS Batch Task #782

Merged
merged 16 commits into from
Feb 17, 2022
1 change: 1 addition & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
plugin-names:
# Please maintain an alphabetical order in the following list
- flytekit-aws-athena
- flytekit-aws-batch
- flytekit-aws-sagemaker
- flytekit-bigquery
- flytekit-data-fsspec
Expand Down
25 changes: 24 additions & 1 deletion flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def _compute_array_job_index():
offset = 0
if _os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"):
offset = int(_os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"))
return offset + int(_os.environ.get(_os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME")))
if _os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"):
return offset + int(_os.environ.get(_os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME")))
return offset


def _dispatch_execute(
Expand Down Expand Up @@ -345,6 +347,27 @@ def _execute_map_task(
dynamic_addl_distro: Optional[str] = None,
dynamic_dest_dir: Optional[str] = None,
):
"""
This function should be called by map task and aws-batch task
resolver should be something like:
flytekit.core.python_auto_container.default_task_resolver
resolver args should be something like
task_module app.workflows task_name task_1
have dashes seems to mess up click, like --task_module seems to interfere

:param inputs: Where to read inputs
:param output_prefix: Where to write primitive outputs
:param raw_output_data_prefix: Where to write offloaded data (files, directories, dataframes).
:param test: Dry run
:param resolver: The task resolver to use. This needs to be loadable directly from importlib (and thus cannot be
nested).
:param resolver_args: Args that will be passed to the aforementioned resolver's load_task function
:param dynamic_addl_distro: In the case of parent tasks executed using the 'fast' mode this captures where the
compressed code archive has been uploaded.
:param dynamic_dest_dir: In the case of parent tasks executed using the 'fast' mode this captures where compressed
code archives should be installed in the flyte task container.
:return:
"""
if len(resolver_args) < 1:
raise Exception(f"Resolver args cannot be <1, got {resolver_args}")

Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(

collection_interface = transform_interface_to_list_interface(python_function_task.python_interface)
instance = next(self._ids)
name = f"{python_function_task._task_function.__module__}.mapper_{python_function_task._task_function.__name__}_{instance}"
name = f"{python_function_task.task_function.__module__}.mapper_{python_function_task.task_function.__name__}_{instance}"

self._run_task = python_function_task
self._max_concurrency = concurrency
Expand Down
9 changes: 9 additions & 0 deletions plugins/flytekit-aws-batch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Flytekit AWS Batch Plugin

Flyte backend can be connected with AWS batch. Once enabled, it allows you to run flyte task on AWS batch service

To install the plugin, run the following command:

```bash
pip install flytekitplugins-awsbatch
```
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .task import AWSBatchConfig
78 changes: 78 additions & 0 deletions plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional

from dataclasses_json import dataclass_json
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit import PythonFunctionTask
from flytekit.extend import SerializationSettings, TaskPlugins


@dataclass_json
@dataclass
class AWSBatchConfig(object):
"""
Use this to configure SubmitJobInput for a AWS batch job. Task's marked with this will automatically execute
natively onto AWS batch service.
Refer to AWS SubmitJobInput for more detail: https://docs.aws.amazon.com/sdk-for-go/api/service/batch/#SubmitJobInput
"""

parameters: Optional[Dict[str, str]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are all five of these in the SubmitJobInput documentation? Or are some of these Flyte concepts?

Copy link
Member Author

@pingsutw pingsutw Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all are in the SubmitJobInput documentation

schedulingPriority: Optional[int] = None
platformCapabilities: str = "EC2"
propagateTags: Optional[bool] = None
tags: Optional[Dict[str, str]] = None

def to_dict(self):
s = Struct()
s.update(self.to_dict())
return json_format.MessageToDict(s)


class AWSBatchFunctionTask(PythonFunctionTask):
"""
Actual Plugin that transforms the local python code for execution within AWS batch job
"""

_AWS_BATCH_TASK_TYPE = "aws-batch"

def __init__(self, task_config: AWSBatchConfig, task_function: Callable, **kwargs):
if task_config is None:
task_config = AWSBatchConfig()
super(AWSBatchFunctionTask, self).__init__(
task_config=task_config, task_type=self._AWS_BATCH_TASK_TYPE, task_function=task_function, **kwargs
)
self._task_config = task_config

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
# task_config will be used to create SubmitJobInput in propeller except platformCapabilities.
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
return self._task_config.to_dict()

def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
# Parameters in taskTemplate config will be used to create aws job definition.
# More detail about job definition: https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html
return {"platformCapabilities": self._task_config.platformCapabilities}

def get_command(self, settings: SerializationSettings) -> List[str]:
container_args = [
"pyflyte-execute",
"--inputs",
"{{.input}}",
"--output-prefix",
# FlytePropeller will always read the output from this directory (outputPrefix/0)
# More detail, see https://github.com/flyteorg/flyteplugins/blob/0dd93c23ed2edeca65d58e89b0edb613f88120e0/go/tasks/plugins/array/catalog.go#L501.
"{{.outputPrefix}}/0",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the reason why I used pyflyte-map-execute before. map task will reconstruct output_prefix to output_prefix/0. check

output_prefix = _os.path.join(output_prefix, str(task_index))

Agree with you, we should use pyflyte-execute instead. it will be more concise.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add another comment here on why the single batch task is calling the array batch plugin?

"--raw-output-data-prefix",
"{{.rawOutputDataPrefix}}",
"--resolver",
self.task_resolver.location,
"--",
*self.task_resolver.loader_args(settings, self),
]

return container_args


# Inject the AWS batch plugin into flytekits dynamic plugin loading system
TaskPlugins.register_pythontask_plugin(AWSBatchConfig, AWSBatchFunctionTask)
2 changes: 2 additions & 0 deletions plugins/flytekit-aws-batch/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.
-e file:.#egg=flytekitplugins-awsbatch
148 changes: 148 additions & 0 deletions plugins/flytekit-aws-batch/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#
# This file is autogenerated by pip-compile with python 3.9
# To update, run:
#
# pip-compile requirements.in
#
-e file:.#egg=flytekitplugins-awsbatch
# via -r requirements.in
arrow==1.2.1
# via jinja2-time
binaryornot==0.4.4
# via cookiecutter
certifi==2021.10.8
# via requests
chardet==4.0.0
# via binaryornot
charset-normalizer==2.0.10
# via requests
checksumdir==1.2.0
# via flytekit
click==7.1.2
# via
# cookiecutter
# flytekit
cloudpickle==2.0.0
# via flytekit
cookiecutter==1.7.3
# via flytekit
croniter==1.2.0
# via flytekit
dataclasses-json==0.5.6
# via flytekit
decorator==5.1.1
# via retry
deprecated==1.2.13
# via flytekit
diskcache==5.4.0
# via flytekit
docker-image-py==0.1.12
# via flytekit
docstring-parser==0.13
# via flytekit
flyteidl==0.21.23
# via flytekit
flytekit==0.26.0
# via flytekitplugins-awsbatch
grpcio==1.43.0
# via flytekit
idna==3.3
# via requests
importlib-metadata==4.10.1
# via keyring
jinja2==3.0.3
# via
# cookiecutter
# jinja2-time
jinja2-time==0.2.0
# via cookiecutter
keyring==23.5.0
# via flytekit
markupsafe==2.0.1
# via jinja2
marshmallow==3.14.1
# via
# dataclasses-json
# marshmallow-enum
# marshmallow-jsonschema
marshmallow-enum==1.5.1
# via dataclasses-json
marshmallow-jsonschema==0.13.0
# via flytekit
mypy-extensions==0.4.3
# via typing-inspect
natsort==8.0.2
# via flytekit
numpy==1.22.1
# via
# pandas
# pyarrow
pandas==1.3.5
# via flytekit
poyo==0.5.0
# via cookiecutter
protobuf==3.19.3
# via
# flyteidl
# flytekit
py==1.11.0
# via retry
pyarrow==6.0.1
# via flytekit
python-dateutil==2.8.1
# via
# arrow
# croniter
# flytekit
# pandas
python-json-logger==2.0.2
# via flytekit
python-slugify==5.0.2
# via cookiecutter
pytimeparse==1.1.8
# via flytekit
pytz==2021.3
# via
# flytekit
# pandas
regex==2022.1.18
# via docker-image-py
requests==2.27.1
# via
# cookiecutter
# flytekit
# responses
responses==0.17.0
# via flytekit
retry==0.9.2
# via flytekit
six==1.16.0
# via
# cookiecutter
# flytekit
# grpcio
# python-dateutil
# responses
sortedcontainers==2.4.0
# via flytekit
statsd==3.3.0
# via flytekit
text-unidecode==1.3
# via python-slugify
typing-extensions==4.0.1
# via typing-inspect
typing-inspect==0.7.1
# via dataclasses-json
urllib3==1.26.8
# via
# flytekit
# requests
# responses
wheel==0.37.1
# via flytekit
wrapt==1.13.3
# via
# deprecated
# flytekit
zipp==3.7.0
# via importlib-metadata
36 changes: 36 additions & 0 deletions plugins/flytekit-aws-batch/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from setuptools import setup

PLUGIN_NAME = "awsbatch"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=0.19.0,<1.0.0"]

__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="flyteorg",
author_email="admin@flyte.org",
description="This package holds the AWS Batch plugins for flytekit",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.7",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)
Empty file.
49 changes: 49 additions & 0 deletions plugins/flytekit-aws-batch/tests/test_aws_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from flytekitplugins.awsbatch import AWSBatchConfig

from flytekit import PythonFunctionTask, task
from flytekit.extend import Image, ImageConfig, SerializationSettings

config = AWSBatchConfig(
parameters={"codec": "mp4"},
platformCapabilities="EC2",
propagateTags=True,
tags={"hello": "world"},
)


def test_aws_batch_task():
@task(task_config=config)
def t1(a: int) -> str:
inc = a + 2
return str(inc)

assert t1.task_config is not None
assert t1.task_config == config
assert t1.task_type == "aws-batch"
assert isinstance(t1, PythonFunctionTask)

default_img = Image(name="default", fqn="test", tag="tag")
settings = SerializationSettings(
project="project",
domain="domain",
version="version",
env={"FOO": "baz"},
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
assert t1.get_custom(settings) == config.to_dict()
assert t1.get_command(settings) == [
"pyflyte-execute",
"--inputs",
"{{.input}}",
"--output-prefix",
"{{.outputPrefix}}/0",
"--raw-output-data-prefix",
"{{.rawOutputDataPrefix}}",
"--resolver",
"flytekit.core.python_auto_container.default_task_resolver",
"--",
"task-module",
"tests.test_aws_batch",
"task-name",
"t1",
]