-
Notifications
You must be signed in to change notification settings - Fork 294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core feature] Add Raw AWS Batch Task #782
Changes from 15 commits
e8673ef
01cf848
7de5364
85c8552
97508bb
70f5f90
f81516b
cbded6c
68fd00e
35f8464
7629730
7a2320b
b09a9c7
9292099
6c4cd48
caa6f1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Flytekit AWS Batch Plugin | ||
|
||
Flyte backend can be connected with AWS batch. Once enabled, it allows you to run flyte task on AWS batch service | ||
|
||
To install the plugin, run the following command: | ||
|
||
```bash | ||
pip install flytekitplugins-awsbatch | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .task import AWSBatchConfig |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,78 @@ | ||||
from dataclasses import dataclass | ||||
from typing import Any, Callable, Dict, List, Optional | ||||
|
||||
from dataclasses_json import dataclass_json | ||||
from google.protobuf import json_format | ||||
from google.protobuf.struct_pb2 import Struct | ||||
|
||||
from flytekit import PythonFunctionTask | ||||
from flytekit.extend import SerializationSettings, TaskPlugins | ||||
|
||||
|
||||
@dataclass_json | ||||
@dataclass | ||||
class AWSBatchConfig(object): | ||||
""" | ||||
Use this to configure SubmitJobInput for a AWS batch job. Task's marked with this will automatically execute | ||||
natively onto AWS batch service. | ||||
Refer to AWS SubmitJobInput for more detail: https://docs.aws.amazon.com/sdk-for-go/api/service/batch/#SubmitJobInput | ||||
""" | ||||
|
||||
parameters: Optional[Dict[str, str]] = None | ||||
schedulingPriority: Optional[int] = None | ||||
platformCapabilities: str = "EC2" | ||||
propagateTags: Optional[bool] = None | ||||
tags: Optional[Dict[str, str]] = None | ||||
|
||||
def to_dict(self): | ||||
s = Struct() | ||||
s.update(self.to_dict()) | ||||
return json_format.MessageToDict(s) | ||||
|
||||
|
||||
class AWSBatchFunctionTask(PythonFunctionTask): | ||||
""" | ||||
Actual Plugin that transforms the local python code for execution within AWS batch job | ||||
""" | ||||
|
||||
_AWS_BATCH_TASK_TYPE = "aws-batch" | ||||
|
||||
def __init__(self, task_config: AWSBatchConfig, task_function: Callable, **kwargs): | ||||
if task_config is None: | ||||
task_config = AWSBatchConfig() | ||||
super(AWSBatchFunctionTask, self).__init__( | ||||
task_config=task_config, task_type=self._AWS_BATCH_TASK_TYPE, task_function=task_function, **kwargs | ||||
) | ||||
self._task_config = task_config | ||||
|
||||
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: | ||||
# task_config will be used to create SubmitJobInput in propeller except platformCapabilities. | ||||
pingsutw marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
return self._task_config.to_dict() | ||||
|
||||
def get_config(self, settings: SerializationSettings) -> Dict[str, str]: | ||||
# Parameters in taskTemplate config will be used to create aws job definition. | ||||
# More detail about job definition: https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html | ||||
return {"platformCapabilities": self._task_config.platformCapabilities} | ||||
|
||||
def get_command(self, settings: SerializationSettings) -> List[str]: | ||||
container_args = [ | ||||
"pyflyte-execute", | ||||
"--inputs", | ||||
"{{.input}}", | ||||
"--output-prefix", | ||||
# FlytePropeller will always read the output from this directory (outputPrefix/0) | ||||
# More detail, see https://github.com/flyteorg/flyteplugins/blob/0dd93c23ed2edeca65d58e89b0edb613f88120e0/go/tasks/plugins/array/catalog.go#L501. | ||||
"{{.outputPrefix}}/0", | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is the reason why I used pyflyte-map-execute before. map task will reconstruct output_prefix to flytekit/flytekit/bin/entrypoint.py Line 386 in b09a9c7
Agree with you, we should use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add another comment here on why the single batch task is calling the array batch plugin? |
||||
"--raw-output-data-prefix", | ||||
"{{.rawOutputDataPrefix}}", | ||||
"--resolver", | ||||
self.task_resolver.location, | ||||
"--", | ||||
*self.task_resolver.loader_args(settings, self), | ||||
] | ||||
|
||||
return container_args | ||||
|
||||
|
||||
# Inject the AWS batch plugin into flytekits dynamic plugin loading system | ||||
TaskPlugins.register_pythontask_plugin(AWSBatchConfig, AWSBatchFunctionTask) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
. | ||
-e file:.#egg=flytekitplugins-awsbatch |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# | ||
# This file is autogenerated by pip-compile with python 3.9 | ||
# To update, run: | ||
# | ||
# pip-compile requirements.in | ||
# | ||
-e file:.#egg=flytekitplugins-awsbatch | ||
# via -r requirements.in | ||
arrow==1.2.1 | ||
# via jinja2-time | ||
binaryornot==0.4.4 | ||
# via cookiecutter | ||
certifi==2021.10.8 | ||
# via requests | ||
chardet==4.0.0 | ||
# via binaryornot | ||
charset-normalizer==2.0.10 | ||
# via requests | ||
checksumdir==1.2.0 | ||
# via flytekit | ||
click==7.1.2 | ||
# via | ||
# cookiecutter | ||
# flytekit | ||
cloudpickle==2.0.0 | ||
# via flytekit | ||
cookiecutter==1.7.3 | ||
# via flytekit | ||
croniter==1.2.0 | ||
# via flytekit | ||
dataclasses-json==0.5.6 | ||
# via flytekit | ||
decorator==5.1.1 | ||
# via retry | ||
deprecated==1.2.13 | ||
# via flytekit | ||
diskcache==5.4.0 | ||
# via flytekit | ||
docker-image-py==0.1.12 | ||
# via flytekit | ||
docstring-parser==0.13 | ||
# via flytekit | ||
flyteidl==0.21.23 | ||
# via flytekit | ||
flytekit==0.26.0 | ||
# via flytekitplugins-awsbatch | ||
grpcio==1.43.0 | ||
# via flytekit | ||
idna==3.3 | ||
# via requests | ||
importlib-metadata==4.10.1 | ||
# via keyring | ||
jinja2==3.0.3 | ||
# via | ||
# cookiecutter | ||
# jinja2-time | ||
jinja2-time==0.2.0 | ||
# via cookiecutter | ||
keyring==23.5.0 | ||
# via flytekit | ||
markupsafe==2.0.1 | ||
# via jinja2 | ||
marshmallow==3.14.1 | ||
# via | ||
# dataclasses-json | ||
# marshmallow-enum | ||
# marshmallow-jsonschema | ||
marshmallow-enum==1.5.1 | ||
# via dataclasses-json | ||
marshmallow-jsonschema==0.13.0 | ||
# via flytekit | ||
mypy-extensions==0.4.3 | ||
# via typing-inspect | ||
natsort==8.0.2 | ||
# via flytekit | ||
numpy==1.22.1 | ||
# via | ||
# pandas | ||
# pyarrow | ||
pandas==1.3.5 | ||
# via flytekit | ||
poyo==0.5.0 | ||
# via cookiecutter | ||
protobuf==3.19.3 | ||
# via | ||
# flyteidl | ||
# flytekit | ||
py==1.11.0 | ||
# via retry | ||
pyarrow==6.0.1 | ||
# via flytekit | ||
python-dateutil==2.8.1 | ||
# via | ||
# arrow | ||
# croniter | ||
# flytekit | ||
# pandas | ||
python-json-logger==2.0.2 | ||
# via flytekit | ||
python-slugify==5.0.2 | ||
# via cookiecutter | ||
pytimeparse==1.1.8 | ||
# via flytekit | ||
pytz==2021.3 | ||
# via | ||
# flytekit | ||
# pandas | ||
regex==2022.1.18 | ||
# via docker-image-py | ||
requests==2.27.1 | ||
# via | ||
# cookiecutter | ||
# flytekit | ||
# responses | ||
responses==0.17.0 | ||
# via flytekit | ||
retry==0.9.2 | ||
# via flytekit | ||
six==1.16.0 | ||
# via | ||
# cookiecutter | ||
# flytekit | ||
# grpcio | ||
# python-dateutil | ||
# responses | ||
sortedcontainers==2.4.0 | ||
# via flytekit | ||
statsd==3.3.0 | ||
# via flytekit | ||
text-unidecode==1.3 | ||
# via python-slugify | ||
typing-extensions==4.0.1 | ||
# via typing-inspect | ||
typing-inspect==0.7.1 | ||
# via dataclasses-json | ||
urllib3==1.26.8 | ||
# via | ||
# flytekit | ||
# requests | ||
# responses | ||
wheel==0.37.1 | ||
# via flytekit | ||
wrapt==1.13.3 | ||
# via | ||
# deprecated | ||
# flytekit | ||
zipp==3.7.0 | ||
# via importlib-metadata |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from setuptools import setup | ||
|
||
PLUGIN_NAME = "awsbatch" | ||
|
||
microlib_name = f"flytekitplugins-{PLUGIN_NAME}" | ||
|
||
plugin_requires = ["flytekit>=0.19.0,<1.0.0"] | ||
|
||
__version__ = "0.0.0+develop" | ||
|
||
setup( | ||
name=microlib_name, | ||
version=__version__, | ||
author="flyteorg", | ||
author_email="admin@flyte.org", | ||
description="This package holds the AWS Batch plugins for flytekit", | ||
namespace_packages=["flytekitplugins"], | ||
packages=[f"flytekitplugins.{PLUGIN_NAME}"], | ||
install_requires=plugin_requires, | ||
license="apache2", | ||
python_requires=">=3.7", | ||
classifiers=[ | ||
"Intended Audience :: Science/Research", | ||
"Intended Audience :: Developers", | ||
"License :: OSI Approved :: Apache Software License", | ||
"Programming Language :: Python :: 3.7", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Topic :: Scientific/Engineering", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"Topic :: Software Development", | ||
"Topic :: Software Development :: Libraries", | ||
"Topic :: Software Development :: Libraries :: Python Modules", | ||
], | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from flytekitplugins.awsbatch import AWSBatchConfig | ||
|
||
from flytekit import PythonFunctionTask, task | ||
from flytekit.extend import Image, ImageConfig, SerializationSettings | ||
|
||
config = AWSBatchConfig( | ||
parameters={"codec": "mp4"}, | ||
platformCapabilities="EC2", | ||
propagateTags=True, | ||
tags={"hello": "world"}, | ||
) | ||
|
||
|
||
def test_aws_batch_task(): | ||
@task(task_config=config) | ||
def t1(a: int) -> str: | ||
inc = a + 2 | ||
return str(inc) | ||
|
||
assert t1.task_config is not None | ||
assert t1.task_config == config | ||
assert t1.task_type == "aws-batch" | ||
assert isinstance(t1, PythonFunctionTask) | ||
|
||
default_img = Image(name="default", fqn="test", tag="tag") | ||
settings = SerializationSettings( | ||
project="project", | ||
domain="domain", | ||
version="version", | ||
env={"FOO": "baz"}, | ||
image_config=ImageConfig(default_image=default_img, images=[default_img]), | ||
) | ||
assert t1.get_custom(settings) == config.to_dict() | ||
assert t1.get_command(settings) == [ | ||
"pyflyte-execute", | ||
"--inputs", | ||
"{{.input}}", | ||
"--output-prefix", | ||
"{{.outputPrefix}}/0", | ||
"--raw-output-data-prefix", | ||
"{{.rawOutputDataPrefix}}", | ||
"--resolver", | ||
"flytekit.core.python_auto_container.default_task_resolver", | ||
"--", | ||
"task-module", | ||
"tests.test_aws_batch", | ||
"task-name", | ||
"t1", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are all five of these in the SubmitJobInput documentation? Or are some of these Flyte concepts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all are in the SubmitJobInput documentation