diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 4c96e9f011..4c16153be9 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -80,7 +80,7 @@ jobs: - flytekit-modin - flytekit-onnx-pytorch - flytekit-onnx-scikitlearn - # onnxx-tensorflow needs a version of tensorflow that does not work with protobuf>4. + # onnx-tensorflow needs a version of tensorflow that does not work with protobuf>4. # The issue is being tracked on the tensorflow side in https://github.com/tensorflow/tensorflow/issues/53234#issuecomment-1330111693 # flytekit-onnx-tensorflow - flytekit-pandera diff --git a/dev-requirements.txt b/dev-requirements.txt index a9de992c14..14510b09cb 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -130,7 +130,7 @@ filelock==3.8.2 # via virtualenv flatbuffers==22.12.6 # via tensorflow -flyteidl==1.3.0 +flyteidl==1.3.1 # via # -c requirements.txt # flytekit diff --git a/doc-requirements.txt b/doc-requirements.txt index be5d737bfd..7c92fcb018 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -182,7 +182,7 @@ filelock==3.8.2 # virtualenv flatbuffers==22.12.6 # via tensorflow -flyteidl==1.3.0 +flyteidl==1.3.1 # via flytekit fonttools==4.38.0 # via matplotlib diff --git a/plugins/flytekit-spark/flytekitplugins/spark/models.py b/plugins/flytekit-spark/flytekitplugins/spark/models.py index 53b1620331..28e67ac631 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/models.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/models.py @@ -1,7 +1,9 @@ import enum -import typing +from typing import Dict, Optional from flyteidl.plugins import spark_pb2 as _spark_task +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Struct from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common @@ -17,12 +19,15 @@ class SparkType(enum.Enum): class SparkJob(_common.FlyteIdlEntity): def __init__( self, - spark_type, - application_file, - main_class, - spark_conf, - hadoop_conf, - executor_path, + spark_type: SparkType, + application_file: str, + main_class: str, + spark_conf: Dict[str, str], + hadoop_conf: Dict[str, str], + executor_path: str, + databricks_conf: Dict[str, Dict[str, Dict]] = {}, + databricks_token: Optional[str] = None, + databricks_instance: Optional[str] = None, ): """ This defines a SparkJob target. It will execute the appropriate SparkJob. @@ -30,6 +35,9 @@ def __init__( :param application_file: The main application file to execute. :param dict[Text, Text] spark_conf: A definition of key-value pairs for spark config for the job. :param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job. + :param Optional[dict[Text, dict]] databricks_conf: A definition of key-value pairs for databricks config for the job. Refer to https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit. + :param Optional[str] databricks_token: databricks access token. + :param Optional[str] databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. """ self._application_file = application_file self._spark_type = spark_type @@ -37,9 +45,15 @@ def __init__( self._executor_path = executor_path self._spark_conf = spark_conf self._hadoop_conf = hadoop_conf + self._databricks_conf = databricks_conf + self._databricks_token = databricks_token + self._databricks_instance = databricks_instance def with_overrides( - self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None + self, + new_spark_conf: Optional[Dict[str, str]] = None, + new_hadoop_conf: Optional[Dict[str, str]] = None, + new_databricks_conf: Optional[Dict[str, Dict]] = None, ) -> "SparkJob": if not new_spark_conf: new_spark_conf = self.spark_conf @@ -47,12 +61,18 @@ def with_overrides( if not new_hadoop_conf: new_hadoop_conf = self.hadoop_conf + if not new_databricks_conf: + new_databricks_conf = self.databricks_conf + return SparkJob( spark_type=self.spark_type, application_file=self.application_file, main_class=self.main_class, spark_conf=new_spark_conf, hadoop_conf=new_hadoop_conf, + databricks_conf=new_databricks_conf, + databricks_token=self.databricks_token, + databricks_instance=self.databricks_instance, executor_path=self.executor_path, ) @@ -104,6 +124,31 @@ def hadoop_conf(self): """ return self._hadoop_conf + @property + def databricks_conf(self) -> Dict[str, Dict]: + """ + databricks_conf: Databricks job configuration. + Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure + :rtype: dict[Text, dict[Text, Text]] + """ + return self._databricks_conf + + @property + def databricks_token(self) -> str: + """ + Databricks access token + :rtype: str + """ + return self._databricks_token + + @property + def databricks_instance(self) -> str: + """ + Domain name of your deployment. Use the form .cloud.databricks.com. + :rtype: str + """ + return self._databricks_instance + def to_flyte_idl(self): """ :rtype: flyteidl.plugins.spark_pb2.SparkJob @@ -120,6 +165,9 @@ def to_flyte_idl(self): else: raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified") + databricks_conf = Struct() + databricks_conf.update(self.databricks_conf) + return _spark_task.SparkJob( applicationType=application_type, mainApplicationFile=self.application_file, @@ -127,6 +175,9 @@ def to_flyte_idl(self): executorPath=self.executor_path, sparkConf=self.spark_conf, hadoopConf=self.hadoop_conf, + databricksConf=databricks_conf, + databricksToken=self.databricks_token, + databricksInstance=self.databricks_instance, ) @classmethod @@ -151,4 +202,7 @@ def from_flyte_idl(cls, pb2_object): main_class=pb2_object.mainClass, hadoop_conf=pb2_object.hadoopConf, executor_path=pb2_object.executorPath, + databricks_conf=json_format.MessageToDict(pb2_object.databricksConf), + databricks_token=pb2_object.databricksToken, + databricks_instance=pb2_object.databricksInstance, ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 8428b492ce..180a28bb87 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -36,6 +36,23 @@ def __post_init__(self): self.hadoop_conf = {} +@dataclass +class Databricks(Spark): + """ + Use this to configure a Databricks task. Task's marked with this will automatically execute + natively onto databricks platform as a distributed execution of spark + + Args: + databricks_conf: Databricks job configuration. Config structure can be found here. https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure + databricks_token: Databricks access token. https://docs.databricks.com/dev-tools/api/latest/authentication.html. + databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. + """ + + databricks_conf: typing.Optional[Dict[str, typing.Union[str, dict]]] = None + databricks_token: Optional[str] = None + databricks_instance: Optional[str] = None + + # This method does not reset the SparkSession since it's a bit hard to handle multiple # Spark sessions in a single application as it's described in: # https://stackoverflow.com/questions/41491972/how-can-i-tear-down-a-sparksession-and-create-a-new-one-within-one-application. @@ -100,6 +117,12 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: main_class="", spark_type=SparkType.PYTHON, ) + if isinstance(self.task_config, Databricks): + cfg = typing.cast(self.task_config, Databricks) + job._databricks_conf = cfg.databricks_conf + job._databricks_token = cfg.databricks_token + job._databricks_instance = cfg.databricks_instance + return MessageToDict(job.to_flyte_idl()) def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: diff --git a/plugins/flytekit-spark/requirements.txt b/plugins/flytekit-spark/requirements.txt index a8df06cc24..b7087441ef 100644 --- a/plugins/flytekit-spark/requirements.txt +++ b/plugins/flytekit-spark/requirements.txt @@ -46,7 +46,7 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.0 +flyteidl==1.3.1 # via flytekit flytekit==1.3.0b2 # via flytekitplugins-spark diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 38555fa9b8..7049684a2d 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -2,7 +2,7 @@ import pyspark import pytest from flytekitplugins.spark import Spark -from flytekitplugins.spark.task import new_spark_session +from flytekitplugins.spark.task import Databricks, new_spark_session from pyspark.sql import SparkSession import flytekit @@ -19,6 +19,23 @@ def reset_spark_session() -> None: def test_spark_task(reset_spark_session): + databricks_conf = { + "name": "flytekit databricks plugin example", + "new_cluster": { + "spark_version": "11.0.x-scala2.12", + "node_type_id": "r3.xlarge", + "aws_attributes": {"availability": "ON_DEMAND"}, + "num_workers": 4, + "docker_image": {"url": "pingsutw/databricks:latest"}, + }, + "timeout_seconds": 3600, + "max_retries": 1, + "spark_python_task": { + "python_file": "dbfs:///FileStore/tables/entrypoint-1.py", + "parameters": "ls", + }, + } + @task(task_config=Spark(spark_conf={"spark": "1"})) def my_spark(a: str) -> int: session = flytekit.current_context().spark_session @@ -53,6 +70,28 @@ def my_spark(a: str) -> int: assert ("spark", "1") in configs assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs + databricks_token = "token" + databricks_instance = "account.cloud.databricks.com" + + @task( + task_config=Databricks( + spark_conf={"spark": "2"}, + databricks_conf=databricks_conf, + databricks_instance="account.cloud.databricks.com", + databricks_token="token", + ) + ) + def my_databricks(a: str) -> int: + session = flytekit.current_context().spark_session + assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" + return 10 + + assert my_databricks.task_config is not None + assert my_databricks.task_config.spark_conf == {"spark": "2"} + assert my_databricks.task_config.databricks_conf == databricks_conf + assert my_databricks.task_config.databricks_instance == databricks_instance + assert my_databricks.task_config.databricks_token == databricks_token + def test_new_spark_session(): name = "SessionName" diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 3c07cb3cee..c6d0ff7fc0 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -52,7 +52,7 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.0 +flyteidl==1.3.1 # via flytekit googleapis-common-protos==1.57.0 # via diff --git a/requirements.txt b/requirements.txt index caff0db497..5623078a25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,7 +50,7 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.0 +flyteidl==1.3.1 # via flytekit googleapis-common-protos==1.57.0 # via