From 315956f45978ed121d04149ffe1f4d802a93c6e3 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Mon, 19 Dec 2022 10:49:24 -0800 Subject: [PATCH 1/8] Restrict numpy due to deprecated aliases (#1376) Signed-off-by: Eduardo Apolinario Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- setup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.py b/setup.py index 44266c708c..3d9710004f 100644 --- a/setup.py +++ b/setup.py @@ -79,6 +79,9 @@ "cloudpickle>=2.0.0", "cookiecutter>=1.7.3", "numpy<1.22.0; python_version < '3.8.0'", + # TODO: We should remove mentions to the deprecated numpy + # aliases. More details in https://github.com/flyteorg/flyte/issues/3166 + "numpy<1.24.0", ], extras_require=extras_require, scripts=[ From e4911e7415ec5670947f6fe90a90544934e53306 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 20 Dec 2022 06:44:10 +0800 Subject: [PATCH 2/8] Add Databricks config to Spark Job (#1358) Signed-off-by: Kevin Su --- .github/workflows/pythonbuild.yml | 2 +- dev-requirements.txt | 2 +- doc-requirements.txt | 2 +- .../flytekitplugins/spark/models.py | 70 ++++++++++++++++--- .../flytekitplugins/spark/task.py | 23 ++++++ plugins/flytekit-spark/requirements.txt | 2 +- .../flytekit-spark/tests/test_spark_task.py | 41 ++++++++++- requirements-spark2.txt | 2 +- requirements.txt | 2 +- 9 files changed, 131 insertions(+), 15 deletions(-) diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 4c96e9f011..4c16153be9 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -80,7 +80,7 @@ jobs: - flytekit-modin - flytekit-onnx-pytorch - flytekit-onnx-scikitlearn - # onnxx-tensorflow needs a version of tensorflow that does not work with protobuf>4. + # onnx-tensorflow needs a version of tensorflow that does not work with protobuf>4. # The issue is being tracked on the tensorflow side in https://github.com/tensorflow/tensorflow/issues/53234#issuecomment-1330111693 # flytekit-onnx-tensorflow - flytekit-pandera diff --git a/dev-requirements.txt b/dev-requirements.txt index a9de992c14..14510b09cb 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -130,7 +130,7 @@ filelock==3.8.2 # via virtualenv flatbuffers==22.12.6 # via tensorflow -flyteidl==1.3.0 +flyteidl==1.3.1 # via # -c requirements.txt # flytekit diff --git a/doc-requirements.txt b/doc-requirements.txt index be5d737bfd..7c92fcb018 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -182,7 +182,7 @@ filelock==3.8.2 # virtualenv flatbuffers==22.12.6 # via tensorflow -flyteidl==1.3.0 +flyteidl==1.3.1 # via flytekit fonttools==4.38.0 # via matplotlib diff --git a/plugins/flytekit-spark/flytekitplugins/spark/models.py b/plugins/flytekit-spark/flytekitplugins/spark/models.py index 53b1620331..28e67ac631 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/models.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/models.py @@ -1,7 +1,9 @@ import enum -import typing +from typing import Dict, Optional from flyteidl.plugins import spark_pb2 as _spark_task +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Struct from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common @@ -17,12 +19,15 @@ class SparkType(enum.Enum): class SparkJob(_common.FlyteIdlEntity): def __init__( self, - spark_type, - application_file, - main_class, - spark_conf, - hadoop_conf, - executor_path, + spark_type: SparkType, + application_file: str, + main_class: str, + spark_conf: Dict[str, str], + hadoop_conf: Dict[str, str], + executor_path: str, + databricks_conf: Dict[str, Dict[str, Dict]] = {}, + databricks_token: Optional[str] = None, + databricks_instance: Optional[str] = None, ): """ This defines a SparkJob target. It will execute the appropriate SparkJob. @@ -30,6 +35,9 @@ def __init__( :param application_file: The main application file to execute. :param dict[Text, Text] spark_conf: A definition of key-value pairs for spark config for the job. :param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job. + :param Optional[dict[Text, dict]] databricks_conf: A definition of key-value pairs for databricks config for the job. Refer to https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit. + :param Optional[str] databricks_token: databricks access token. + :param Optional[str] databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. """ self._application_file = application_file self._spark_type = spark_type @@ -37,9 +45,15 @@ def __init__( self._executor_path = executor_path self._spark_conf = spark_conf self._hadoop_conf = hadoop_conf + self._databricks_conf = databricks_conf + self._databricks_token = databricks_token + self._databricks_instance = databricks_instance def with_overrides( - self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None + self, + new_spark_conf: Optional[Dict[str, str]] = None, + new_hadoop_conf: Optional[Dict[str, str]] = None, + new_databricks_conf: Optional[Dict[str, Dict]] = None, ) -> "SparkJob": if not new_spark_conf: new_spark_conf = self.spark_conf @@ -47,12 +61,18 @@ def with_overrides( if not new_hadoop_conf: new_hadoop_conf = self.hadoop_conf + if not new_databricks_conf: + new_databricks_conf = self.databricks_conf + return SparkJob( spark_type=self.spark_type, application_file=self.application_file, main_class=self.main_class, spark_conf=new_spark_conf, hadoop_conf=new_hadoop_conf, + databricks_conf=new_databricks_conf, + databricks_token=self.databricks_token, + databricks_instance=self.databricks_instance, executor_path=self.executor_path, ) @@ -104,6 +124,31 @@ def hadoop_conf(self): """ return self._hadoop_conf + @property + def databricks_conf(self) -> Dict[str, Dict]: + """ + databricks_conf: Databricks job configuration. + Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure + :rtype: dict[Text, dict[Text, Text]] + """ + return self._databricks_conf + + @property + def databricks_token(self) -> str: + """ + Databricks access token + :rtype: str + """ + return self._databricks_token + + @property + def databricks_instance(self) -> str: + """ + Domain name of your deployment. Use the form .cloud.databricks.com. + :rtype: str + """ + return self._databricks_instance + def to_flyte_idl(self): """ :rtype: flyteidl.plugins.spark_pb2.SparkJob @@ -120,6 +165,9 @@ def to_flyte_idl(self): else: raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified") + databricks_conf = Struct() + databricks_conf.update(self.databricks_conf) + return _spark_task.SparkJob( applicationType=application_type, mainApplicationFile=self.application_file, @@ -127,6 +175,9 @@ def to_flyte_idl(self): executorPath=self.executor_path, sparkConf=self.spark_conf, hadoopConf=self.hadoop_conf, + databricksConf=databricks_conf, + databricksToken=self.databricks_token, + databricksInstance=self.databricks_instance, ) @classmethod @@ -151,4 +202,7 @@ def from_flyte_idl(cls, pb2_object): main_class=pb2_object.mainClass, hadoop_conf=pb2_object.hadoopConf, executor_path=pb2_object.executorPath, + databricks_conf=json_format.MessageToDict(pb2_object.databricksConf), + databricks_token=pb2_object.databricksToken, + databricks_instance=pb2_object.databricksInstance, ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 8428b492ce..180a28bb87 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -36,6 +36,23 @@ def __post_init__(self): self.hadoop_conf = {} +@dataclass +class Databricks(Spark): + """ + Use this to configure a Databricks task. Task's marked with this will automatically execute + natively onto databricks platform as a distributed execution of spark + + Args: + databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure + databricks_token: Databricks access token. https://docs.databricks.com/dev-tools/api/latest/authentication.html. + databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. + """ + + databricks_conf: typing.Optional[Dict[str, typing.Union[str, dict]]] = None + databricks_token: Optional[str] = None + databricks_instance: Optional[str] = None + + # This method does not reset the SparkSession since it's a bit hard to handle multiple # Spark sessions in a single application as it's described in: # https://stackoverflow.com/questions/41491972/how-can-i-tear-down-a-sparksession-and-create-a-new-one-within-one-application. @@ -100,6 +117,12 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: main_class="", spark_type=SparkType.PYTHON, ) + if isinstance(self.task_config, Databricks): + cfg = typing.cast(self.task_config, Databricks) + job._databricks_conf = cfg.databricks_conf + job._databricks_token = cfg.databricks_token + job._databricks_instance = cfg.databricks_instance + return MessageToDict(job.to_flyte_idl()) def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: diff --git a/plugins/flytekit-spark/requirements.txt b/plugins/flytekit-spark/requirements.txt index a8df06cc24..b7087441ef 100644 --- a/plugins/flytekit-spark/requirements.txt +++ b/plugins/flytekit-spark/requirements.txt @@ -46,7 +46,7 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.0 +flyteidl==1.3.1 # via flytekit flytekit==1.3.0b2 # via flytekitplugins-spark diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 38555fa9b8..7049684a2d 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -2,7 +2,7 @@ import pyspark import pytest from flytekitplugins.spark import Spark -from flytekitplugins.spark.task import new_spark_session +from flytekitplugins.spark.task import Databricks, new_spark_session from pyspark.sql import SparkSession import flytekit @@ -19,6 +19,23 @@ def reset_spark_session() -> None: def test_spark_task(reset_spark_session): + databricks_conf = { + "name": "flytekit databricks plugin example", + "new_cluster": { + "spark_version": "11.0.x-scala2.12", + "node_type_id": "r3.xlarge", + "aws_attributes": {"availability": "ON_DEMAND"}, + "num_workers": 4, + "docker_image": {"url": "pingsutw/databricks:latest"}, + }, + "timeout_seconds": 3600, + "max_retries": 1, + "spark_python_task": { + "python_file": "dbfs:///FileStore/tables/entrypoint-1.py", + "parameters": "ls", + }, + } + @task(task_config=Spark(spark_conf={"spark": "1"})) def my_spark(a: str) -> int: session = flytekit.current_context().spark_session @@ -53,6 +70,28 @@ def my_spark(a: str) -> int: assert ("spark", "1") in configs assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs + databricks_token = "token" + databricks_instance = "account.cloud.databricks.com" + + @task( + task_config=Databricks( + spark_conf={"spark": "2"}, + databricks_conf=databricks_conf, + databricks_instance="account.cloud.databricks.com", + databricks_token="token", + ) + ) + def my_databricks(a: str) -> int: + session = flytekit.current_context().spark_session + assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" + return 10 + + assert my_databricks.task_config is not None + assert my_databricks.task_config.spark_conf == {"spark": "2"} + assert my_databricks.task_config.databricks_conf == databricks_conf + assert my_databricks.task_config.databricks_instance == databricks_instance + assert my_databricks.task_config.databricks_token == databricks_token + def test_new_spark_session(): name = "SessionName" diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 3c07cb3cee..c6d0ff7fc0 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -52,7 +52,7 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.0 +flyteidl==1.3.1 # via flytekit googleapis-common-protos==1.57.0 # via diff --git a/requirements.txt b/requirements.txt index caff0db497..5623078a25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,7 +50,7 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.0 +flyteidl==1.3.1 # via flytekit googleapis-common-protos==1.57.0 # via From d9878b61095271a899af0ad7798fefbe3178ed23 Mon Sep 17 00:00:00 2001 From: "H. Furkan Vural" <33652917+hfurkanvural@users.noreply.github.com> Date: Tue, 20 Dec 2022 18:41:38 +0100 Subject: [PATCH 3/8] Add overwrite_cache option the to calls of remote and local executions (#1375) Signed-off-by: H. Furkan Vural Implemented cache overwrite feature is added on flytekit as well for the completeness. In order to support the cache eviction RFC, an overwrite parameter was added, indicating the data store should replace an existing artifact instead of creating a new one on local calls. --- flytekit/models/execution.py | 4 ++++ flytekit/remote/remote.py | 27 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 08fb3c938d..9c2d5ba2ec 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -175,6 +175,7 @@ def __init__( raw_output_data_config=None, max_parallelism=None, security_context: typing.Optional[security.SecurityContext] = None, + overwrite_cache: bool = None, ): """ :param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute @@ -200,6 +201,7 @@ def __init__( self._raw_output_data_config = raw_output_data_config self._max_parallelism = max_parallelism self._security_context = security_context + self.overwrite_cache = overwrite_cache @property def launch_plan(self): @@ -283,6 +285,7 @@ def to_flyte_idl(self): else None, max_parallelism=self.max_parallelism, security_context=self.security_context.to_flyte_idl() if self.security_context else None, + overwrite_cache=self.overwrite_cache, ) @classmethod @@ -306,6 +309,7 @@ def from_flyte_idl(cls, p): security_context=security.SecurityContext.from_flyte_idl(p.security_context) if p.security_context else None, + overwrite_cache=p.overwrite_cache, ) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 00227f88f0..14cd7e11bb 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -744,6 +744,7 @@ def _execute( options: typing.Optional[Options] = None, wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """Common method for execution across all entities. @@ -755,6 +756,9 @@ def _execute( :param wait: if True, waits for execution to complete :param type_hints: map of python types to inputs so that the TypeEngine knows how to convert the input values into Flyte Literals. + :param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten + for a single execution. If enabled, all calculations are performed even if cached results would + be available, overwriting the stored data once execution finishes successfully. :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` """ execution_name = execution_name or "f" + uuid.uuid4().hex[:19] @@ -810,6 +814,7 @@ def _execute( "placeholder", # Admin replaces this from oidc token if auth is enabled. 0, ), + overwrite_cache=overwrite_cache, notifications=notifications, disable_all=options.disable_notifications, labels=options.labels, @@ -873,6 +878,7 @@ def execute( options: typing.Optional[Options] = None, wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """ Execute a task, workflow, or launchplan, either something that's been declared locally, or a fetched entity. @@ -906,6 +912,9 @@ def execute( using the type engine, and then to ``type(v)``. Providing the correct Python types is particularly important if the inputs are containers like lists or maps, or if the Python type is one of the more complex Flyte provided classes (like a StructuredDataset that's annotated with columns). + :param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten + for a single execution. If enabled, all calculations are performed even if cached results would + be available, overwriting the stored data once execution finishes successfully. .. note: @@ -924,6 +933,7 @@ def execute( options=options, wait=wait, type_hints=type_hints, + overwrite_cache=overwrite_cache, ) if isinstance(entity, FlyteWorkflow): return self.execute_remote_wf( @@ -935,6 +945,7 @@ def execute( options=options, wait=wait, type_hints=type_hints, + overwrite_cache=overwrite_cache, ) if isinstance(entity, PythonTask): return self.execute_local_task( @@ -947,6 +958,7 @@ def execute( execution_name=execution_name, image_config=image_config, wait=wait, + overwrite_cache=overwrite_cache, ) if isinstance(entity, WorkflowBase): return self.execute_local_workflow( @@ -960,6 +972,7 @@ def execute( image_config=image_config, options=options, wait=wait, + overwrite_cache=overwrite_cache, ) if isinstance(entity, LaunchPlan): return self.execute_local_launch_plan( @@ -971,6 +984,7 @@ def execute( execution_name=execution_name, options=options, wait=wait, + overwrite_cache=overwrite_cache, ) raise NotImplementedError(f"entity type {type(entity)} not recognized for execution") @@ -987,6 +1001,7 @@ def execute_remote_task_lp( options: typing.Optional[Options] = None, wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """Execute a FlyteTask, or FlyteLaunchplan. @@ -1001,6 +1016,7 @@ def execute_remote_task_lp( wait=wait, options=options, type_hints=type_hints, + overwrite_cache=overwrite_cache, ) def execute_remote_wf( @@ -1013,6 +1029,7 @@ def execute_remote_wf( options: typing.Optional[Options] = None, wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """Execute a FlyteWorkflow. @@ -1028,6 +1045,7 @@ def execute_remote_wf( options=options, wait=wait, type_hints=type_hints, + overwrite_cache=overwrite_cache, ) # Flytekit Entities @@ -1044,6 +1062,7 @@ def execute_local_task( execution_name: str = None, image_config: typing.Optional[ImageConfig] = None, wait: bool = False, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """ Execute an @task-decorated function or TaskTemplate task. @@ -1058,6 +1077,7 @@ def execute_local_task( :param execution_name: :param image_config: :param wait: + :param overwrite_cache: :return: """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) @@ -1084,6 +1104,7 @@ def execute_local_task( execution_name=execution_name, wait=wait, type_hints=entity.python_interface.inputs, + overwrite_cache=overwrite_cache, ) def execute_local_workflow( @@ -1098,6 +1119,7 @@ def execute_local_workflow( image_config: typing.Optional[ImageConfig] = None, options: typing.Optional[Options] = None, wait: bool = False, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """ Execute an @workflow decorated function. @@ -1111,6 +1133,7 @@ def execute_local_workflow( :param image_config: :param options: :param wait: + :param overwrite_cache: :return: """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) @@ -1155,6 +1178,7 @@ def execute_local_workflow( wait=wait, options=options, type_hints=entity.python_interface.inputs, + overwrite_cache=overwrite_cache, ) def execute_local_launch_plan( @@ -1167,6 +1191,7 @@ def execute_local_launch_plan( execution_name: typing.Optional[str] = None, options: typing.Optional[Options] = None, wait: bool = False, + overwrite_cache: bool = None, ) -> FlyteWorkflowExecution: """ @@ -1178,6 +1203,7 @@ def execute_local_launch_plan( :param execution_name: If specified, will be used as the execution name instead of randomly generating. :param options: :param wait: + :param overwrite_cache: :return: """ try: @@ -1203,6 +1229,7 @@ def execute_local_launch_plan( options=options, wait=wait, type_hints=entity.python_interface.inputs, + overwrite_cache=overwrite_cache, ) ################################### From b3bfef5815623ba3e717540ce700b21d494fefe9 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 20 Dec 2022 15:10:29 -0800 Subject: [PATCH 4/8] Remove project/domain from being overridden with execution values in serialized context (#1378) Signed-off-by: Yee Hing Tong --- flytekit/bin/entrypoint.py | 2 -- .../unit/bin/test_python_entrypoint.py | 18 ++++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 1f8dd78ef0..3d5017675e 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -264,8 +264,6 @@ def setup_execution( if compressed_serialization_settings: ss = SerializationSettings.from_transport(compressed_serialization_settings) ssb = ss.new_builder() - ssb.project = exe_project - ssb.domain = exe_domain ssb.version = tk_version if dynamic_addl_distro: ssb.fast_serialization_settings = FastSerializationSettings( diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index c8d1d93cbd..479ad9e7bd 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -1,3 +1,4 @@ +import os import typing from collections import OrderedDict @@ -5,6 +6,7 @@ from flyteidl.core.errors_pb2 import ErrorDocument from flytekit.bin.entrypoint import _dispatch_execute, normalize_inputs, setup_execution +from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.base_task import IgnoreOutputs from flytekit.core.data_persistence import DiskPersistence @@ -290,6 +292,22 @@ def test_setup_cloud_prefix(): assert isinstance(ctx.file_access._default_remote, GCSPersistence) +def test_persist_ss(): + default_img = Image(name="default", fqn="test", tag="tag") + ss = SerializationSettings( + project="proj1", + domain="dom", + version="version123", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + ss_txt = ss.serialized_context + os.environ["_F_SS_C"] = ss_txt + with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx: + assert ctx.serialization_settings.project == "proj1" + assert ctx.serialization_settings.domain == "dom" + + def test_normalize_inputs(): assert normalize_inputs("{{.rawOutputDataPrefix}}", "{{.checkpointOutputPrefix}}", "{{.prevCheckpointPrefix}}") == ( None, From c665616f12ae5fba072c4e9c37848f38cd14e913 Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Tue, 20 Dec 2022 16:34:48 -0800 Subject: [PATCH 5/8] Use TaskSpec instead of TaskTemplate for fetch_task and avoid network when loading module (#1348) Signed-off-by: Ketan Umare --- flytekit/clients/raw.py | 2 - flytekit/clis/sdk_in_container/register.py | 9 + flytekit/core/node.py | 4 + flytekit/core/promise.py | 1 - flytekit/core/python_function_task.py | 7 +- flytekit/remote/__init__.py | 14 +- flytekit/remote/component_nodes.py | 163 ---- flytekit/remote/entities.py | 791 ++++++++++++++++++ flytekit/remote/executions.py | 3 +- flytekit/remote/launch_plan.py | 92 -- flytekit/remote/lazy_entity.py | 62 ++ flytekit/remote/nodes.py | 164 ---- flytekit/remote/remote.py | 93 +- flytekit/remote/remote_callable.py | 4 +- flytekit/remote/task.py | 51 -- flytekit/remote/workflow.py | 149 ---- flytekit/tools/repo.py | 41 +- flytekit/tools/translator.py | 156 ++-- .../responses/CompiledWorkflowClosure.pb | Bin 0 -> 2118 bytes tests/flytekit/unit/remote/test_calling.py | 25 +- .../flytekit/unit/remote/test_lazy_entity.py | 65 ++ tests/flytekit/unit/remote/test_remote.py | 48 +- .../unit/remote/test_with_responses.py | 4 +- .../unit/remote/test_wrapper_classes.py | 26 +- 24 files changed, 1242 insertions(+), 732 deletions(-) delete mode 100644 flytekit/remote/component_nodes.py create mode 100644 flytekit/remote/entities.py delete mode 100644 flytekit/remote/launch_plan.py create mode 100644 flytekit/remote/lazy_entity.py delete mode 100644 flytekit/remote/nodes.py delete mode 100644 flytekit/remote/task.py delete mode 100644 flytekit/remote/workflow.py create mode 100644 tests/flytekit/unit/remote/responses/CompiledWorkflowClosure.pb create mode 100644 tests/flytekit/unit/remote/test_lazy_entity.py diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index bf4c1e2d0c..7c4439d83d 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -79,7 +79,6 @@ def handler(self, create_request): cli_logger.error(_MessageToJson(create_request)) cli_logger.error("Details returned from the flyte admin: ") cli_logger.error(e.details) - e.details += "create_request: " + _MessageToJson(create_request) # Re-raise since we're not handling the error here and add the create_request details raise e @@ -260,7 +259,6 @@ def _refresh_credentials_from_command(self): :param self: RawSynchronousFlyteClient :return: """ - command = self._cfg.command if not command: raise FlyteAuthenticationException("No command specified in configuration for command authentication") diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 1556e343bf..afb6d613fe 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -107,6 +107,12 @@ is_flag=True, help="Enables to skip zipping and uploading the package", ) +@click.option( + "--dry-run", + default=False, + is_flag=True, + help="Execute registration in dry-run mode. Skips actual registration to remote", +) @click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) @click.pass_context def register( @@ -122,6 +128,7 @@ def register( deref_symlinks: bool, non_fast: bool, package_or_module: typing.Tuple[str], + dry_run: bool, ): """ see help @@ -156,6 +163,7 @@ def register( # Create and save FlyteRemote, remote = get_and_save_remote_with_click_context(ctx, project, domain) + click.secho(f"Registering against {remote.config.platform.endpoint}") try: repo.register( project, @@ -170,6 +178,7 @@ def register( fast=not non_fast, package_or_module=package_or_module, remote=remote, + dry_run=dry_run, ) except Exception as e: raise e diff --git a/flytekit/core/node.py b/flytekit/core/node.py index d849ef5397..d8b43f2728 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -51,6 +51,10 @@ def __rshift__(self, other: Node): self.runs_before(other) return other + @property + def name(self) -> str: + return self._id + @property def outputs(self): if self._outputs is None: diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 9f7f84e4bb..53048cb03f 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -870,7 +870,6 @@ def create_and_link_node_from_remote( ) flytekit_node = Node( - # TODO: Better naming, probably a derivative of the function name. id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}", metadata=entity.construct_node_metadata(), bindings=sorted(bindings, key=lambda b: b.var), diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index bcb80f34ca..81f6739a39 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -217,12 +217,7 @@ def compile_into_workflow( for entity, model in model_entities.items(): # We only care about gathering tasks here. Launch plans are handled by # propeller. Subworkflows should already be in the workflow spec. - if not isinstance(entity, Task) and not isinstance(entity, task_models.TaskTemplate): - continue - - # Handle FlyteTask - if isinstance(entity, task_models.TaskTemplate): - tts.append(entity) + if not isinstance(entity, Task) and not isinstance(entity, task_models.TaskSpec): continue # We are currently not supporting reference tasks since these will diff --git a/flytekit/remote/__init__.py b/flytekit/remote/__init__.py index 643d613231..174928a5b4 100644 --- a/flytekit/remote/__init__.py +++ b/flytekit/remote/__init__.py @@ -85,10 +85,14 @@ """ -from flytekit.remote.component_nodes import FlyteTaskNode, FlyteWorkflowNode +from flytekit.remote.entities import ( + FlyteBranchNode, + FlyteLaunchPlan, + FlyteNode, + FlyteTask, + FlyteTaskNode, + FlyteWorkflow, + FlyteWorkflowNode, +) from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution -from flytekit.remote.launch_plan import FlyteLaunchPlan -from flytekit.remote.nodes import FlyteNode from flytekit.remote.remote import FlyteRemote -from flytekit.remote.task import FlyteTask -from flytekit.remote.workflow import FlyteWorkflow diff --git a/flytekit/remote/component_nodes.py b/flytekit/remote/component_nodes.py deleted file mode 100644 index bdf5fab38a..0000000000 --- a/flytekit/remote/component_nodes.py +++ /dev/null @@ -1,163 +0,0 @@ -from typing import Dict - -from flytekit.exceptions import system as _system_exceptions -from flytekit.loggers import remote_logger -from flytekit.models import launch_plan as _launch_plan_model -from flytekit.models import task as _task_model -from flytekit.models.core import identifier as id_models -from flytekit.models.core import workflow as _workflow_model - - -class FlyteTaskNode(_workflow_model.TaskNode): - """ - A class encapsulating a task that a Flyte node needs to execute. - """ - - def __init__(self, flyte_task: "flytekit.remote.task.FlyteTask"): - self._flyte_task = flyte_task - super(FlyteTaskNode, self).__init__(None) - - @property - def reference_id(self) -> id_models.Identifier: - """ - A globally unique identifier for the task. - """ - return self._flyte_task.id - - @property - def flyte_task(self) -> "flytekit.remote.tasks.task.FlyteTask": - return self._flyte_task - - @classmethod - def promote_from_model( - cls, - base_model: _workflow_model.TaskNode, - tasks: Dict[id_models.Identifier, _task_model.TaskTemplate], - ) -> "FlyteTaskNode": - """ - Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it with the - FlyteTask control plane. - - :param base_model: - :param tasks: - """ - from flytekit.remote.task import FlyteTask - - if base_model.reference_id in tasks: - task = tasks[base_model.reference_id] - remote_logger.debug(f"Found existing task template for {task.id}, will not retrieve from Admin") - flyte_task = FlyteTask.promote_from_model(task) - return cls(flyte_task) - - raise _system_exceptions.FlyteSystemException(f"Task template {base_model.reference_id} not found.") - - -class FlyteWorkflowNode(_workflow_model.WorkflowNode): - """A class encapsulating a workflow that a Flyte node needs to execute.""" - - def __init__( - self, - flyte_workflow: "flytekit.remote.workflow.FlyteWorkflow" = None, - flyte_launch_plan: "flytekit.remote.launch_plan.FlyteLaunchPlan" = None, - ): - if flyte_workflow and flyte_launch_plan: - raise _system_exceptions.FlyteSystemException( - "FlyteWorkflowNode cannot be called with both a workflow and a launchplan specified, please pick " - f"one. workflow: {flyte_workflow} launchPlan: {flyte_launch_plan}", - ) - - self._flyte_workflow = flyte_workflow - self._flyte_launch_plan = flyte_launch_plan - super(FlyteWorkflowNode, self).__init__( - launchplan_ref=self._flyte_launch_plan.id if self._flyte_launch_plan else None, - sub_workflow_ref=self._flyte_workflow.id if self._flyte_workflow else None, - ) - - def __repr__(self) -> str: - if self.flyte_workflow is not None: - return f"FlyteWorkflowNode with workflow: {self.flyte_workflow}" - return f"FlyteWorkflowNode with launch plan: {self.flyte_launch_plan}" - - @property - def launchplan_ref(self) -> id_models.Identifier: - """A globally unique identifier for the launch plan, which should map to Admin.""" - return self._flyte_launch_plan.id if self._flyte_launch_plan else None - - @property - def sub_workflow_ref(self): - return self._flyte_workflow.id if self._flyte_workflow else None - - @property - def flyte_launch_plan(self) -> "flytekit.remote.launch_plan.FlyteLaunchPlan": - return self._flyte_launch_plan - - @property - def flyte_workflow(self) -> "flytekit.remote.workflow.FlyteWorkflow": - return self._flyte_workflow - - @classmethod - def promote_from_model( - cls, - base_model: _workflow_model.WorkflowNode, - sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], - node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], - tasks: Dict[id_models.Identifier, _task_model.TaskTemplate], - ) -> "FlyteWorkflowNode": - from flytekit.remote import launch_plan as _launch_plan - from flytekit.remote import workflow as _workflow - - if base_model.launchplan_ref is not None: - return cls( - flyte_launch_plan=_launch_plan.FlyteLaunchPlan.promote_from_model( - base_model.launchplan_ref, node_launch_plans[base_model.launchplan_ref] - ) - ) - elif base_model.sub_workflow_ref is not None: - # the workflow templates for sub-workflows should have been included in the original response - if base_model.reference in sub_workflows: - return cls( - flyte_workflow=_workflow.FlyteWorkflow.promote_from_model( - sub_workflows[base_model.reference], - sub_workflows=sub_workflows, - node_launch_plans=node_launch_plans, - tasks=tasks, - ) - ) - raise _system_exceptions.FlyteSystemException(f"Subworkflow {base_model.reference} not found.") - - raise _system_exceptions.FlyteSystemException( - "Bad workflow node model, neither subworkflow nor launchplan specified." - ) - - -class FlyteBranchNode(_workflow_model.BranchNode): - def __init__(self, if_else: _workflow_model.IfElseBlock): - super().__init__(if_else) - - @classmethod - def promote_from_model( - cls, - base_model: _workflow_model.BranchNode, - sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], - node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], - tasks: Dict[id_models.Identifier, _task_model.TaskTemplate], - ) -> "FlyteBranchNode": - - from flytekit.remote.nodes import FlyteNode - - block = base_model.if_else - - else_node = None - if block.else_node: - else_node = FlyteNode.promote_from_model(block.else_node, sub_workflows, node_launch_plans, tasks) - - block.case._then_node = FlyteNode.promote_from_model( - block.case.then_node, sub_workflows, node_launch_plans, tasks - ) - - for o in block.other: - o._then_node = FlyteNode.promote_from_model(o.then_node, sub_workflows, node_launch_plans, tasks) - - new_if_else_block = _workflow_model.IfElseBlock(block.case, block.other, else_node, block.error) - - return cls(new_if_else_block) diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py new file mode 100644 index 0000000000..0c745c11bb --- /dev/null +++ b/flytekit/remote/entities.py @@ -0,0 +1,791 @@ +"""This module contains shadow entities for all Flyte entities as represented in Flyte Admin / Control Plane. +The goal is to enable easy access, manipulation of these entities. """ +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple, Union + +from flytekit.core import constants as _constants +from flytekit.core import hash as _hash_mixin +from flytekit.core import hash as hash_mixin +from flytekit.exceptions import system as _system_exceptions +from flytekit.exceptions import user as _user_exceptions +from flytekit.loggers import remote_logger +from flytekit.models import interface as _interface_models +from flytekit.models import launch_plan as _launch_plan_model +from flytekit.models import launch_plan as _launch_plan_models +from flytekit.models import launch_plan as launch_plan_models +from flytekit.models import task as _task_model +from flytekit.models import task as _task_models +from flytekit.models.admin.workflow import WorkflowSpec +from flytekit.models.core import compiler as compiler_models +from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import identifier as id_models +from flytekit.models.core import workflow as _workflow_model +from flytekit.models.core import workflow as _workflow_models +from flytekit.models.core.identifier import Identifier +from flytekit.models.core.workflow import Node, WorkflowMetadata, WorkflowMetadataDefaults +from flytekit.models.interface import TypedInterface +from flytekit.models.literals import Binding +from flytekit.models.task import TaskSpec +from flytekit.remote import interface as _interface +from flytekit.remote import interface as _interfaces +from flytekit.remote.remote_callable import RemoteEntity + + +class FlyteTask(hash_mixin.HashOnReferenceMixin, RemoteEntity, TaskSpec): + """A class encapsulating a remote Flyte task.""" + + def __init__( + self, + id, + type, + metadata, + interface, + custom, + container=None, + task_type_version: int = 0, + config=None, + should_register: bool = False, + ): + super(FlyteTask, self).__init__( + template=_task_model.TaskTemplate( + id, + type, + metadata, + interface, + custom, + container=container, + task_type_version=task_type_version, + config=config, + ) + ) + self._should_register = should_register + + @property + def id(self): + """ + This is generated by the system and uniquely identifies the task. + :rtype: flytekit.models.core.identifier.Identifier + """ + return self.template.id + + @property + def type(self): + """ + This is used to identify additional extensions for use by Propeller or SDK. + :rtype: Text + """ + return self.template.type + + @property + def metadata(self): + """ + This contains information needed at runtime to determine behavior such as whether or not outputs are + discoverable, timeouts, and retries. + :rtype: TaskMetadata + """ + return self.template.metadata + + @property + def interface(self): + """ + The interface definition for this task. + :rtype: flytekit.models.interface.TypedInterface + """ + return self.template.interface + + @property + def custom(self): + """ + Arbitrary dictionary containing metadata for custom plugins. + :rtype: dict[Text, T] + """ + return self.template.custom + + @property + def task_type_version(self): + return self.template.task_type_version + + @property + def container(self): + """ + If not None, the target of execution should be a container. + :rtype: Container + """ + return self.template.container + + @property + def config(self): + """ + Arbitrary dictionary containing metadata for parsing and handling custom plugins. + :rtype: dict[Text, T] + """ + return self.template.config + + @property + def security_context(self): + return self.template.security_context + + @property + def k8s_pod(self): + return self.template.k8s_pod + + @property + def sql(self): + return self.template.sql + + @property + def should_register(self) -> bool: + return self._should_register + + @property + def name(self) -> str: + return self.template.id.name + + @property + def resource_type(self) -> _identifier_model.ResourceType: + return _identifier_model.ResourceType.TASK + + @property + def entity_type_text(self) -> str: + return "Task" + + @classmethod + def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> FlyteTask: + t = cls( + id=base_model.id, + type=base_model.type, + metadata=base_model.metadata, + interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), + custom=base_model.custom, + container=base_model.container, + task_type_version=base_model.task_type_version, + ) + # Override the newly generated name if one exists in the base model + if not base_model.id.is_empty: + t._id = base_model.id + + return t + + +class FlyteTaskNode(_workflow_model.TaskNode): + """ + A class encapsulating a task that a Flyte node needs to execute. + """ + + def __init__(self, flyte_task: FlyteTask): + super(FlyteTaskNode, self).__init__(None) + self._flyte_task = flyte_task + + @property + def reference_id(self) -> id_models.Identifier: + """ + A globally unique identifier for the task. + """ + return self._flyte_task.id + + @property + def flyte_task(self) -> FlyteTask: + return self._flyte_task + + @classmethod + def promote_from_model(cls, task: FlyteTask) -> FlyteTaskNode: + """ + Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it with the + FlyteTask control plane. + """ + return cls(flyte_task=task) + + +class FlyteWorkflowNode(_workflow_model.WorkflowNode): + """A class encapsulating a workflow that a Flyte node needs to execute.""" + + def __init__( + self, + flyte_workflow: FlyteWorkflow = None, + flyte_launch_plan: FlyteLaunchPlan = None, + ): + if flyte_workflow and flyte_launch_plan: + raise _system_exceptions.FlyteSystemException( + "FlyteWorkflowNode cannot be called with both a workflow and a launchplan specified, please pick " + f"one. workflow: {flyte_workflow} launchPlan: {flyte_launch_plan}", + ) + + self._flyte_workflow = flyte_workflow + self._flyte_launch_plan = flyte_launch_plan + super(FlyteWorkflowNode, self).__init__( + launchplan_ref=self._flyte_launch_plan.id if self._flyte_launch_plan else None, + sub_workflow_ref=self._flyte_workflow.id if self._flyte_workflow else None, + ) + + def __repr__(self) -> str: + if self.flyte_workflow is not None: + return f"FlyteWorkflowNode with workflow: {self.flyte_workflow}" + return f"FlyteWorkflowNode with launch plan: {self.flyte_launch_plan}" + + @property + def launchplan_ref(self) -> id_models.Identifier: + """A globally unique identifier for the launch plan, which should map to Admin.""" + return self._flyte_launch_plan.id if self._flyte_launch_plan else None + + @property + def sub_workflow_ref(self): + return self._flyte_workflow.id if self._flyte_workflow else None + + @property + def flyte_launch_plan(self) -> FlyteLaunchPlan: + return self._flyte_launch_plan + + @property + def flyte_workflow(self) -> FlyteWorkflow: + return self._flyte_workflow + + @classmethod + def _promote_workflow( + cls, + wf: _workflow_models.WorkflowTemplate, + sub_workflows: Optional[Dict[Identifier, _workflow_models.WorkflowTemplate]] = None, + tasks: Optional[Dict[Identifier, FlyteTask]] = None, + node_launch_plans: Optional[Dict[Identifier, launch_plan_models.LaunchPlanSpec]] = None, + ) -> FlyteWorkflow: + return FlyteWorkflow.promote_from_model( + wf, + sub_workflows=sub_workflows, + node_launch_plans=node_launch_plans, + tasks=tasks, + ) + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_model.WorkflowNode, + sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], + node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], + tasks: Dict[Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[FlyteWorkflowNode, Dict[id_models.Identifier, FlyteWorkflow]]: + if base_model.launchplan_ref is not None: + return ( + cls( + flyte_launch_plan=FlyteLaunchPlan.promote_from_model( + base_model.launchplan_ref, node_launch_plans[base_model.launchplan_ref] + ) + ), + converted_sub_workflows, + ) + elif base_model.sub_workflow_ref is not None: + # the workflow templates for sub-workflows should have been included in the original response + if base_model.reference in sub_workflows: + wf = None + if base_model.reference not in converted_sub_workflows: + wf = cls._promote_workflow( + sub_workflows[base_model.reference], + sub_workflows=sub_workflows, + node_launch_plans=node_launch_plans, + tasks=tasks, + ) + converted_sub_workflows[base_model.reference] = wf + else: + wf = converted_sub_workflows[base_model.reference] + return cls(flyte_workflow=wf), converted_sub_workflows + raise _system_exceptions.FlyteSystemException(f"Subworkflow {base_model.reference} not found.") + + raise _system_exceptions.FlyteSystemException( + "Bad workflow node model, neither subworkflow nor launchplan specified." + ) + + +class FlyteBranchNode(_workflow_model.BranchNode): + def __init__(self, if_else: _workflow_model.IfElseBlock): + super().__init__(if_else) + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_model.BranchNode, + sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], + node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], + tasks: Dict[id_models.Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[FlyteBranchNode, Dict[id_models.Identifier, FlyteWorkflow]]: + + block = base_model.if_else + block.case._then_node, converted_sub_workflows = FlyteNode.promote_from_model( + block.case.then_node, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + + for o in block.other: + o._then_node, converted_sub_workflows = FlyteNode.promote_from_model( + o.then_node, sub_workflows, node_launch_plans, tasks, converted_sub_workflows + ) + + else_node = None + if block.else_node: + else_node, converted_sub_workflows = FlyteNode.promote_from_model( + block.else_node, sub_workflows, node_launch_plans, tasks, converted_sub_workflows + ) + + new_if_else_block = _workflow_model.IfElseBlock(block.case, block.other, else_node, block.error) + + return cls(new_if_else_block), converted_sub_workflows + + +class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): + """A class encapsulating a remote Flyte node.""" + + def __init__( + self, + id, + upstream_nodes, + bindings, + metadata, + task_node: FlyteTaskNode = None, + workflow_node: FlyteWorkflowNode = None, + branch_node: FlyteBranchNode = None, + ): + if not task_node and not workflow_node and not branch_node: + raise _user_exceptions.FlyteAssertion( + "An Flyte node must have one of task|workflow|branch entity specified at once" + ) + # todo: wip - flyte_branch_node is a hack, it should be a Condition, but backing out a Condition object from + # the compiled IfElseBlock is cumbersome, shouldn't do it if we can get away with it. + if task_node: + self._flyte_entity = task_node.flyte_task + elif workflow_node: + self._flyte_entity = workflow_node.flyte_workflow or workflow_node.flyte_launch_plan + else: + self._flyte_entity = branch_node + + super(FlyteNode, self).__init__( + id=id, + metadata=metadata, + inputs=bindings, + upstream_node_ids=[n.id for n in upstream_nodes], + output_aliases=[], + task_node=task_node, + workflow_node=workflow_node, + branch_node=branch_node, + ) + self._upstream = upstream_nodes + + @property + def flyte_entity(self) -> Union[FlyteTask, FlyteWorkflow, FlyteLaunchPlan, FlyteBranchNode]: + return self._flyte_entity + + @classmethod + def _promote_task_node(cls, t: FlyteTask) -> FlyteTaskNode: + return FlyteTaskNode.promote_from_model(t) + + @classmethod + def _promote_workflow_node( + cls, + wn: _workflow_model.WorkflowNode, + sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], + node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], + tasks: Dict[Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[FlyteWorkflowNode, Dict[id_models.Identifier, FlyteWorkflow]]: + return FlyteWorkflowNode.promote_from_model( + wn, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + + @classmethod + def promote_from_model( + cls, + model: _workflow_model.Node, + sub_workflows: Optional[Dict[id_models.Identifier, _workflow_model.WorkflowTemplate]], + node_launch_plans: Optional[Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec]], + tasks: Dict[id_models.Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[Optional[FlyteNode], Dict[id_models.Identifier, FlyteWorkflow]]: + node_model_id = model.id + # TODO: Consider removing + if id in {_constants.START_NODE_ID, _constants.END_NODE_ID}: + remote_logger.warning(f"Should not call promote from model on a start node or end node {model}") + return None, converted_sub_workflows + + flyte_task_node, flyte_workflow_node, flyte_branch_node = None, None, None + if model.task_node is not None: + if model.task_node.reference_id not in tasks: + raise RuntimeError( + f"Remote Workflow closure does not have task with id {model.task_node.reference_id}." + ) + flyte_task_node = cls._promote_task_node(tasks[model.task_node.reference_id]) + elif model.workflow_node is not None: + flyte_workflow_node, converted_sub_workflows = cls._promote_workflow_node( + model.workflow_node, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + elif model.branch_node is not None: + flyte_branch_node, converted_sub_workflows = FlyteBranchNode.promote_from_model( + model.branch_node, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + else: + raise _system_exceptions.FlyteSystemException( + f"Bad Node model, neither task nor workflow detected, node: {model}" + ) + + # When WorkflowTemplate models (containing node models) are returned by Admin, they've been compiled with a + # start node. In order to make the promoted FlyteWorkflow look the same, we strip the start-node text back out. + # TODO: Consider removing + for model_input in model.inputs: + if ( + model_input.binding.promise is not None + and model_input.binding.promise.node_id == _constants.START_NODE_ID + ): + model_input.binding.promise._node_id = _constants.GLOBAL_INPUT_NODE_ID + + return ( + cls( + id=node_model_id, + upstream_nodes=[], # set downstream, model doesn't contain this information + bindings=model.inputs, + metadata=model.metadata, + task_node=flyte_task_node, + workflow_node=flyte_workflow_node, + branch_node=flyte_branch_node, + ), + converted_sub_workflows, + ) + + @property + def upstream_nodes(self) -> List[FlyteNode]: + return self._upstream + + @property + def upstream_node_ids(self) -> List[str]: + return list(sorted(n.id for n in self.upstream_nodes)) + + def __repr__(self) -> str: + return f"Node(ID: {self.id})" + + +class FlyteWorkflow(_hash_mixin.HashOnReferenceMixin, RemoteEntity, WorkflowSpec): + """A class encapsulating a remote Flyte workflow.""" + + def __init__( + self, + id: id_models.Identifier, + nodes: List[FlyteNode], + interface, + output_bindings, + metadata, + metadata_defaults, + subworkflows: Optional[List[FlyteWorkflow]] = None, + tasks: Optional[List[FlyteTask]] = None, + launch_plans: Optional[Dict[id_models.Identifier, launch_plan_models.LaunchPlanSpec]] = None, + compiled_closure: Optional[compiler_models.CompiledWorkflowClosure] = None, + should_register: bool = False, + ): + # TODO: Remove check + for node in nodes: + for upstream in node.upstream_nodes: + if upstream.id is None: + raise _user_exceptions.FlyteAssertion( + "Some nodes contained in the workflow were not found in the workflow description. Please " + "ensure all nodes are either assigned to attributes within the class or an element in a " + "list, dict, or tuple which is stored as an attribute in the class." + ) + + self._flyte_sub_workflows = subworkflows + template_subworkflows = [] + if subworkflows: + template_subworkflows = [swf.template for swf in subworkflows] + + super(FlyteWorkflow, self).__init__( + template=_workflow_models.WorkflowTemplate( + id=id, + metadata=metadata, + metadata_defaults=metadata_defaults, + interface=interface, + nodes=nodes, + outputs=output_bindings, + ), + sub_workflows=template_subworkflows, + ) + self._flyte_nodes = nodes + + # Optional things that we save for ease of access when promoting from a model or CompiledWorkflowClosure + self._tasks = tasks + self._launch_plans = launch_plans + self._compiled_closure = compiled_closure + self._node_map = None + self._name = id.name + self._should_register = should_register + + @property + def name(self) -> str: + return self._name + + @property + def flyte_tasks(self) -> Optional[List[FlyteTask]]: + return self._tasks + + @property + def should_register(self) -> bool: + return self._should_register + + @property + def flyte_sub_workflows(self) -> List[FlyteWorkflow]: + return self._flyte_sub_workflows + + @property + def entity_type_text(self) -> str: + return "Workflow" + + @property + def resource_type(self): + return id_models.ResourceType.WORKFLOW + + @property + def flyte_nodes(self) -> List[FlyteNode]: + return self._flyte_nodes + + @property + def id(self) -> Identifier: + """ + This is an autogenerated id by the system. The id is globally unique across Flyte. + """ + return self.template.id + + @property + def metadata(self) -> WorkflowMetadata: + """ + This contains information on how to run the workflow. + """ + return self.template.metadata + + @property + def metadata_defaults(self) -> WorkflowMetadataDefaults: + """ + This contains information on how to run the workflow. + :rtype: WorkflowMetadataDefaults + """ + return self.template.metadata_defaults + + @property + def interface(self) -> TypedInterface: + """ + Defines a strongly typed interface for the Workflow (inputs, outputs). This can include some optional + parameters. + """ + return self.template.interface + + @property + def nodes(self) -> List[Node]: + """ + A list of nodes. In addition, "globals" is a special reserved node id that can be used to consume + workflow inputs + """ + return self.template.nodes + + @property + def outputs(self) -> List[Binding]: + """ + A list of output bindings that specify how to construct workflow outputs. Bindings can + pull node outputs or specify literals. All workflow outputs specified in the interface field must be bound + in order for the workflow to be validated. A workflow has an implicit dependency on all of its nodes + to execute successfully in order to bind final outputs. + """ + return self.template.outputs + + @property + def failure_node(self) -> Node: + """ + Node failure_node: A catch-all node. This node is executed whenever the execution engine determines the + workflow has failed. The interface of this node must match the Workflow interface with an additional input + named "error" of type pb.lyft.flyte.core.Error. + """ + return self.template.failure_node + + @classmethod + def get_non_system_nodes(cls, nodes: List[_workflow_models.Node]) -> List[_workflow_models.Node]: + return [n for n in nodes if n.id not in {_constants.START_NODE_ID, _constants.END_NODE_ID}] + + @classmethod + def _promote_node( + cls, + model: _workflow_model.Node, + sub_workflows: Optional[Dict[id_models.Identifier, _workflow_model.WorkflowTemplate]], + node_launch_plans: Optional[Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec]], + tasks: Dict[id_models.Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[Optional[FlyteNode], Dict[id_models.Identifier, FlyteWorkflow]]: + return FlyteNode.promote_from_model(model, sub_workflows, node_launch_plans, tasks, converted_sub_workflows) + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_models.WorkflowTemplate, + sub_workflows: Optional[Dict[Identifier, _workflow_models.WorkflowTemplate]] = None, + tasks: Optional[Dict[Identifier, FlyteTask]] = None, + node_launch_plans: Optional[Dict[Identifier, launch_plan_models.LaunchPlanSpec]] = None, + ) -> FlyteWorkflow: + + base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) + + node_map = {} + converted_sub_workflows = {} + for node in base_model_non_system_nodes: + flyte_node, converted_sub_workflows = cls._promote_node( + node, sub_workflows, node_launch_plans, tasks, converted_sub_workflows + ) + node_map[node.id] = flyte_node + + # Set upstream nodes for each node + for n in base_model_non_system_nodes: + current = node_map[n.id] + for upstream_id in n.upstream_node_ids: + upstream_node = node_map[upstream_id] + current._upstream.append(upstream_node) + + subworkflow_list = [] + if converted_sub_workflows: + subworkflow_list = [v for _, v in converted_sub_workflows.items()] + + task_list = [] + if tasks: + task_list = [t for _, t in tasks.items()] + + # No inputs/outputs specified, see the constructor for more information on the overrides. + wf = cls( + id=base_model.id, + nodes=list(node_map.values()), + metadata=base_model.metadata, + metadata_defaults=base_model.metadata_defaults, + interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), + output_bindings=base_model.outputs, + subworkflows=subworkflow_list, + tasks=task_list, + launch_plans=node_launch_plans, + ) + + wf._node_map = node_map + + return wf + + @classmethod + def _promote_task(cls, t: _task_models.TaskTemplate) -> FlyteTask: + return FlyteTask.promote_from_model(t) + + @classmethod + def promote_from_closure( + cls, + closure: compiler_models.CompiledWorkflowClosure, + node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None, + ): + """ + Extracts out the relevant portions of a FlyteWorkflow from a closure from the control plane. + + :param closure: This is the closure returned by Admin + :param node_launch_plans: The reason this exists is because the compiled closure doesn't have launch plans. + It only has subworkflows and tasks. Why this is is unclear. If supplied, this map of launch plans will be + :return: + """ + sub_workflows = {sw.template.id: sw.template for sw in closure.sub_workflows} + tasks = {} + if closure.tasks: + tasks = {t.template.id: cls._promote_task(t.template) for t in closure.tasks} + + flyte_wf = cls.promote_from_model( + base_model=closure.primary.template, + sub_workflows=sub_workflows, + node_launch_plans=node_launch_plans, + tasks=tasks, + ) + flyte_wf._compiled_closure = closure + return flyte_wf + + +class FlyteLaunchPlan(hash_mixin.HashOnReferenceMixin, RemoteEntity, _launch_plan_models.LaunchPlanSpec): + """A class encapsulating a remote Flyte launch plan.""" + + def __init__(self, id, *args, **kwargs): + super(FlyteLaunchPlan, self).__init__(*args, **kwargs) + # Set all the attributes we expect this class to have + self._id = id + self._name = id.name + + # The interface is not set explicitly unless fetched in an engine context + self._interface = None + # If fetched when creating this object, can store it here. + self._flyte_workflow = None + + @property + def name(self) -> str: + return self._name + + @property + def flyte_workflow(self) -> Optional[FlyteWorkflow]: + return self._flyte_workflow + + @classmethod + def promote_from_model(cls, id: id_models.Identifier, model: _launch_plan_models.LaunchPlanSpec) -> FlyteLaunchPlan: + lp = cls( + id=id, + workflow_id=model.workflow_id, + default_inputs=_interface_models.ParameterMap(model.default_inputs.parameters), + fixed_inputs=model.fixed_inputs, + entity_metadata=model.entity_metadata, + labels=model.labels, + annotations=model.annotations, + auth_role=model.auth_role, + raw_output_data_config=model.raw_output_data_config, + max_parallelism=model.max_parallelism, + security_context=model.security_context, + ) + return lp + + @property + def id(self) -> id_models.Identifier: + return self._id + + @property + def is_scheduled(self) -> bool: + if self.entity_metadata.schedule.cron_expression: + return True + elif self.entity_metadata.schedule.rate and self.entity_metadata.schedule.rate.value: + return True + elif self.entity_metadata.schedule.cron_schedule and self.entity_metadata.schedule.cron_schedule.schedule: + return True + else: + return False + + @property + def workflow_id(self) -> id_models.Identifier: + return self._workflow_id + + @property + def interface(self) -> Optional[_interface.TypedInterface]: + """ + The interface is not technically part of the admin.LaunchPlanSpec in the IDL, however the workflow ID is, and + from the workflow ID, fetch will fill in the interface. This is nice because then you can __call__ the= + object and get a node. + """ + return self._interface + + @property + def resource_type(self) -> id_models.ResourceType: + return id_models.ResourceType.LAUNCH_PLAN + + @property + def entity_type_text(self) -> str: + return "Launch Plan" + + def __repr__(self) -> str: + return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface}) - Spec {super().__repr__()})" diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py index 607b15c889..292b6f0218 100644 --- a/flytekit/remote/executions.py +++ b/flytekit/remote/executions.py @@ -9,8 +9,7 @@ from flytekit.models import node_execution as node_execution_models from flytekit.models.admin import task_execution as admin_task_execution_models from flytekit.models.core import execution as core_execution_models -from flytekit.remote.task import FlyteTask -from flytekit.remote.workflow import FlyteWorkflow +from flytekit.remote.entities import FlyteTask, FlyteWorkflow class RemoteExecutionBase(object): diff --git a/flytekit/remote/launch_plan.py b/flytekit/remote/launch_plan.py deleted file mode 100644 index b6c8e1f9e6..0000000000 --- a/flytekit/remote/launch_plan.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from flytekit.core import hash as hash_mixin -from flytekit.models import interface as _interface_models -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models.core import identifier as id_models -from flytekit.remote import interface as _interface -from flytekit.remote.remote_callable import RemoteEntity - - -class FlyteLaunchPlan(hash_mixin.HashOnReferenceMixin, RemoteEntity, _launch_plan_models.LaunchPlanSpec): - """A class encapsulating a remote Flyte launch plan.""" - - def __init__(self, id, *args, **kwargs): - super(FlyteLaunchPlan, self).__init__(*args, **kwargs) - # Set all the attributes we expect this class to have - self._id = id - self._name = id.name - - # The interface is not set explicitly unless fetched in an engine context - self._interface = None - - @property - def name(self) -> str: - return self._name - - # If fetched when creating this object, can store it here. - self._flyte_workflow = None - - @property - def flyte_workflow(self) -> Optional["FlyteWorkflow"]: - return self._flyte_workflow - - @classmethod - def promote_from_model( - cls, id: id_models.Identifier, model: _launch_plan_models.LaunchPlanSpec - ) -> "FlyteLaunchPlan": - lp = cls( - id=id, - workflow_id=model.workflow_id, - default_inputs=_interface_models.ParameterMap(model.default_inputs.parameters), - fixed_inputs=model.fixed_inputs, - entity_metadata=model.entity_metadata, - labels=model.labels, - annotations=model.annotations, - auth_role=model.auth_role, - raw_output_data_config=model.raw_output_data_config, - max_parallelism=model.max_parallelism, - security_context=model.security_context, - ) - return lp - - @property - def id(self) -> id_models.Identifier: - return self._id - - @property - def is_scheduled(self) -> bool: - if self.entity_metadata.schedule.cron_expression: - return True - elif self.entity_metadata.schedule.rate and self.entity_metadata.schedule.rate.value: - return True - elif self.entity_metadata.schedule.cron_schedule and self.entity_metadata.schedule.cron_schedule.schedule: - return True - else: - return False - - @property - def workflow_id(self) -> id_models.Identifier: - return self._workflow_id - - @property - def interface(self) -> Optional[_interface.TypedInterface]: - """ - The interface is not technically part of the admin.LaunchPlanSpec in the IDL, however the workflow ID is, and - from the workflow ID, fetch will fill in the interface. This is nice because then you can __call__ the= - object and get a node. - """ - return self._interface - - @property - def resource_type(self) -> id_models.ResourceType: - return id_models.ResourceType.LAUNCH_PLAN - - @property - def entity_type_text(self) -> str: - return "Launch Plan" - - def __repr__(self) -> str: - return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface}) - Spec {super().__repr__()})" diff --git a/flytekit/remote/lazy_entity.py b/flytekit/remote/lazy_entity.py new file mode 100644 index 0000000000..b40c6e3ff7 --- /dev/null +++ b/flytekit/remote/lazy_entity.py @@ -0,0 +1,62 @@ +import typing +from threading import Lock + +from flytekit import FlyteContext +from flytekit.remote.remote_callable import RemoteEntity + +T = typing.TypeVar("T", bound=RemoteEntity) + + +class LazyEntity(RemoteEntity, typing.Generic[T]): + """ + Fetches the entity when the entity is called or when the entity is retrieved. + The entity is derived from RemoteEntity so that it behaves exactly like the mimiced entity. + """ + + def __init__(self, name: str, getter: typing.Callable[[], T], *args, **kwargs): + super().__init__(*args, **kwargs) + self._entity = None + self._getter = getter + self._name = name + if not self._getter: + raise ValueError("getter method is required to create a Lazy loadable Remote Entity.") + self._mutex = Lock() + + @property + def name(self) -> str: + return self._name + + def entity_fetched(self) -> bool: + with self._mutex: + return self._entity is not None + + @property + def entity(self) -> T: + """ + If not already fetched / available, then the entity will be force fetched. + """ + with self._mutex: + if self._entity is None: + self._entity = self._getter() + return self._entity + + def __getattr__(self, item: str) -> typing.Any: + """ + Forwards all other attributes to entity, causing the entity to be fetched! + """ + return getattr(self.entity, item) + + def compile(self, ctx: FlyteContext, *args, **kwargs): + return self.entity.compile(ctx, *args, **kwargs) + + def __call__(self, *args, **kwargs): + """ + Forwards the call to the underlying entity. The entity will be fetched if not already present + """ + return self.entity(*args, **kwargs) + + def __repr__(self) -> str: + return str(self) + + def __str__(self) -> str: + return f"Promise for entity [{self._name}]" diff --git a/flytekit/remote/nodes.py b/flytekit/remote/nodes.py deleted file mode 100644 index 0d73678b7e..0000000000 --- a/flytekit/remote/nodes.py +++ /dev/null @@ -1,164 +0,0 @@ -from __future__ import annotations - -from typing import Dict, List, Optional, Union - -from flytekit.core import constants as _constants -from flytekit.core import hash as _hash_mixin -from flytekit.core.promise import NodeOutput -from flytekit.exceptions import system as _system_exceptions -from flytekit.exceptions import user as _user_exceptions -from flytekit.loggers import remote_logger -from flytekit.models import launch_plan as _launch_plan_model -from flytekit.models import task as _task_model -from flytekit.models.core import identifier as id_models -from flytekit.models.core import workflow as _workflow_model -from flytekit.remote import component_nodes as _component_nodes - - -class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): - """A class encapsulating a remote Flyte node.""" - - def __init__( - self, - id, - upstream_nodes, - bindings, - metadata, - flyte_task: Optional["FlyteTask"] = None, - flyte_workflow: Optional["FlyteWorkflow"] = None, - flyte_launch_plan: Optional["FlyteLaunchPlan"] = None, - flyte_branch_node: Optional["FlyteBranchNode"] = None, - ): - # todo: flyte_branch_node is the only non-entity here, feels wrong, it should probably be a Condition - # or the other ones changed. - non_none_entities = list(filter(None, [flyte_task, flyte_workflow, flyte_launch_plan, flyte_branch_node])) - if len(non_none_entities) != 1: - raise _user_exceptions.FlyteAssertion( - "An Flyte node must have one underlying entity specified at once. Received the following " - "entities: {}".format(non_none_entities) - ) - # todo: wip - flyte_branch_node is a hack, it should be a Condition, but backing out a Condition object from - # the compiled IfElseBlock is cumbersome, shouldn't do it if we can get away with it. - self._flyte_entity = flyte_task or flyte_workflow or flyte_launch_plan or flyte_branch_node - - workflow_node = None - if flyte_workflow is not None: - workflow_node = _component_nodes.FlyteWorkflowNode(flyte_workflow=flyte_workflow) - elif flyte_launch_plan is not None: - workflow_node = _component_nodes.FlyteWorkflowNode(flyte_launch_plan=flyte_launch_plan) - - task_node = None - if flyte_task: - task_node = _component_nodes.FlyteTaskNode(flyte_task) - - super(FlyteNode, self).__init__( - id=id, - metadata=metadata, - inputs=bindings, - upstream_node_ids=[n.id for n in upstream_nodes], - output_aliases=[], - task_node=task_node, - workflow_node=workflow_node, - branch_node=flyte_branch_node, - ) - self._upstream = upstream_nodes - - @property - def flyte_entity(self) -> Union["FlyteTask", "FlyteWorkflow", "FlyteLaunchPlan"]: - return self._flyte_entity - - @classmethod - def promote_from_model( - cls, - model: _workflow_model.Node, - sub_workflows: Optional[Dict[id_models.Identifier, _workflow_model.WorkflowTemplate]], - node_launch_plans: Optional[Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec]], - tasks: Optional[Dict[id_models.Identifier, _task_model.TaskTemplate]], - ) -> FlyteNode: - node_model_id = model.id - # TODO: Consider removing - if id in {_constants.START_NODE_ID, _constants.END_NODE_ID}: - remote_logger.warning(f"Should not call promote from model on a start node or end node {model}") - return None - - flyte_task_node, flyte_workflow_node, flyte_branch_node = None, None, None - if model.task_node is not None: - flyte_task_node = _component_nodes.FlyteTaskNode.promote_from_model(model.task_node, tasks) - elif model.workflow_node is not None: - flyte_workflow_node = _component_nodes.FlyteWorkflowNode.promote_from_model( - model.workflow_node, - sub_workflows, - node_launch_plans, - tasks, - ) - elif model.branch_node is not None: - flyte_branch_node = _component_nodes.FlyteBranchNode.promote_from_model( - model.branch_node, sub_workflows, node_launch_plans, tasks - ) - else: - raise _system_exceptions.FlyteSystemException( - f"Bad Node model, neither task nor workflow detected, node: {model}" - ) - - # When WorkflowTemplate models (containing node models) are returned by Admin, they've been compiled with a - # start node. In order to make the promoted FlyteWorkflow look the same, we strip the start-node text back out. - # TODO: Consider removing - for model_input in model.inputs: - if ( - model_input.binding.promise is not None - and model_input.binding.promise.node_id == _constants.START_NODE_ID - ): - model_input.binding.promise._node_id = _constants.GLOBAL_INPUT_NODE_ID - - if flyte_task_node is not None: - return cls( - id=node_model_id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - flyte_task=flyte_task_node.flyte_task, - ) - elif flyte_workflow_node is not None: - if flyte_workflow_node.flyte_workflow is not None: - return cls( - id=node_model_id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - flyte_workflow=flyte_workflow_node.flyte_workflow, - ) - elif flyte_workflow_node.flyte_launch_plan is not None: - return cls( - id=node_model_id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - flyte_launch_plan=flyte_workflow_node.flyte_launch_plan, - ) - raise _system_exceptions.FlyteSystemException( - "Bad FlyteWorkflowNode model, both launch plan and workflow are None" - ) - elif flyte_branch_node is not None: - return cls( - id=node_model_id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - flyte_branch_node=flyte_branch_node, - ) - raise _system_exceptions.FlyteSystemException("Bad FlyteNode model, both task and workflow nodes are empty") - - @property - def upstream_nodes(self) -> List[FlyteNode]: - return self._upstream - - @property - def upstream_node_ids(self) -> List[str]: - return list(sorted(n.id for n in self.upstream_nodes)) - - @property - def outputs(self) -> Dict[str, NodeOutput]: - return self._outputs - - def __repr__(self) -> str: - return f"Node(ID: {self.id})" diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 14cd7e11bb..6473d46ec9 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -53,13 +53,11 @@ NotificationList, WorkflowExecutionGetDataResponse, ) +from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteWorkflow from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution from flytekit.remote.interface import TypedInterface -from flytekit.remote.launch_plan import FlyteLaunchPlan -from flytekit.remote.nodes import FlyteNode +from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote_callable import RemoteEntity -from flytekit.remote.task import FlyteTask -from flytekit.remote.workflow import FlyteWorkflow from flytekit.tools.fast_registration import fast_package from flytekit.tools.script_mode import fast_register_single_script, hash_file from flytekit.tools.translator import ( @@ -75,6 +73,14 @@ MOST_RECENT_FIRST = admin_common_models.Sort("created_at", admin_common_models.Sort.Direction.DESCENDING) +class RegistrationSkipped(Exception): + """ + RegistrationSkipped error is raised when trying to register an entity that is not registrable. + """ + + pass + + @dataclass class ResolvedIdentifiers: project: str @@ -190,6 +196,20 @@ def remote_context(self): FlyteContextManager.current_context().with_file_access(self.file_access) ) + def fetch_task_lazy( + self, project: str = None, domain: str = None, name: str = None, version: str = None + ) -> LazyEntity: + """ + Similar to fetch_task, just that it returns a LazyEntity, which will fetch the workflow lazily. + """ + if name is None: + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") + + def _fetch(): + return self.fetch_task(project=project, domain=domain, name=name, version=version) + + return LazyEntity(name=name, getter=_fetch) + def fetch_task(self, project: str = None, domain: str = None, name: str = None, version: str = None) -> FlyteTask: """Fetch a task entity from flyte admin. @@ -213,14 +233,28 @@ def fetch_task(self, project: str = None, domain: str = None, name: str = None, ) admin_task = self.client.get_task(task_id) flyte_task = FlyteTask.promote_from_model(admin_task.closure.compiled_task.template) - flyte_task._id = task_id + flyte_task.template._id = task_id return flyte_task + def fetch_workflow_lazy( + self, project: str = None, domain: str = None, name: str = None, version: str = None + ) -> LazyEntity[FlyteWorkflow]: + """ + Similar to fetch_workflow, just that it returns a LazyEntity, which will fetch the workflow lazily. + """ + if name is None: + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") + + def _fetch(): + return self.fetch_workflow(project, domain, name, version) + + return LazyEntity(name=name, getter=_fetch) + def fetch_workflow( self, project: str = None, domain: str = None, name: str = None, version: str = None ) -> FlyteWorkflow: - """Fetch a workflow entity from flyte admin. - + """ + Fetch a workflow entity from flyte admin. :param project: fetch entity from this project. If None, uses the default_project attribute. :param domain: fetch entity from this domain. If None, uses the default_domain attribute. :param name: fetch entity with matching name. @@ -237,6 +271,7 @@ def fetch_workflow( name, version, ) + admin_workflow = self.client.get_workflow(workflow_id) compiled_wf = admin_workflow.closure.compiled_workflow @@ -359,8 +394,8 @@ def list_tasks_by_version( def _resolve_identifier(self, t: int, name: str, version: str, ss: SerializationSettings) -> Identifier: ident = Identifier( resource_type=t, - project=ss.project or self.default_project if ss else self.default_project, - domain=ss.domain or self.default_domain if ss else self.default_domain, + project=ss.project if ss and ss.project else self.default_project, + domain=ss.domain if ss and ss.domain else self.default_domain, name=name, version=version or ss.version, ) @@ -374,7 +409,7 @@ def _resolve_identifier(self, t: int, name: str, version: str, ss: Serialization def raw_register( self, cp_entity: FlyteControlPlaneEntity, - settings: typing.Optional[SerializationSettings], + settings: SerializationSettings, version: str, create_default_launchplan: bool = True, options: Options = None, @@ -393,6 +428,15 @@ def raw_register( :param og_entity: Pass in the original workflow (flytekit type) if create_default_launchplan is true :return: Identifier of the created entity """ + if isinstance(cp_entity, RemoteEntity): + if isinstance(cp_entity, (FlyteWorkflow, FlyteTask)): + if not cp_entity.should_register: + remote_logger.debug(f"Skipping registration of remote entity: {cp_entity.name}") + raise RegistrationSkipped(f"Remote task/Workflow {cp_entity.name} is not registrable.") + else: + remote_logger.debug(f"Skipping registration of remote entity: {cp_entity.name}") + raise RegistrationSkipped(f"Remote task/Workflow {cp_entity.name} is not registrable.") + if isinstance( cp_entity, ( @@ -410,6 +454,8 @@ def raw_register( return None if isinstance(cp_entity, task_models.TaskSpec): + if isinstance(cp_entity, FlyteTask): + version = cp_entity.id.version ident = self._resolve_identifier(ResourceType.TASK, cp_entity.template.id.name, version, settings) try: self.client.create_task(task_identifer=ident, task_spec=cp_entity) @@ -418,6 +464,8 @@ def raw_register( return ident if isinstance(cp_entity, admin_workflow_models.WorkflowSpec): + if isinstance(cp_entity, FlyteWorkflow): + version = cp_entity.id.version ident = self._resolve_identifier(ResourceType.WORKFLOW, cp_entity.template.id.name, version, settings) try: self.client.create_workflow(workflow_identifier=ident, workflow_spec=cp_entity) @@ -484,10 +532,6 @@ def _serialize_and_register( ident = None for entity, cp_entity in m.items(): - if isinstance(entity, RemoteEntity): - remote_logger.debug(f"Skipping registration of remote entity: {entity.name}") - continue - if not isinstance(cp_entity, admin_workflow_models.WorkflowSpec) and is_dummy_serialization_setting: # Only in the case of workflows can we use the dummy serialization settings. raise user_exceptions.FlyteValueException( @@ -495,14 +539,17 @@ def _serialize_and_register( f"No serialization settings set, but workflow contains entities that need to be registered. {cp_entity.id.name}", ) - ident = self.raw_register( - cp_entity, - settings=settings, - version=version, - create_default_launchplan=True, - options=options, - og_entity=entity, - ) + try: + ident = self.raw_register( + cp_entity, + settings=settings, + version=version, + create_default_launchplan=True, + options=options, + og_entity=entity, + ) + except RegistrationSkipped: + pass return ident @@ -602,7 +649,7 @@ def _upload_file( filename=to_upload.name, ) self._ctx.file_access.put_data(str(to_upload), upload_location.signed_url) - remote_logger.warning( + remote_logger.debug( f"Uploading {to_upload} to {upload_location.signed_url} native url {upload_location.native_url}" ) diff --git a/flytekit/remote/remote_callable.py b/flytekit/remote/remote_callable.py index c04ec75f66..9adfd4846f 100644 --- a/flytekit/remote/remote_callable.py +++ b/flytekit/remote/remote_callable.py @@ -63,10 +63,10 @@ def __call__(self, *args, **kwargs): return self.execute(**kwargs) def local_execute(self, ctx: FlyteContext, **kwargs) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]: - raise Exception("Remotely fetched entities cannot be run locally. You have to mock this out.") + return self.execute(**kwargs) def execute(self, **kwargs) -> Any: - raise Exception("Remotely fetched entities cannot be run locally. You have to mock this out.") + raise AssertionError(f"Remotely fetched entities cannot be run locally. Please mock the {self.name}.execute.") @property def python_interface(self) -> Optional[Dict[str, Type]]: diff --git a/flytekit/remote/task.py b/flytekit/remote/task.py deleted file mode 100644 index 3c2c8f8d92..0000000000 --- a/flytekit/remote/task.py +++ /dev/null @@ -1,51 +0,0 @@ -from flytekit.core import hash as hash_mixin -from flytekit.models import task as _task_model -from flytekit.models.core import identifier as _identifier_model -from flytekit.remote import interface as _interfaces -from flytekit.remote.remote_callable import RemoteEntity - - -class FlyteTask(hash_mixin.HashOnReferenceMixin, RemoteEntity, _task_model.TaskTemplate): - """A class encapsulating a remote Flyte task.""" - - def __init__(self, id, type, metadata, interface, custom, container=None, task_type_version=0, config=None): - super(FlyteTask, self).__init__( - id, - type, - metadata, - interface, - custom, - container=container, - task_type_version=task_type_version, - config=config, - ) - self._name = id.name - - @property - def name(self) -> str: - return self._name - - @property - def resource_type(self) -> _identifier_model.ResourceType: - return _identifier_model.ResourceType.TASK - - @property - def entity_type_text(self) -> str: - return "Task" - - @classmethod - def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> "FlyteTask": - t = cls( - id=base_model.id, - type=base_model.type, - metadata=base_model.metadata, - interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), - custom=base_model.custom, - container=base_model.container, - task_type_version=base_model.task_type_version, - ) - # Override the newly generated name if one exists in the base model - if not base_model.id.is_empty: - t._id = base_model.id - - return t diff --git a/flytekit/remote/workflow.py b/flytekit/remote/workflow.py deleted file mode 100644 index 3133f8a1fe..0000000000 --- a/flytekit/remote/workflow.py +++ /dev/null @@ -1,149 +0,0 @@ -from __future__ import annotations - -from typing import Dict, List, Optional - -from flytekit.core import constants as _constants -from flytekit.core import hash as _hash_mixin -from flytekit.exceptions import user as _user_exceptions -from flytekit.models import launch_plan as launch_plan_models -from flytekit.models import task as _task_models -from flytekit.models.core import compiler as compiler_models -from flytekit.models.core import identifier as id_models -from flytekit.models.core import workflow as _workflow_models -from flytekit.remote import interface as _interfaces -from flytekit.remote import nodes as _nodes -from flytekit.remote.remote_callable import RemoteEntity - - -class FlyteWorkflow(_hash_mixin.HashOnReferenceMixin, RemoteEntity, _workflow_models.WorkflowTemplate): - """A class encapsulating a remote Flyte workflow.""" - - def __init__( - self, - id: id_models.Identifier, - nodes: List[_nodes.FlyteNode], - interface, - output_bindings, - metadata, - metadata_defaults, - subworkflows: Optional[Dict[id_models.Identifier, _workflow_models.WorkflowTemplate]] = None, - tasks: Optional[Dict[id_models.Identifier, _task_models.TaskTemplate]] = None, - launch_plans: Optional[Dict[id_models.Identifier, launch_plan_models.LaunchPlanSpec]] = None, - compiled_closure: Optional[compiler_models.CompiledWorkflowClosure] = None, - ): - # TODO: Remove check - for node in nodes: - for upstream in node.upstream_nodes: - if upstream.id is None: - raise _user_exceptions.FlyteAssertion( - "Some nodes contained in the workflow were not found in the workflow description. Please " - "ensure all nodes are either assigned to attributes within the class or an element in a " - "list, dict, or tuple which is stored as an attribute in the class." - ) - super(FlyteWorkflow, self).__init__( - id=id, - metadata=metadata, - metadata_defaults=metadata_defaults, - interface=interface, - nodes=nodes, - outputs=output_bindings, - ) - self._flyte_nodes = nodes - - # Optional things that we save for ease of access when promoting from a model or CompiledWorkflowClosure - self._subworkflows = subworkflows - self._tasks = tasks - self._launch_plans = launch_plans - self._compiled_closure = compiled_closure - self._node_map = None - self._name = id.name - - @property - def name(self) -> str: - return self._name - - @property - def sub_workflows(self) -> Optional[Dict[id_models.Identifier, _workflow_models.WorkflowTemplate]]: - return self._subworkflows - - @property - def entity_type_text(self) -> str: - return "Workflow" - - @property - def resource_type(self): - return id_models.ResourceType.WORKFLOW - - @property - def flyte_nodes(self) -> List[_nodes.FlyteNode]: - return self._flyte_nodes - - @classmethod - def get_non_system_nodes(cls, nodes: List[_workflow_models.Node]) -> List[_workflow_models.Node]: - return [n for n in nodes if n.id not in {_constants.START_NODE_ID, _constants.END_NODE_ID}] - - @classmethod - def promote_from_model( - cls, - base_model: _workflow_models.WorkflowTemplate, - sub_workflows: Optional[Dict[id_models, _workflow_models.WorkflowTemplate]] = None, - node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None, - tasks: Optional[Dict[id_models, _task_models.TaskTemplate]] = None, - ) -> FlyteWorkflow: - base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) - sub_workflows = sub_workflows or {} - tasks = tasks or {} - node_map = { - node.id: _nodes.FlyteNode.promote_from_model(node, sub_workflows, node_launch_plans, tasks) - for node in base_model_non_system_nodes - } - - # Set upstream nodes for each node - for n in base_model_non_system_nodes: - current = node_map[n.id] - for upstream_id in n.upstream_node_ids: - upstream_node = node_map[upstream_id] - current._upstream.append(upstream_node) - - # No inputs/outputs specified, see the constructor for more information on the overrides. - wf = cls( - id=base_model.id, - nodes=list(node_map.values()), - metadata=base_model.metadata, - metadata_defaults=base_model.metadata_defaults, - interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), - output_bindings=base_model.outputs, - subworkflows=sub_workflows, - tasks=tasks, - launch_plans=node_launch_plans, - ) - - wf._node_map = node_map - - return wf - - @classmethod - def promote_from_closure( - cls, - closure: compiler_models.CompiledWorkflowClosure, - node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None, - ): - """ - Extracts out the relevant portions of a FlyteWorkflow from a closure from the control plane. - - :param closure: This is the closure returned by Admin - :param node_launch_plans: The reason this exists is because the compiled closure doesn't have launch plans. - It only has subworkflows and tasks. Why this is is unclear. If supplied, this map of launch plans will be - :return: - """ - sub_workflows = {sw.template.id: sw.template for sw in closure.sub_workflows} - tasks = {t.template.id: t.template for t in closure.tasks} - - flyte_wf = FlyteWorkflow.promote_from_model( - base_model=closure.primary.template, - sub_workflows=sub_workflows, - node_launch_plans=node_launch_plans, - tasks=tasks, - ) - flyte_wf._compiled_closure = closure - return flyte_wf diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 870299e5ad..50bac67844 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -10,7 +10,9 @@ from flytekit.core.context_manager import FlyteContextManager from flytekit.loggers import logger from flytekit.models import launch_plan +from flytekit.models.core.identifier import Identifier from flytekit.remote import FlyteRemote +from flytekit.remote.remote import RegistrationSkipped from flytekit.tools import fast_registration, module_loader from flytekit.tools.script_mode import _find_project_root from flytekit.tools.serialize_helpers import get_registrable_entities, persist_registrable_entities @@ -179,6 +181,27 @@ def load_packages_and_modules( return registrable_entities +def secho(i: Identifier, state: str = "success", reason: str = None): + state_ind = "[ ]" + fg = "white" + nl = False + if state == "success": + state_ind = "\r[✔]" + fg = "green" + nl = True + reason = f"successful with version {i.version}" if not reason else reason + elif state == "failed": + state_ind = "\r[x]" + fg = "red" + nl = True + reason = "skipped!" + click.secho( + click.style(f"{state_ind}", fg=fg) + f" Registration {i.name} type {i.resource_type_name()} {reason}", + dim=True, + nl=nl, + ) + + def register( project: str, domain: str, @@ -192,6 +215,7 @@ def register( fast: bool, package_or_module: typing.Tuple[str], remote: FlyteRemote, + dry_run: bool = False, ): detected_root = find_common_root(package_or_module) click.secho(f"Detected Root {detected_root}, using this to create deployable package...", fg="yellow") @@ -234,11 +258,18 @@ def register( if len(serializable_entities) == 0: click.secho("No Flyte entities were detected. Aborting!", fg="red") return - click.secho(f"Found and serialized {len(serializable_entities)} entities") for cp_entity in serializable_entities: - name = cp_entity.id.name if isinstance(cp_entity, launch_plan.LaunchPlan) else cp_entity.template.id.name - click.secho(f" Registering {name}....", dim=True, nl=False) - i = remote.raw_register(cp_entity, serialization_settings, version=version, create_default_launchplan=False) - click.secho(f"done, {i.resource_type_name()} with version {i.version}.", dim=True) + og_id = cp_entity.id if isinstance(cp_entity, launch_plan.LaunchPlan) else cp_entity.template.id + secho(og_id, "") + try: + if not dry_run: + i = remote.raw_register( + cp_entity, serialization_settings, version=version, create_default_launchplan=False + ) + secho(i) + else: + secho(og_id, reason="Dry run Mode!") + except RegistrationSkipped: + secho(og_id, "failed") click.secho(f"Successfully registered {len(serializable_entities)} entities", fg="green") diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index abea7019f1..f0ad5e96c6 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -22,7 +22,6 @@ from flytekit.models import interface as interface_models from flytekit.models import launch_plan as _launch_plan_models from flytekit.models import security -from flytekit.models import task as task_models from flytekit.models.admin import workflow as admin_workflow_models from flytekit.models.core import identifier as _identifier_model from flytekit.models.core import workflow as _core_wf @@ -30,6 +29,7 @@ from flytekit.models.core.workflow import ApproveCondition from flytekit.models.core.workflow import BranchNode as BranchNodeModel from flytekit.models.core.workflow import GateNode, SignalCondition, SleepCondition, TaskNodeOverrides +from flytekit.models.task import TaskSpec, TaskTemplate FlyteLocalEntity = Union[ PythonTask, @@ -43,7 +43,7 @@ ReferenceEntity, ] FlyteControlPlaneEntity = Union[ - task_models.TaskSpec, + TaskSpec, _launch_plan_models.LaunchPlan, admin_workflow_models.WorkflowSpec, workflow_model.Node, @@ -154,10 +154,9 @@ def fn(settings: SerializationSettings) -> List[str]: def get_serializable_task( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, -) -> task_models.TaskSpec: +) -> TaskSpec: task_id = _identifier_model.Identifier( _identifier_model.ResourceType.TASK, settings.project, @@ -197,7 +196,7 @@ def get_serializable_task( pod = entity.get_k8s_pod(settings) entity.reset_command_fn() - tt = task_models.TaskTemplate( + tt = TaskTemplate( id=task_id, type=entity.task_type, metadata=entity.metadata.to_taskmetadata_model(), @@ -212,7 +211,7 @@ def get_serializable_task( ) if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask): entity.reset_command_fn() - return task_models.TaskSpec(template=tt) + return TaskSpec(template=tt) def get_serializable_workflow( @@ -221,18 +220,19 @@ def get_serializable_workflow( entity: WorkflowBase, options: Optional[Options] = None, ) -> admin_workflow_models.WorkflowSpec: - # TODO: Try to move up following config refactor - https://github.com/flyteorg/flyte/issues/2214 - from flytekit.remote.workflow import FlyteWorkflow - - # Get node models - upstream_node_models = [ - get_serializable(entity_mapping, settings, n, options) - for n in entity.nodes - if n.id != _common_constants.GLOBAL_INPUT_NODE_ID - ] - + # Serialize all nodes + serialized_nodes = [] sub_wfs = [] for n in entity.nodes: + # Ignore start nodes + if n.id == _common_constants.GLOBAL_INPUT_NODE_ID: + continue + + # Recursively serialize the node + serialized_nodes.append(get_serializable(entity_mapping, settings, n, options)) + + # If the node is workflow Node or Branch node, we need to handle it specially, to extract all subworkflows, + # so that they can be added to the workflow being serialized if isinstance(n.flyte_entity, WorkflowBase): # We are currently not supporting reference workflows since these will # require a network call to flyteadmin to populate the WorkflowTemplate @@ -249,10 +249,14 @@ def get_serializable_workflow( sub_wfs.append(sub_wf_spec.template) sub_wfs.extend(sub_wf_spec.sub_workflows) + from flytekit.remote import FlyteWorkflow + if isinstance(n.flyte_entity, FlyteWorkflow): - get_serializable(entity_mapping, settings, n.flyte_entity, options) - sub_wfs.append(n.flyte_entity) - sub_wfs.extend([s for s in n.flyte_entity.sub_workflows.values()]) + for swf in n.flyte_entity.flyte_sub_workflows: + sub_wf = get_serializable(entity_mapping, settings, swf, options) + sub_wfs.append(sub_wf.template) + main_wf = get_serializable(entity_mapping, settings, n.flyte_entity, options) + sub_wfs.append(main_wf.template) if isinstance(n.flyte_entity, BranchNode): if_else: workflow_model.IfElseBlock = n.flyte_entity._ifelse_block @@ -288,7 +292,7 @@ def get_serializable_workflow( metadata=entity.workflow_metadata.to_flyte_model(), metadata_defaults=entity.workflow_metadata_defaults.to_flyte_model(), interface=entity.interface, - nodes=upstream_node_models, + nodes=serialized_nodes, outputs=entity.output_bindings, ) return admin_workflow_models.WorkflowSpec( @@ -376,12 +380,7 @@ def get_serializable_node( if entity.flyte_entity is None: raise Exception(f"Node {entity.id} has no flyte entity") - # TODO: Try to move back up following config refactor - https://github.com/flyteorg/flyte/issues/2214 - from flytekit.remote.launch_plan import FlyteLaunchPlan - from flytekit.remote.task import FlyteTask - from flytekit.remote.workflow import FlyteWorkflow - - upstream_sdk_nodes = [ + upstream_nodes = [ get_serializable(entity_mapping, settings, n, options=options) for n in entity.upstream_nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID @@ -395,7 +394,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], ) if ref_template.resource_type == _identifier_model.ResourceType.TASK: @@ -410,13 +409,15 @@ def get_serializable_node( ) return node_model + from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow + if isinstance(entity.flyte_entity, PythonTask): task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources) @@ -431,7 +432,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.template.id), ) @@ -441,7 +442,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], branch_node=get_serializable(entity_mapping, settings, entity.flyte_entity, options=options), ) @@ -459,7 +460,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=node_input, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(launchplan_ref=lp_spec.id), ) @@ -480,7 +481,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], gate_node=gn, ) @@ -492,23 +493,23 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=entity.flyte_entity.id, overrides=TaskNodeOverrides(resources=entity._resources) ), ) elif isinstance(entity.flyte_entity, FlyteWorkflow): - wf_template = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) - for _, sub_wf in entity.flyte_entity.sub_workflows.items(): + wf_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) + for sub_wf in entity.flyte_entity.flyte_sub_workflows: get_serializable(entity_mapping, settings, sub_wf, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], - workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_template.id), + workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.id), ) elif isinstance(entity.flyte_entity, FlyteLaunchPlan): # Recursive call doesn't do anything except put the entity on the map. @@ -523,7 +524,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=node_input, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(launchplan_ref=entity.flyte_entity.id), ) @@ -563,6 +564,54 @@ def get_reference_spec( return ReferenceSpec(template) +def get_serializable_flyte_workflow( + entity: "FlyteWorkflow", settings: SerializationSettings +) -> FlyteControlPlaneEntity: + """ + TODO replace with deep copy + """ + + def _mutate_task_node(tn: workflow_model.TaskNode): + tn.reference_id._project = settings.project + tn.reference_id._domain = settings.domain + + def _mutate_branch_node_task_ids(bn: workflow_model.BranchNode): + _mutate_node(bn.if_else.case.then_node) + for c in bn.if_else.other: + _mutate_node(c.then_node) + if bn.if_else.else_node: + _mutate_node(bn.if_else.else_node) + + def _mutate_workflow_node(wn: workflow_model.WorkflowNode): + wn.sub_workflow_ref._project = settings.project + wn.sub_workflow_ref._domain = settings.domain + + def _mutate_node(n: workflow_model.Node): + if n.task_node: + _mutate_task_node(n.task_node) + elif n.branch_node: + _mutate_branch_node_task_ids(n.branch_node) + elif n.workflow_node: + _mutate_workflow_node(n.workflow_node) + + for n in entity.flyte_nodes: + _mutate_node(n) + + entity.id._project = settings.project + entity.id._domain = settings.domain + + return entity + + +def get_serializable_flyte_task(entity: "FlyteTask", settings: SerializationSettings) -> FlyteControlPlaneEntity: + """ + TODO replace with deep copy + """ + entity.id._project = settings.project + entity.id._domain = settings.domain + return entity + + def get_serializable( entity_mapping: OrderedDict, settings: SerializationSettings, @@ -586,19 +635,16 @@ def get_serializable( :return: The resulting control plane entity, in addition to being added to the mutable entity_mapping parameter is also returned. """ - # TODO: Try to replace following config refactor - https://github.com/flyteorg/flyte/issues/2214 - from flytekit.remote.launch_plan import FlyteLaunchPlan - from flytekit.remote.task import FlyteTask - from flytekit.remote.workflow import FlyteWorkflow - if entity in entity_mapping: return entity_mapping[entity] + from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow + if isinstance(entity, ReferenceEntity): cp_entity = get_reference_spec(entity_mapping, settings, entity) elif isinstance(entity, PythonTask): - cp_entity = get_serializable_task(entity_mapping, settings, entity) + cp_entity = get_serializable_task(settings, entity) elif isinstance(entity, WorkflowBase): cp_entity = get_serializable_workflow(entity_mapping, settings, entity, options) @@ -612,7 +658,21 @@ def get_serializable( elif isinstance(entity, BranchNode): cp_entity = get_serializable_branch_node(entity_mapping, settings, entity, options) - elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow) or isinstance(entity, FlyteLaunchPlan): + elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow): + if entity.should_register: + if isinstance(entity, FlyteTask): + cp_entity = get_serializable_flyte_task(entity, settings) + else: + if entity.should_register: + # We only add the tasks if the should register flag is set. This is to avoid adding + # unnecessary tasks to the registrable list. + for t in entity.flyte_tasks: + get_serializable(entity_mapping, settings, t, options) + cp_entity = get_serializable_flyte_workflow(entity, settings) + else: + cp_entity = entity + + elif isinstance(entity, FlyteLaunchPlan): cp_entity = entity else: @@ -626,7 +686,7 @@ def get_serializable( def gather_dependent_entities( serialized: OrderedDict, ) -> Tuple[ - Dict[_identifier_model.Identifier, task_models.TaskTemplate], + Dict[_identifier_model.Identifier, TaskTemplate], Dict[_identifier_model.Identifier, admin_workflow_models.WorkflowSpec], Dict[_identifier_model.Identifier, _launch_plan_models.LaunchPlanSpec], ]: @@ -639,12 +699,12 @@ def gather_dependent_entities( :param serialized: This should be the filled in OrderedDict used in the get_serializable function above. :return: """ - task_templates: Dict[_identifier_model.Identifier, task_models.TaskTemplate] = {} + task_templates: Dict[_identifier_model.Identifier, TaskTemplate] = {} workflow_specs: Dict[_identifier_model.Identifier, admin_workflow_models.WorkflowSpec] = {} launch_plan_specs: Dict[_identifier_model.Identifier, _launch_plan_models.LaunchPlanSpec] = {} for cp_entity in serialized.values(): - if isinstance(cp_entity, task_models.TaskSpec): + if isinstance(cp_entity, TaskSpec): task_templates[cp_entity.template.id] = cp_entity.template elif isinstance(cp_entity, _launch_plan_models.LaunchPlan): launch_plan_specs[cp_entity.id] = cp_entity.spec diff --git a/tests/flytekit/unit/remote/responses/CompiledWorkflowClosure.pb b/tests/flytekit/unit/remote/responses/CompiledWorkflowClosure.pb new file mode 100644 index 0000000000000000000000000000000000000000..1f3ce5c79a76a959969c4e4bef24ebb9150699f3 GIT binary patch literal 2118 zcmd5--EI;=6fPDN4{Zq3-%M>&sThmd<*$UF3k_6SF)cI{HC~v_!p`c3Wryr8f2nw3 z;)QSEdl+wh2XB1`6Cc1thlL8PqG(L?f@C@KotgQ*^UXQ%-UV-*@XF!7fIPC=HHk?K zljztZkrq*DqfRu_pQ%z>tFRw$i1mg_gf`V{vP#v^rwG`HD#+lr-v&3) zBXeuZ$B^ymt8+~-br+N4vq-AEl1p(OU+u&C=UV$YT_uLPS92v=k1C#JoybOy} zVlHpi4X$rqG|!%d|DD4gUa|s0A1Cy2mQ~kbo&}H(dnXF!5TnoG8GC=zyTjh>Cbg_; zo%tK~oQ)^yq&MtD$9=dtLCb4{W;bntNlRrnM!F+=F8<<=2NO!b;oPS z$13Iu^Km|mFAi4vMfPvOB6tRcw%R1J{xX*g%ocs?gq%dHHP*-#8dAHm8M1}~=~#W= zNYC2E@q$70mQp1`i)bzEZFH@DhV7N$Mb5e@SEo`_C3tsc3t`VNQxoe1!^lwB&2dV9 zL>L&(&Uovy*TY_p8_Wf+p%Yo@U>{o?L{6+gug8>0z}17JN9CN5FO+koLgq-=FBT5+ zhxbdFa=us)PBP`Y2!(V?&5%l~qRLmPCM!>Zk73!Cm_N=O=Fs0!F&KbV+x91=;=SB{ rdEz++*$lJ6?i7zR`2rfB;z8D6Y^Xx!I2X)*cgBNU>6EcXGtca=v{ske literal 0 HcmV?d00001 diff --git a/tests/flytekit/unit/remote/test_calling.py b/tests/flytekit/unit/remote/test_calling.py index 00d80464c3..34e4f8e8b8 100644 --- a/tests/flytekit/unit/remote/test_calling.py +++ b/tests/flytekit/unit/remote/test_calling.py @@ -12,11 +12,10 @@ from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion -from flytekit.models.core.workflow import WorkflowTemplate -from flytekit.models.task import TaskTemplate -from flytekit.remote import FlyteLaunchPlan, FlyteTask +from flytekit.models.admin.workflow import WorkflowSpec +from flytekit.models.task import TaskSpec +from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow from flytekit.remote.interface import TypedInterface -from flytekit.remote.workflow import FlyteWorkflow from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") @@ -63,7 +62,7 @@ def wf(a: int) -> int: serialized = OrderedDict() wf_spec = get_serializable(serialized, serialization_settings, wf) vals = [v for v in serialized.values()] - tts = [f for f in filter(lambda x: isinstance(x, TaskTemplate), vals)] + tts = [f for f in filter(lambda x: isinstance(x, TaskSpec), vals)] assert len(tts) == 1 assert wf_spec.template.nodes[0].id == "foobar" assert wf_spec.template.outputs[0].binding.promise.node_id == "foobar" @@ -143,9 +142,11 @@ def my_subwf(a: int) -> typing.List[int]: def test_calling_wf(): # No way to fetch from Admin in unit tests so we serialize and then promote back serialized = OrderedDict() - wf_spec = get_serializable(serialized, serialization_settings, sub_wf) + wf_spec: WorkflowSpec = get_serializable(serialized, serialization_settings, sub_wf) task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) - fwf = FlyteWorkflow.promote_from_model(wf_spec.template, tasks=task_templates) + fwf = FlyteWorkflow.promote_from_model( + wf_spec.template, tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates.items()} + ) @workflow def parent_1(a: int, b: str) -> typing.Tuple[int, str]: @@ -162,8 +163,14 @@ def parent_1(a: int, b: str) -> typing.Tuple[int, str]: # Pick out the subworkflow templates from the ordereddict. We can't use the output of the gather_dependent_entities # function because that only looks for WorkflowSpecs - subwf_templates = {x.id: x for x in list(filter(lambda x: isinstance(x, WorkflowTemplate), serialized.values()))} - fwf_p1 = FlyteWorkflow.promote_from_model(wf_spec.template, sub_workflows=subwf_templates, tasks=task_templates_p1) + subwf_templates = { + x.template.id: x.template for x in list(filter(lambda x: isinstance(x, WorkflowSpec), serialized.values())) + } + fwf_p1 = FlyteWorkflow.promote_from_model( + wf_spec.template, + sub_workflows=subwf_templates, + tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates_p1.items()}, + ) @workflow def parent_2(a: int, b: str) -> typing.Tuple[int, str]: diff --git a/tests/flytekit/unit/remote/test_lazy_entity.py b/tests/flytekit/unit/remote/test_lazy_entity.py new file mode 100644 index 0000000000..1ed191aea4 --- /dev/null +++ b/tests/flytekit/unit/remote/test_lazy_entity.py @@ -0,0 +1,65 @@ +import pytest +from mock import patch + +from flytekit import TaskMetadata +from flytekit.core import context_manager +from flytekit.models.core.identifier import Identifier, ResourceType +from flytekit.models.interface import TypedInterface +from flytekit.remote import FlyteTask +from flytekit.remote.lazy_entity import LazyEntity + + +def test_missing_getter(): + with pytest.raises(ValueError): + LazyEntity("x", None) + + +dummy_task = FlyteTask( + id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), + type="t", + metadata=TaskMetadata().to_taskmetadata_model(), + interface=TypedInterface(inputs={}, outputs={}), + custom=None, +) + + +def test_lazy_loading(): + once = True + + def _getter(): + nonlocal once + if not once: + raise ValueError("Should be called once only") + once = False + return dummy_task + + e = LazyEntity("x", _getter) + assert e.__repr__() == "Promise for entity [x]" + assert e.name == "x" + assert e._entity is None + assert not e.entity_fetched() + v = e.entity + assert e._entity is not None + assert v == dummy_task + assert e.entity == dummy_task + assert e.entity_fetched() + + +@patch("flytekit.remote.remote_callable.create_and_link_node_from_remote") +def test_lazy_loading_compile(create_and_link_node_from_remote_mock): + once = True + + def _getter(): + nonlocal once + if not once: + raise ValueError("Should be called once only") + once = False + return dummy_task + + e = LazyEntity("x", _getter) + assert e.name == "x" + assert e._entity is None + ctx = context_manager.FlyteContext.current_context() + e.compile(ctx) + assert e._entity is not None + assert e.entity == dummy_task diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index dd37b97f87..01688ea825 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -3,6 +3,7 @@ import tempfile import pytest +from flyteidl.core import compiler_pb2 as _compiler_pb2 from mock import MagicMock, patch import flytekit.configuration @@ -10,10 +11,15 @@ from flytekit.exceptions import user as user_exceptions from flytekit.models import common as common_models from flytekit.models import security -from flytekit.models.core.identifier import ResourceType, WorkflowExecutionIdentifier +from flytekit.models.admin.workflow import Workflow, WorkflowClosure +from flytekit.models.core.compiler import CompiledWorkflowClosure +from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier from flytekit.models.execution import Execution +from flytekit.models.task import Task +from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote import FlyteRemote from flytekit.tools.translator import Options +from tests.flytekit.common.parameterizers import LIST_OF_TASK_CLOSURES CLIENT_METHODS = { ResourceType.WORKFLOW: "list_workflows_paginated", @@ -247,3 +253,43 @@ def test_generate_console_http_domain_sandbox_rewrite(mock_client): os.remove(temp_filename) except OSError: pass + + +def get_compiled_workflow_closure(): + """ + :rtype: flytekit.models.core.compiler.CompiledWorkflowClosure + """ + cwc_pb = _compiler_pb2.CompiledWorkflowClosure() + # So that tests that use this work when run from any directory + basepath = os.path.dirname(__file__) + filepath = os.path.abspath(os.path.join(basepath, "responses", "CompiledWorkflowClosure.pb")) + with open(filepath, "rb") as fh: + cwc_pb.ParseFromString(fh.read()) + + return CompiledWorkflowClosure.from_flyte_idl(cwc_pb) + + +@patch("flytekit.remote.remote.SynchronousFlyteClient") +def test_fetch_lazy(mock_client): + mock_client.get_task.return_value = Task( + id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), closure=LIST_OF_TASK_CLOSURES[0] + ) + + mock_client.get_workflow.return_value = Workflow( + id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), + closure=WorkflowClosure(compiled_workflow=get_compiled_workflow_closure()), + ) + + remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + lw = remote.fetch_workflow_lazy(name="wn", version="v") + assert isinstance(lw, LazyEntity) + assert lw._getter + assert lw._entity is None + assert lw.entity + + lt = remote.fetch_task_lazy(name="n", version="v") + assert isinstance(lw, LazyEntity) + assert lt._getter + assert lt._entity is None + tk = lt.entity + assert tk.name == "n" diff --git a/tests/flytekit/unit/remote/test_with_responses.py b/tests/flytekit/unit/remote/test_with_responses.py index ee3fbb4d8a..7dd7b97910 100644 --- a/tests/flytekit/unit/remote/test_with_responses.py +++ b/tests/flytekit/unit/remote/test_with_responses.py @@ -66,11 +66,11 @@ def test_normal_task(mock_client): ) admin_task = task_models.Task.from_flyte_idl(merge_sort_remotely) mock_client.get_task.return_value = admin_task - ft = rr.fetch_task(name="merge_sort_remotely", version="tst") + remote_task = rr.fetch_task(name="merge_sort_remotely", version="tst") @workflow def my_wf(numbers: typing.List[int], run_local_at_count: int) -> typing.List[int]: - t1_node = create_node(ft, numbers=numbers, run_local_at_count=run_local_at_count) + t1_node = create_node(remote_task, numbers=numbers, run_local_at_count=run_local_at_count) return t1_node.o0 serialization_settings = flytekit.configuration.SerializationSettings( diff --git a/tests/flytekit/unit/remote/test_wrapper_classes.py b/tests/flytekit/unit/remote/test_wrapper_classes.py index 4a08cb7724..82ba538883 100644 --- a/tests/flytekit/unit/remote/test_wrapper_classes.py +++ b/tests/flytekit/unit/remote/test_wrapper_classes.py @@ -9,7 +9,7 @@ from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.workflow import workflow -from flytekit.remote import FlyteWorkflow +from flytekit.remote import FlyteTask, FlyteWorkflow from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") @@ -58,11 +58,14 @@ def wf(b: int) -> int: serialized = OrderedDict() wf_spec = get_serializable(serialized, serialization_settings, wf) - sub_wf_dict = {s.id: s for s in wf_spec.sub_workflows} task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) + sub_wf_dict = {s.id: s for s in wf_spec.sub_workflows} fwf = FlyteWorkflow.promote_from_model( - wf_spec.template, sub_workflows=sub_wf_dict, node_launch_plans=lp_specs, tasks=task_templates + wf_spec.template, + sub_workflows=sub_wf_dict, + node_launch_plans=lp_specs, + tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates.items()}, ) assert len(fwf.outputs) == 1 assert list(fwf.interface.inputs.keys()) == ["b"] @@ -79,7 +82,10 @@ def wf2(b: int) -> int: task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) fwf = FlyteWorkflow.promote_from_model( - wf_spec.template, sub_workflows={}, node_launch_plans=lp_specs, tasks=task_templates + wf_spec.template, + sub_workflows={}, + node_launch_plans=lp_specs, + tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates.items()}, ) assert len(fwf.outputs) == 1 assert list(fwf.interface.inputs.keys()) == ["b"] @@ -111,7 +117,10 @@ def my_wf(a: int) -> str: task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) fwf = FlyteWorkflow.promote_from_model( - wf_spec.template, sub_workflows={}, node_launch_plans={}, tasks=task_templates + wf_spec.template, + sub_workflows={}, + node_launch_plans={}, + tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates.items()}, ) assert len(fwf.flyte_nodes[0].upstream_nodes) == 0 @@ -125,11 +134,14 @@ def parent(a: int) -> (str, str): serialized = OrderedDict() wf_spec = get_serializable(serialized, serialization_settings, parent) - sub_wf_dict = {s.id: s for s in wf_spec.sub_workflows} task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) + sub_wf_dict = {s.id: s for s in wf_spec.sub_workflows} fwf = FlyteWorkflow.promote_from_model( - wf_spec.template, sub_workflows=sub_wf_dict, node_launch_plans={}, tasks=task_templates + wf_spec.template, + sub_workflows=sub_wf_dict, + node_launch_plans={}, + tasks={k: FlyteTask.promote_from_model(v) for k, v in task_templates.items()}, ) # Test upstream nodes don't get confused by subworkflows assert len(fwf.flyte_nodes[0].upstream_nodes) == 0 From 604e9a615b75ce6864e2981cca8d9412a40410bf Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 21 Dec 2022 13:16:52 +0800 Subject: [PATCH 6/8] Register Databricks config (#1379) * Register databricks plugin Signed-off-by: Kevin Su * Update databricks plugin Signed-off-by: Kevin Su * register databricks Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su Signed-off-by: Kevin Su Co-authored-by: Yee Hing Tong --- plugins/flytekit-spark/flytekitplugins/spark/__init__.py | 2 +- plugins/flytekit-spark/flytekitplugins/spark/task.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py index 7e0c0b77e7..e769540aea 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py @@ -20,4 +20,4 @@ from .pyspark_transformers import PySparkPipelineModelTransformer from .schema import SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter, SparkDataFrameTransformer # noqa from .sd_transformers import ParquetToSparkDecodingHandler, SparkToParquetEncodingHandler -from .task import Spark, new_spark_session # noqa +from .task import Databricks, Spark, new_spark_session # noqa diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 180a28bb87..7b32e9f28b 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -118,7 +118,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: spark_type=SparkType.PYTHON, ) if isinstance(self.task_config, Databricks): - cfg = typing.cast(self.task_config, Databricks) + cfg = typing.cast(Databricks, self.task_config) job._databricks_conf = cfg.databricks_conf job._databricks_token = cfg.databricks_token job._databricks_instance = cfg.databricks_instance @@ -150,3 +150,4 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: # Inject the Spark plugin into flytekits dynamic plugin loading system TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask) +TaskPlugins.register_pythontask_plugin(Databricks, PysparkFunctionTask) From 9ed0c18adde78b227ab7e26e531913d5761377f9 Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Wed, 21 Dec 2022 14:50:54 -0800 Subject: [PATCH 7/8] PodSpec should not require primary_container name (#1380) For Pod tasks, if the primary_container_name is not specified, it should default. Signed-off-by: Ketan Umare --- .../flytekitplugins/pod/task.py | 42 +++++-------------- 1 file changed, 11 insertions(+), 31 deletions(-) diff --git a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py index e81728ddb4..c38ad33834 100644 --- a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py +++ b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Tuple, Union from flyteidl.core import tasks_pb2 as _core_task @@ -18,6 +19,7 @@ def _sanitize_resource_name(resource: _task_models.Resources.ResourceEntry) -> s return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") +@dataclass class Pod(object): """ Pod is a platform-wide configuration that uses pod templates. By default, every task is launched as a container in a pod. @@ -29,39 +31,17 @@ class Pod(object): :param Optional[Dict[str, str]] annotations: Annotations are key/value pairs that are attached to arbitrary non-identifying metadata to pod spec. """ - def __init__( - self, - pod_spec: V1PodSpec, - primary_container_name: str, - labels: Optional[Dict[str, str]] = None, - annotations: Optional[Dict[str, str]] = None, - ): - if not pod_spec: + pod_spec: V1PodSpec + primary_container_name: str = _PRIMARY_CONTAINER_NAME_FIELD + labels: Optional[Dict[str, str]] = None + annotations: Optional[Dict[str, str]] = None + + def __post_init_(self): + if not self.pod_spec: raise _user_exceptions.FlyteValidationException("A pod spec cannot be undefined") - if not primary_container_name: + if not self.primary_container_name: raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined") - self._pod_spec = pod_spec - self._primary_container_name = primary_container_name - self._labels = labels - self._annotations = annotations - - @property - def pod_spec(self) -> V1PodSpec: - return self._pod_spec - - @property - def primary_container_name(self) -> str: - return self._primary_container_name - - @property - def labels(self) -> Optional[Dict[str, str]]: - return self._labels - - @property - def annotations(self) -> Optional[Dict[str, str]]: - return self._annotations - class PodFunctionTask(PythonFunctionTask[Pod]): def __init__(self, task_config: Pod, task_function: Callable, **kwargs): @@ -114,7 +94,7 @@ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any] final_containers.append(container) - self.task_config._pod_spec.containers = final_containers + self.task_config.pod_spec.containers = final_containers return ApiClient().sanitize_for_serialization(self.task_config.pod_spec) From 46eae16b6fd9adc9ee45b5f9c6ee1e4ea948b810 Mon Sep 17 00:00:00 2001 From: mcloney-ddm <119345186+mcloney-ddm@users.noreply.github.com> Date: Thu, 22 Dec 2022 09:27:32 -0500 Subject: [PATCH 8/8] fix(pyflyte): change -d to -D for --destination-dir as -d is already for --domain (#1381) Co-authored-by: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> --- flytekit/clis/sdk_in_container/register.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index afb6d613fe..e1bf4eb5c3 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -67,7 +67,7 @@ help="Directory to write the output zip file containing the protobuf definitions", ) @click.option( - "-d", + "-D", "--destination-dir", required=False, type=str,