-
Notifications
You must be signed in to change notification settings - Fork 256
/
task.py
227 lines (186 loc) · 10.3 KB
/
task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Union, cast
import click
from google.protobuf.json_format import MessageToDict
from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger
from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.extend import ExecutionState, TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec
from .models import SparkJob, SparkType
pyspark_sql = lazy_module("pyspark.sql")
SparkSession = pyspark_sql.SparkSession
@dataclass
class Spark(object):
"""
Use this to configure a SparkContext for a your task. Task's marked with this will automatically execute
natively onto K8s as a distributed execution of spark
Args:
spark_conf: Dictionary of spark config. The variables should match what spark expects
hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark
executor_path: Python binary executable to use for PySpark in driver and executor.
applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute.
"""
spark_conf: Optional[Dict[str, str]] = None
hadoop_conf: Optional[Dict[str, str]] = None
executor_path: Optional[str] = None
applications_path: Optional[str] = None
def __post_init__(self):
if self.spark_conf is None:
self.spark_conf = {}
if self.hadoop_conf is None:
self.hadoop_conf = {}
@dataclass
class Databricks(Spark):
"""
Deprecated. Use DatabricksV2 instead.
"""
databricks_conf: Optional[Dict[str, Union[str, dict]]] = None
databricks_instance: Optional[str] = None
def __post_init__(self):
logger.warn(
"Databricks is deprecated. Use 'from flytekitplugins.spark import Databricks' instead,"
"and make sure to upgrade the version of flyteagent deployment to >v1.13.0.",
)
@dataclass
class DatabricksV2(Spark):
"""
Use this to configure a Databricks task. Task's marked with this will automatically execute
natively onto databricks platform as a distributed execution of spark
Args:
databricks_conf: Databricks job configuration compliant with API version 2.1, supporting 2.0 use cases.
For the configuration structure, visit here.https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure
For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html
databricks_instance: Domain name of your deployment. Use the form <account>.cloud.databricks.com.
"""
databricks_conf: Optional[Dict[str, Union[str, dict]]] = None
databricks_instance: Optional[str] = None
# This method does not reset the SparkSession since it's a bit hard to handle multiple
# Spark sessions in a single application as it's described in:
# https://stackoverflow.com/questions/41491972/how-can-i-tear-down-a-sparksession-and-create-a-new-one-within-one-application.
def new_spark_session(name: str, conf: Dict[str, str] = None):
"""
Optionally creates a new spark session and returns it.
In cluster mode (running in hosted flyte, this will disregard the spark conf passed in)
This method is safe to be used from any other method. That is one reason why, we have duplicated this code
fragment with the pre-execute. For example in the notebook scenario we might want to call it from a separate kernel
"""
import pyspark as _pyspark
# We run in cluster-mode in Flyte.
# Ref https://github.com/lyft/flyteplugins/blob/master/go/tasks/v1/flytek8s/k8s_resource_adds.go#L46
sess_builder = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {name}")
if "FLYTE_INTERNAL_EXECUTION_ID" not in os.environ and conf is not None:
# If either of above cases is not true, then we are in local execution of this task
# Add system spark-conf for local/notebook based execution.
sess_builder = sess_builder.master("local[*]")
spark_conf = _pyspark.SparkConf()
for k, v in conf.items():
spark_conf.set(k, v)
spark_conf.set("spark.driver.bindAddress", "127.0.0.1")
# In local execution, propagate PYTHONPATH to executors too. This makes the spark
# execution hermetic to the execution environment. For example, it allows running
# Spark applications using Bazel, without major changes.
if "PYTHONPATH" in os.environ:
spark_conf.setExecutorEnv("PYTHONPATH", os.environ["PYTHONPATH"])
sess_builder = sess_builder.config(conf=spark_conf)
# If there is a global SparkSession available, get it and try to stop it.
_pyspark.sql.SparkSession.builder.getOrCreate().stop()
return sess_builder.getOrCreate()
# SparkSession.Stop does not work correctly, as it stops the session before all the data is written
# sess.stop()
class PysparkFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[Spark]):
"""
Actual Plugin that transforms the local python code for execution within a spark context
"""
_SPARK_TASK_TYPE = "spark"
def __init__(
self,
task_config: Spark,
task_function: Callable,
container_image: Optional[Union[str, ImageSpec]] = None,
**kwargs,
):
self.sess: Optional[SparkSession] = None
self._default_executor_path: str = task_config.executor_path
self._default_applications_path: str = task_config.applications_path
if isinstance(container_image, ImageSpec):
if container_image.base_image is None:
img = f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
container_image.base_image = img
# default executor path and applications path in apache/spark-py:3.3.1
self._default_executor_path = self._default_executor_path or "/usr/bin/python3"
self._default_applications_path = (
self._default_applications_path or "local:///usr/local/bin/entrypoint.py"
)
if isinstance(task_config, DatabricksV2):
task_type = "databricks"
else:
task_type = "spark"
super(PysparkFunctionTask, self).__init__(
task_config=task_config,
task_type=task_type,
task_function=task_function,
container_image=container_image,
**kwargs,
)
def get_image(self, settings: SerializationSettings) -> str:
if isinstance(self.container_image, ImageSpec):
# Ensure that the code is always copied into the image, even during fast-registration.
self.container_image.source_root = settings.source_root
return get_registerable_container_image(self.container_image, settings.image_config)
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
job = SparkJob(
spark_conf=self.task_config.spark_conf,
hadoop_conf=self.task_config.hadoop_conf,
application_file=self._default_applications_path or "local://" + settings.entrypoint_settings.path,
executor_path=self._default_executor_path or settings.python_interpreter,
main_class="",
spark_type=SparkType.PYTHON,
)
if isinstance(self.task_config, (Databricks, DatabricksV2)):
cfg = cast(DatabricksV2, self.task_config)
job._databricks_conf = cfg.databricks_conf
job._databricks_instance = cfg.databricks_instance
return MessageToDict(job.to_flyte_idl())
def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
import pyspark as _pyspark
ctx = FlyteContextManager.current_context()
sess_builder = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {user_params.execution_id}")
if not (ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION):
# If either of above cases is not true, then we are in local execution of this task
# Add system spark-conf for local/notebook based execution.
spark_conf = _pyspark.SparkConf()
spark_conf.set("spark.driver.bindAddress", "127.0.0.1")
for k, v in self.task_config.spark_conf.items():
spark_conf.set(k, v)
# In local execution, propagate PYTHONPATH to executors too. This makes the spark
# execution hermetic to the execution environment. For example, it allows running
# Spark applications using Bazel, without major changes.
if "PYTHONPATH" in os.environ:
spark_conf.setExecutorEnv("PYTHONPATH", os.environ["PYTHONPATH"])
sess_builder = sess_builder.config(conf=spark_conf)
self.sess = sess_builder.getOrCreate()
return user_params.builder().add_attr("SPARK_SESSION", self.sess).build()
def execute(self, **kwargs) -> Any:
if isinstance(self.task_config, (Databricks, DatabricksV2)):
# Use the Databricks agent to run it by default.
try:
ctx = FlyteContextManager.current_context()
if not ctx.file_access.is_remote(ctx.file_access.raw_output_prefix):
raise ValueError(
"To submit a Databricks job locally,"
" please set --raw-output-data-prefix to a remote path. e.g. s3://, gcs//, etc."
)
if ctx.execution_state and ctx.execution_state.is_local_execution():
return AsyncAgentExecutorMixin.execute(self, **kwargs)
except Exception as e:
click.secho(f"❌ Agent failed to run the task with error: {e}", fg="red")
click.secho("Falling back to local execution", fg="red")
return PythonFunctionTask.execute(self, **kwargs)
# Inject the Spark plugin into flytekits dynamic plugin loading system
TaskPlugins.register_pythontask_plugin(Spark, PysparkFunctionTask)
TaskPlugins.register_pythontask_plugin(Databricks, PysparkFunctionTask)
TaskPlugins.register_pythontask_plugin(DatabricksV2, PysparkFunctionTask)