/
dagster_operator.py
141 lines (122 loc) · 5.02 KB
/
dagster_operator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import json
from airflow import __version__ as airflow_version
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from dagster_airflow.hooks.dagster_hook import DagsterHook
from dagster_airflow.links.dagster_link import LINK_FMT, DagsterLink
class DagsterOperator(BaseOperator):
"""DagsterOperator
Uses the dagster graphql api to run and monitor dagster jobs on remote dagster infrastructure
Parameters:
repository_name (str): the name of the repository to use
repostitory_location_name (str): the name of the repostitory location to use
job_name (str): the name of the job to run
run_config (Optional[Dict[str, Any]]): the run config to use for the job run
dagster_conn_id (Optional[str]): the id of the dagster connection, airflow 2.0+ only
organization_id (Optional[str]): the id of the dagster cloud organization
deployment_name (Optional[str]): the name of the dagster cloud deployment
user_token (Optional[str]): the dagster cloud user token to use
"""
template_fields = ["run_config"]
template_ext = (".yaml", ".yml", ".json")
ui_color = "#663399"
ui_fgcolor = "#e0e3fc"
operator_extra_links = (DagsterLink(),)
@apply_defaults
def __init__(
self,
dagster_conn_id="dagster_default",
run_config=None,
repository_name="",
repostitory_location_name="",
job_name="",
# params for airflow < 2.0.0 were custom connections aren't supported
deployment_name="prod",
user_token=None,
organization_id="",
url="https://dagster.cloud/",
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.run_id = None
self.dagster_conn_id = dagster_conn_id if airflow_version >= "2.0.0" else None
self.run_config = run_config or {}
self.repository_name = repository_name
self.repostitory_location_name = repostitory_location_name
self.job_name = job_name
self.user_token = user_token
self.url = url
self.organization_id = organization_id
self.deployment_name = deployment_name
self.hook = DagsterHook(
dagster_conn_id=self.dagster_conn_id,
user_token=self.user_token,
url=f"{self.url}{self.organization_id}/{self.deployment_name}/graphql",
)
def _is_json(self, blob):
try:
json.loads(blob)
except ValueError:
return False
return True
def pre_execute(self, context):
# force re-rendering to ensure run_config renders any templated
# content from run_config that couldn't be accessed on init
setattr(
self,
"run_config",
self.render_template(self.run_config, context),
)
def on_kill(self):
self.log.info("Terminating Run")
self.hook.terminate_run(
run_id=self.run_id,
)
def execute(self, context):
try:
return self._execute(context)
except Exception as e:
raise e
def _execute(self, context):
self.run_id = self.hook.launch_run(
repository_name=self.repository_name,
repostitory_location_name=self.repostitory_location_name,
job_name=self.job_name,
run_config=self.run_config,
)
# save relevant info in xcom for use in links
context["task_instance"].xcom_push(key="run_id", value=self.run_id)
context["task_instance"].xcom_push(
key="organization_id",
value=self.hook.organization_id if self.dagster_conn_id else self.organization_id,
)
context["task_instance"].xcom_push(
key="deployment_name",
value=self.hook.deployment_name if self.dagster_conn_id else self.deployment_name,
)
self.log.info("Run Starting....")
self.log.info(
"Run tracking: %s",
LINK_FMT.format(
organization_id=self.hook.organization_id,
deployment_name=self.hook.deployment_name,
run_id=self.run_id,
),
)
self.hook.wait_for_run(
run_id=self.run_id,
)
class DagsterCloudOperator(DagsterOperator):
"""DagsterCloudOperator
Uses the dagster cloud graphql api to run and monitor dagster jobs on dagster cloud
Parameters:
repository_name (str): the name of the repository to use
repostitory_location_name (str): the name of the repostitory location to use
job_name (str): the name of the job to run
run_config (Optional[Dict[str, Any]]): the run config to use for the job run
dagster_conn_id (Optional[str]): the id of the dagster connection, airflow 2.0+ only
organization_id (Optional[str]): the id of the dagster cloud organization
deployment_name (Optional[str]): the name of the dagster cloud deployment
user_token (Optional[str]): the dagster cloud user token to use
"""