diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index a8f4c06b9a..4ba2cf8894 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -62,6 +62,7 @@ jobs: plugin-names: # Please maintain an alphabetical order in the following list - flytekit-aws-athena + - flytekit-aws-batch - flytekit-aws-sagemaker - flytekit-bigquery - flytekit-data-fsspec diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 4fc9f69d3e..9d07d38a63 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -58,7 +58,9 @@ def _compute_array_job_index(): offset = 0 if _os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"): offset = int(_os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET")) - return offset + int(_os.environ.get(_os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"))) + if _os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"): + return offset + int(_os.environ.get(_os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"))) + return offset def _dispatch_execute( @@ -345,6 +347,27 @@ def _execute_map_task( dynamic_addl_distro: Optional[str] = None, dynamic_dest_dir: Optional[str] = None, ): + """ + This function should be called by map task and aws-batch task + resolver should be something like: + flytekit.core.python_auto_container.default_task_resolver + resolver args should be something like + task_module app.workflows task_name task_1 + have dashes seems to mess up click, like --task_module seems to interfere + + :param inputs: Where to read inputs + :param output_prefix: Where to write primitive outputs + :param raw_output_data_prefix: Where to write offloaded data (files, directories, dataframes). + :param test: Dry run + :param resolver: The task resolver to use. This needs to be loadable directly from importlib (and thus cannot be + nested). + :param resolver_args: Args that will be passed to the aforementioned resolver's load_task function + :param dynamic_addl_distro: In the case of parent tasks executed using the 'fast' mode this captures where the + compressed code archive has been uploaded. + :param dynamic_dest_dir: In the case of parent tasks executed using the 'fast' mode this captures where compressed + code archives should be installed in the flyte task container. + :return: + """ if len(resolver_args) < 1: raise Exception(f"Resolver args cannot be <1, got {resolver_args}") diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index f760be5d3c..42ba785a41 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -52,7 +52,7 @@ def __init__( collection_interface = transform_interface_to_list_interface(python_function_task.python_interface) instance = next(self._ids) - name = f"{python_function_task._task_function.__module__}.mapper_{python_function_task._task_function.__name__}_{instance}" + name = f"{python_function_task.task_function.__module__}.mapper_{python_function_task.task_function.__name__}_{instance}" self._run_task = python_function_task self._max_concurrency = concurrency diff --git a/plugins/flytekit-aws-batch/README.md b/plugins/flytekit-aws-batch/README.md new file mode 100644 index 0000000000..b56663237b --- /dev/null +++ b/plugins/flytekit-aws-batch/README.md @@ -0,0 +1,9 @@ +# Flytekit AWS Batch Plugin + +Flyte backend can be connected with AWS batch. Once enabled, it allows you to run flyte task on AWS batch service + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-awsbatch +``` diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/__init__.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/__init__.py new file mode 100644 index 0000000000..244716ebe7 --- /dev/null +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/__init__.py @@ -0,0 +1 @@ +from .task import AWSBatchConfig diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py new file mode 100644 index 0000000000..c9b30b7af1 --- /dev/null +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +from dataclasses_json import dataclass_json +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Struct + +from flytekit import PythonFunctionTask +from flytekit.extend import SerializationSettings, TaskPlugins + + +@dataclass_json +@dataclass +class AWSBatchConfig(object): + """ + Use this to configure SubmitJobInput for a AWS batch job. Task's marked with this will automatically execute + natively onto AWS batch service. + Refer to AWS SubmitJobInput for more detail: https://docs.aws.amazon.com/sdk-for-go/api/service/batch/#SubmitJobInput + """ + + parameters: Optional[Dict[str, str]] = None + schedulingPriority: Optional[int] = None + platformCapabilities: str = "EC2" + propagateTags: Optional[bool] = None + tags: Optional[Dict[str, str]] = None + + def to_dict(self): + s = Struct() + s.update(self.to_dict()) + return json_format.MessageToDict(s) + + +class AWSBatchFunctionTask(PythonFunctionTask): + """ + Actual Plugin that transforms the local python code for execution within AWS batch job + """ + + _AWS_BATCH_TASK_TYPE = "aws-batch" + + def __init__(self, task_config: AWSBatchConfig, task_function: Callable, **kwargs): + if task_config is None: + task_config = AWSBatchConfig() + super(AWSBatchFunctionTask, self).__init__( + task_config=task_config, task_type=self._AWS_BATCH_TASK_TYPE, task_function=task_function, **kwargs + ) + self._task_config = task_config + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + # task_config will be used to create SubmitJobInput in propeller except platformCapabilities. + return self._task_config.to_dict() + + def get_config(self, settings: SerializationSettings) -> Dict[str, str]: + # Parameters in taskTemplate config will be used to create aws job definition. + # More detail about job definition: https://docs.aws.amazon.com/batch/latest/userguide/job_definition_parameters.html + return {"platformCapabilities": self._task_config.platformCapabilities} + + def get_command(self, settings: SerializationSettings) -> List[str]: + container_args = [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + # As of FlytePropeller v0.16.28, aws array batch plugin support to run single job. + # This task will call aws batch plugin to execute the task on aws batch service. + # For single job, FlytePropeller will always read the output from this directory (outputPrefix/0) + # More detail, see https://github.com/flyteorg/flyteplugins/blob/0dd93c23ed2edeca65d58e89b0edb613f88120e0/go/tasks/plugins/array/catalog.go#L501. + "{{.outputPrefix}}/0", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--resolver", + self.task_resolver.location, + "--", + *self.task_resolver.loader_args(settings, self), + ] + + return container_args + + +# Inject the AWS batch plugin into flytekits dynamic plugin loading system +TaskPlugins.register_pythontask_plugin(AWSBatchConfig, AWSBatchFunctionTask) diff --git a/plugins/flytekit-aws-batch/requirements.in b/plugins/flytekit-aws-batch/requirements.in new file mode 100644 index 0000000000..4c2c1fae97 --- /dev/null +++ b/plugins/flytekit-aws-batch/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-awsbatch diff --git a/plugins/flytekit-aws-batch/requirements.txt b/plugins/flytekit-aws-batch/requirements.txt new file mode 100644 index 0000000000..9181490234 --- /dev/null +++ b/plugins/flytekit-aws-batch/requirements.txt @@ -0,0 +1,148 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-awsbatch + # via -r requirements.in +arrow==1.2.1 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +certifi==2021.10.8 + # via requests +chardet==4.0.0 + # via binaryornot +charset-normalizer==2.0.10 + # via requests +checksumdir==1.2.0 + # via flytekit +click==7.1.2 + # via + # cookiecutter + # flytekit +cloudpickle==2.0.0 + # via flytekit +cookiecutter==1.7.3 + # via flytekit +croniter==1.2.0 + # via flytekit +dataclasses-json==0.5.6 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.13 + # via flytekit +flyteidl==0.21.23 + # via flytekit +flytekit==0.26.0 + # via flytekitplugins-awsbatch +grpcio==1.43.0 + # via flytekit +idna==3.3 + # via requests +importlib-metadata==4.10.1 + # via keyring +jinja2==3.0.3 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +keyring==23.5.0 + # via flytekit +markupsafe==2.0.1 + # via jinja2 +marshmallow==3.14.1 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.0.2 + # via flytekit +numpy==1.22.1 + # via + # pandas + # pyarrow +pandas==1.3.5 + # via flytekit +poyo==0.5.0 + # via cookiecutter +protobuf==3.19.3 + # via + # flyteidl + # flytekit +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +python-dateutil==2.8.1 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.2 + # via flytekit +python-slugify==5.0.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2021.3 + # via + # flytekit + # pandas +regex==2022.1.18 + # via docker-image-py +requests==2.27.1 + # via + # cookiecutter + # flytekit + # responses +responses==0.17.0 + # via flytekit +retry==0.9.2 + # via flytekit +six==1.16.0 + # via + # cookiecutter + # flytekit + # grpcio + # python-dateutil + # responses +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +typing-extensions==4.0.1 + # via typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.8 + # via + # flytekit + # requests + # responses +wheel==0.37.1 + # via flytekit +wrapt==1.13.3 + # via + # deprecated + # flytekit +zipp==3.7.0 + # via importlib-metadata diff --git a/plugins/flytekit-aws-batch/setup.py b/plugins/flytekit-aws-batch/setup.py new file mode 100644 index 0000000000..43613fe244 --- /dev/null +++ b/plugins/flytekit-aws-batch/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "awsbatch" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=0.19.0,<1.0.0"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package holds the AWS Batch plugins for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-aws-batch/tests/__init__.py b/plugins/flytekit-aws-batch/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-aws-batch/tests/test_aws_batch.py b/plugins/flytekit-aws-batch/tests/test_aws_batch.py new file mode 100644 index 0000000000..5cd1e5e5c2 --- /dev/null +++ b/plugins/flytekit-aws-batch/tests/test_aws_batch.py @@ -0,0 +1,49 @@ +from flytekitplugins.awsbatch import AWSBatchConfig + +from flytekit import PythonFunctionTask, task +from flytekit.extend import Image, ImageConfig, SerializationSettings + +config = AWSBatchConfig( + parameters={"codec": "mp4"}, + platformCapabilities="EC2", + propagateTags=True, + tags={"hello": "world"}, +) + + +def test_aws_batch_task(): + @task(task_config=config) + def t1(a: int) -> str: + inc = a + 2 + return str(inc) + + assert t1.task_config is not None + assert t1.task_config == config + assert t1.task_type == "aws-batch" + assert isinstance(t1, PythonFunctionTask) + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + assert t1.get_custom(settings) == config.to_dict() + assert t1.get_command(settings) == [ + "pyflyte-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}/0", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "tests.test_aws_batch", + "task-name", + "t1", + ]