Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Databricks config to Spark Job #1358

Merged
merged 12 commits into from Dec 19, 2022
19 changes: 17 additions & 2 deletions plugins/flytekit-spark/flytekitplugins/spark/models.py
@@ -1,5 +1,5 @@
import enum
import typing
from typing import Dict, Optional

from flyteidl.plugins import spark_pb2 as _spark_task

Expand All @@ -22,6 +22,7 @@ def __init__(
main_class,
spark_conf,
hadoop_conf,
databricks_conf,
executor_path,
):
"""
Expand All @@ -37,22 +38,30 @@ def __init__(
self._executor_path = executor_path
self._spark_conf = spark_conf
self._hadoop_conf = hadoop_conf
self._databricks_conf = databricks_conf

def with_overrides(
self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None
self,
new_spark_conf: Optional[Dict[str, str]] = None,
new_hadoop_conf: Optional[Dict[str, str]] = None,
new_databricks_conf: Optional[str] = None,
) -> "SparkJob":
if not new_spark_conf:
new_spark_conf = self.spark_conf

if not new_hadoop_conf:
new_hadoop_conf = self.hadoop_conf

if not new_databricks_conf:
new_databricks_conf = self.databricks_conf

return SparkJob(
spark_type=self.spark_type,
application_file=self.application_file,
main_class=self.main_class,
spark_conf=new_spark_conf,
hadoop_conf=new_hadoop_conf,
databricks_conf=new_databricks_conf,
executor_path=self.executor_path,
)

Expand Down Expand Up @@ -104,6 +113,10 @@ def hadoop_conf(self):
"""
return self._hadoop_conf

@property
def databricks_conf(self) -> str:
return self._databricks_conf

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
Expand All @@ -127,6 +140,7 @@ def to_flyte_idl(self):
executorPath=self.executor_path,
sparkConf=self.spark_conf,
hadoopConf=self.hadoop_conf,
databricksConf=self.databricks_conf,
)

@classmethod
Expand All @@ -151,4 +165,5 @@ def from_flyte_idl(cls, pb2_object):
main_class=pb2_object.mainClass,
hadoop_conf=pb2_object.hadoopConf,
executor_path=pb2_object.executorPath,
databricks_conf=pb2_object.databricksConf,
)
7 changes: 7 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
@@ -1,3 +1,5 @@
import base64
import json
import os
import typing
from dataclasses import dataclass
Expand Down Expand Up @@ -27,6 +29,7 @@ class Spark(object):

spark_conf: Optional[Dict[str, str]] = None
hadoop_conf: Optional[Dict[str, str]] = None
databricks_conf: typing.Optional[dict] = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: define the type of dict


def __post_init__(self):
if self.spark_conf is None:
Expand All @@ -35,6 +38,9 @@ def __post_init__(self):
if self.hadoop_conf is None:
self.hadoop_conf = {}

if self.databricks_conf is None:
self.databricks_conf = {}


# This method does not reset the SparkSession since it's a bit hard to handle multiple
# Spark sessions in a single application as it's described in:
Expand Down Expand Up @@ -95,6 +101,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
job = SparkJob(
spark_conf=self.task_config.spark_conf,
hadoop_conf=self.task_config.hadoop_conf,
databricks_conf=base64.b64encode(json.dumps(self.task_config.databricks_conf).encode()).decode(),
application_file="local://" + settings.entrypoint_settings.path,
executor_path=settings.python_interpreter,
main_class="",
Expand Down
20 changes: 19 additions & 1 deletion plugins/flytekit-spark/tests/test_spark_task.py
Expand Up @@ -19,14 +19,32 @@ def reset_spark_session() -> None:


def test_spark_task(reset_spark_session):
@task(task_config=Spark(spark_conf={"spark": "1"}))
databricks_conf = {
"name": "flytekit databricks plugin example",
"new_cluster": {
"spark_version": "11.0.x-scala2.12",
"node_type_id": "r3.xlarge",
"aws_attributes": {"availability": "ON_DEMAND"},
"num_workers": 4,
"docker_image": {"url": "pingsutw/databricks:latest"},
},
"timeout_seconds": 3600,
"max_retries": 1,
"spark_python_task": {
"python_file": "dbfs:///FileStore/tables/entrypoint-1.py",
"parameters": "ls",
},
}

@task(task_config=Spark(spark_conf={"spark": "1"}, databricks_conf=databricks_conf))
def my_spark(a: str) -> int:
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return 10

assert my_spark.task_config is not None
assert my_spark.task_config.spark_conf == {"spark": "1"}
assert my_spark.task_config.databricks_conf == databricks_conf

default_img = Image(name="default", fqn="test", tag="tag")
settings = SerializationSettings(
Expand Down