-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
client.py
398 lines (350 loc) · 17.9 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
from itertools import chain
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union
import dagster._check as check
import requests.exceptions
from dagster import DagsterRunStatus
from dagster._annotations import deprecated, public
from dagster._core.definitions.run_config import RunConfig, convert_config_input
from dagster._core.definitions.utils import normalize_tags
from gql import Client, gql
from gql.transport import Transport
from gql.transport.requests import RequestsHTTPTransport
from .client_queries import (
CLIENT_GET_REPO_LOCATIONS_NAMES_AND_PIPELINES_QUERY,
CLIENT_SUBMIT_PIPELINE_RUN_MUTATION,
GET_PIPELINE_RUN_STATUS_QUERY,
RELOAD_REPOSITORY_LOCATION_MUTATION,
SHUTDOWN_REPOSITORY_LOCATION_MUTATION,
TERMINATE_RUN_JOB_MUTATION,
)
from .utils import (
DagsterGraphQLClientError,
InvalidOutputErrorInfo,
JobInfo,
ReloadRepositoryLocationInfo,
ReloadRepositoryLocationStatus,
ShutdownRepositoryLocationInfo,
ShutdownRepositoryLocationStatus,
)
class DagsterGraphQLClient:
"""Official Dagster Python Client for GraphQL.
Utilizes the gql library to dispatch queries over HTTP to a remote Dagster GraphQL Server
As of now, all operations on this client are synchronous.
Intended usage:
.. code-block:: python
client = DagsterGraphQLClient("localhost", port_number=3000)
status = client.get_run_status(**SOME_RUN_ID**)
Args:
hostname (str): Hostname for the Dagster GraphQL API, like `localhost` or
`dagster.YOUR_ORG_HERE`.
port_number (Optional[int]): Port number to connect to on the host.
Defaults to None.
transport (Optional[Transport], optional): A custom transport to use to connect to the
GraphQL API with (e.g. for custom auth). Defaults to None.
use_https (bool, optional): Whether to use https in the URL connection string for the
GraphQL API. Defaults to False.
timeout (int): Number of seconds before requests should time out. Defaults to 60.
headers (Optional[Dict[str, str]]): Additional headers to include in the request. To use
this client in Dagster Cloud, set the "Dagster-Cloud-Api-Token" header to a user token
generated in the Dagster Cloud UI.
Raises:
:py:class:`~requests.exceptions.ConnectionError`: if the client cannot connect to the host.
"""
def __init__(
self,
hostname: str,
port_number: Optional[int] = None,
transport: Optional[Transport] = None,
use_https: bool = False,
timeout: int = 300,
headers: Optional[Dict[str, str]] = None,
):
self._hostname = check.str_param(hostname, "hostname")
self._port_number = check.opt_int_param(port_number, "port_number")
self._use_https = check.bool_param(use_https, "use_https")
self._url = (
("https://" if self._use_https else "http://")
+ (f"{self._hostname}:{self._port_number}" if self._port_number else self._hostname)
+ "/graphql"
)
self._transport = check.opt_inst_param(
transport,
"transport",
Transport,
default=RequestsHTTPTransport(
url=self._url, use_json=True, timeout=timeout, headers=headers
),
)
try:
self._client = Client(transport=self._transport, fetch_schema_from_transport=True)
except requests.exceptions.ConnectionError as exc:
raise DagsterGraphQLClientError(
f"Error when connecting to url {self._url}. "
+ f"Did you specify hostname: {self._hostname} "
+ (f"and port_number: {self._port_number} " if self._port_number else "")
+ "correctly?"
) from exc
def _execute(self, query: str, variables: Optional[Dict[str, Any]] = None):
try:
return self._client.execute(gql(query), variable_values=variables)
except Exception as exc: # catch generic Exception from the gql client
raise DagsterGraphQLClientError(
f"Exception occured during execution of query \n{query}\n with variables"
f" \n{variables}\n"
) from exc
def _get_repo_locations_and_names_with_pipeline(self, job_name: str) -> List[JobInfo]:
res_data = self._execute(CLIENT_GET_REPO_LOCATIONS_NAMES_AND_PIPELINES_QUERY)
query_res = res_data["repositoriesOrError"]
repo_connection_status = query_res["__typename"]
if repo_connection_status == "RepositoryConnection":
valid_nodes: Iterable[JobInfo] = chain(*map(JobInfo.from_node, query_res["nodes"]))
return [info for info in valid_nodes if info.job_name == job_name]
else:
raise DagsterGraphQLClientError(repo_connection_status, query_res["message"])
def _core_submit_execution(
self,
pipeline_name: str,
repository_location_name: Optional[str] = None,
repository_name: Optional[str] = None,
run_config: Optional[Union[RunConfig, Mapping[str, Any]]] = None,
mode: str = "default",
preset: Optional[str] = None,
tags: Optional[Mapping[str, str]] = None,
op_selection: Optional[Sequence[str]] = None,
is_using_job_op_graph_apis: Optional[bool] = False,
):
check.opt_str_param(repository_location_name, "repository_location_name")
check.opt_str_param(repository_name, "repository_name")
check.str_param(pipeline_name, "pipeline_name")
check.opt_str_param(mode, "mode")
check.opt_str_param(preset, "preset")
run_config = check.opt_mapping_param(convert_config_input(run_config), "run_config")
# The following invariant will never fail when a job is executed
check.invariant(
(mode is not None and run_config is not None) or preset is not None,
"Either a mode and run_config or a preset must be specified in order to "
f"submit the pipeline {pipeline_name} for execution",
)
tags = normalize_tags(tags).tags
pipeline_or_job = "Job" if is_using_job_op_graph_apis else "Pipeline"
if not repository_location_name or not repository_name:
job_info_lst = self._get_repo_locations_and_names_with_pipeline(pipeline_name)
if len(job_info_lst) == 0:
raise DagsterGraphQLClientError(
f"{pipeline_or_job}NotFoundError",
f"No {'jobs' if is_using_job_op_graph_apis else 'pipelines'} with the name"
f" `{pipeline_name}` exist",
)
elif len(job_info_lst) == 1:
job_info = job_info_lst[0]
repository_location_name = job_info.repository_location_name
repository_name = job_info.repository_name
else:
raise DagsterGraphQLClientError(
"Must specify repository_location_name and repository_name since there are"
f" multiple {'jobs' if is_using_job_op_graph_apis else 'pipelines'} with the"
f" name {pipeline_name}.\n\tchoose one of: {job_info_lst}"
)
variables: Dict[str, Any] = {
"executionParams": {
"selector": {
"repositoryLocationName": repository_location_name,
"repositoryName": repository_name,
"pipelineName": pipeline_name,
"solidSelection": op_selection,
}
}
}
if preset is not None:
variables["executionParams"]["preset"] = preset
if mode is not None and run_config is not None:
variables["executionParams"] = {
**variables["executionParams"],
"runConfigData": run_config,
"mode": mode,
"executionMetadata": (
{"tags": [{"key": k, "value": v} for k, v in tags.items()]} if tags else {}
),
}
res_data: Dict[str, Any] = self._execute(CLIENT_SUBMIT_PIPELINE_RUN_MUTATION, variables)
query_result = res_data["launchPipelineExecution"]
query_result_type = query_result["__typename"]
if (
query_result_type == "LaunchRunSuccess"
or query_result_type == "LaunchPipelineRunSuccess"
):
return query_result["run"]["runId"]
elif query_result_type == "InvalidStepError":
raise DagsterGraphQLClientError(query_result_type, query_result["invalidStepKey"])
elif query_result_type == "InvalidOutputError":
error_info = InvalidOutputErrorInfo(
step_key=query_result["stepKey"],
invalid_output_name=query_result["invalidOutputName"],
)
raise DagsterGraphQLClientError(query_result_type, body=error_info)
elif (
query_result_type == "RunConfigValidationInvalid"
or query_result_type == "PipelineConfigValidationInvalid"
):
raise DagsterGraphQLClientError(query_result_type, query_result["errors"])
else:
# query_result_type is a ConflictingExecutionParamsError, a PresetNotFoundError
# a PipelineNotFoundError, a RunConflict, an UnauthorizedError or a PythonError
raise DagsterGraphQLClientError(query_result_type, query_result["message"])
@public
def submit_job_execution(
self,
job_name: str,
repository_location_name: Optional[str] = None,
repository_name: Optional[str] = None,
run_config: Optional[Union[RunConfig, Mapping[str, Any]]] = None,
tags: Optional[Dict[str, Any]] = None,
op_selection: Optional[Sequence[str]] = None,
) -> str:
"""Submits a job with attached configuration for execution.
Args:
job_name (str): The job's name
repository_location_name (Optional[str]): The name of the repository location where
the job is located. If omitted, the client will try to infer the repository location
from the available options on the Dagster deployment. Defaults to None.
repository_name (Optional[str]): The name of the repository where the job is located.
If omitted, the client will try to infer the repository from the available options
on the Dagster deployment. Defaults to None.
run_config (Optional[Union[RunConfig, Mapping[str, Any]]]): This is the run config to execute the job with.
Note that runConfigData is any-typed in the GraphQL type system. This type is used when passing in
an arbitrary object for run config. However, it must conform to the constraints of the config
schema for this job. If it does not, the client will throw a DagsterGraphQLClientError with a message of
JobConfigValidationInvalid. Defaults to None.
tags (Optional[Dict[str, Any]]): A set of tags to add to the job execution.
Raises:
DagsterGraphQLClientError("InvalidStepError", invalid_step_key): the job has an invalid step
DagsterGraphQLClientError("InvalidOutputError", body=error_object): some solid has an invalid output within the job.
The error_object is of type dagster_graphql.InvalidOutputErrorInfo.
DagsterGraphQLClientError("RunConflict", message): a `DagsterRunConflict` occured during execution.
This indicates that a conflicting job run already exists in run storage.
DagsterGraphQLClientError("PipelineConfigurationInvalid", invalid_step_key): the run_config is not in the expected format
for the job
DagsterGraphQLClientError("JobNotFoundError", message): the requested job does not exist
DagsterGraphQLClientError("PythonError", message): an internal framework error occurred
Returns:
str: run id of the submitted pipeline run
"""
return self._core_submit_execution(
pipeline_name=job_name,
repository_location_name=repository_location_name,
repository_name=repository_name,
run_config=run_config,
mode="default",
preset=None,
tags=tags,
op_selection=op_selection,
is_using_job_op_graph_apis=True,
)
@public
def get_run_status(self, run_id: str) -> DagsterRunStatus:
"""Get the status of a given Pipeline Run.
Args:
run_id (str): run id of the requested pipeline run.
Raises:
DagsterGraphQLClientError("PipelineNotFoundError", message): if the requested run id is not found
DagsterGraphQLClientError("PythonError", message): on internal framework errors
Returns:
DagsterRunStatus: returns a status Enum describing the state of the requested pipeline run
"""
check.str_param(run_id, "run_id")
res_data: Dict[str, Dict[str, Any]] = self._execute(
GET_PIPELINE_RUN_STATUS_QUERY, {"runId": run_id}
)
query_result: Dict[str, Any] = res_data["pipelineRunOrError"]
query_result_type: str = query_result["__typename"]
if query_result_type == "PipelineRun" or query_result_type == "Run":
return DagsterRunStatus(query_result["status"])
else:
raise DagsterGraphQLClientError(query_result_type, query_result["message"])
@public
def reload_repository_location(
self, repository_location_name: str
) -> ReloadRepositoryLocationInfo:
"""Reloads a Dagster Repository Location, which reloads all repositories in that repository location.
This is useful in a variety of contexts, including refreshing the Dagster UI without restarting
the server.
Args:
repository_location_name (str): The name of the repository location
Returns:
ReloadRepositoryLocationInfo: Object with information about the result of the reload request
"""
check.str_param(repository_location_name, "repository_location_name")
res_data: Dict[str, Dict[str, Any]] = self._execute(
RELOAD_REPOSITORY_LOCATION_MUTATION,
{"repositoryLocationName": repository_location_name},
)
query_result: Dict[str, Any] = res_data["reloadRepositoryLocation"]
query_result_type: str = query_result["__typename"]
if query_result_type == "WorkspaceLocationEntry":
location_or_error_type = query_result["locationOrLoadError"]["__typename"]
if location_or_error_type == "RepositoryLocation":
return ReloadRepositoryLocationInfo(status=ReloadRepositoryLocationStatus.SUCCESS)
else:
return ReloadRepositoryLocationInfo(
status=ReloadRepositoryLocationStatus.FAILURE,
failure_type="PythonError",
message=query_result["locationOrLoadError"]["message"],
)
else:
# query_result_type is either ReloadNotSupported or RepositoryLocationNotFound
return ReloadRepositoryLocationInfo(
status=ReloadRepositoryLocationStatus.FAILURE,
failure_type=query_result_type,
message=query_result["message"],
)
@deprecated(breaking_version="2.0")
@public
def shutdown_repository_location(
self, repository_location_name: str
) -> ShutdownRepositoryLocationInfo:
"""Shuts down the server that is serving metadata for the provided repository location.
This is primarily useful when you want the server to be restarted by the compute environment
in which it is running (for example, in Kubernetes, the pod in which the server is running
will automatically restart when the server is shut down, and the repository metadata will
be reloaded)
Args:
repository_location_name (str): The name of the repository location
Returns:
ShutdownRepositoryLocationInfo: Object with information about the result of the reload request
"""
check.str_param(repository_location_name, "repository_location_name")
res_data: Dict[str, Dict[str, Any]] = self._execute(
SHUTDOWN_REPOSITORY_LOCATION_MUTATION,
{"repositoryLocationName": repository_location_name},
)
query_result: Dict[str, Any] = res_data["shutdownRepositoryLocation"]
query_result_type: str = query_result["__typename"]
if query_result_type == "ShutdownRepositoryLocationSuccess":
return ShutdownRepositoryLocationInfo(status=ShutdownRepositoryLocationStatus.SUCCESS)
elif (
query_result_type == "RepositoryLocationNotFound" or query_result_type == "PythonError"
):
return ShutdownRepositoryLocationInfo(
status=ShutdownRepositoryLocationStatus.FAILURE,
message=query_result["message"],
)
else:
raise Exception(f"Unexpected query result type {query_result_type}")
def terminate_run(self, run_id: str):
"""Terminates a pipeline run. This method it is useful when you would like to stop a pipeline run
based on a external event.
Args:
run_id (str): The run id of the pipeline run to terminate
"""
check.str_param(run_id, "run_id")
res_data: Dict[str, Dict[str, Any]] = self._execute(
TERMINATE_RUN_JOB_MUTATION, {"runId": run_id}
)
query_result: Dict[str, Any] = res_data["terminateRun"]
query_result_type: str = query_result["__typename"]
if query_result_type == "TerminateRunSuccess":
return
elif query_result_type == "RunNotFoundError":
raise DagsterGraphQLClientError("RunNotFoundError", f"Run Id {run_id} not found")
else:
raise DagsterGraphQLClientError(query_result_type, query_result["message"])