-
Notifications
You must be signed in to change notification settings - Fork 208
/
connections.py
158 lines (122 loc) · 4.26 KB
/
connections.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from contextlib import contextmanager
from dbt.adapters.base import Credentials
from dbt.adapters.sql import SQLConnectionManager
from dbt.logger import GLOBAL_LOGGER as logger
import dbt.exceptions
from pyhive import hive
from thrift.transport import THttpClient
import base64
SPARK_CONNECTION_URL = "https://{host}:{port}/sql/protocolv1/o/0/{cluster}"
SPARK_CREDENTIALS_CONTRACT = {
'type': 'object',
'additionalProperties': False,
'properties': {
'host': {
'type': 'string'
},
'port': {
'type': 'integer',
'minimum': 0,
'maximum': 65535,
},
'cluster': {
'type': 'string'
},
'database': {
'type': 'string',
},
'schema': {
'type': 'string',
},
'token': {
'type': 'string',
},
},
'required': ['host', 'database', 'schema', 'cluster'],
}
class SparkCredentials(Credentials):
SCHEMA = SPARK_CREDENTIALS_CONTRACT
def __init__(self, *args, **kwargs):
kwargs.setdefault('database', kwargs.get('schema'))
super(SparkCredentials, self).__init__(*args, **kwargs)
@property
def type(self):
return 'spark'
def _connection_keys(self):
return ('host', 'port', 'cluster', 'schema')
class ConnectionWrapper(object):
"Wrap a Spark connection in a way that no-ops transactions"
# https://forums.databricks.com/questions/2157/in-apache-spark-sql-can-we-roll-back-the-transacti.html
def __init__(self, handle):
self.handle = handle
self._cursor = None
self._fetch_result = None
def cursor(self):
self._cursor = self.handle.cursor()
return self
def cancel(self):
if self._cursor is not None:
self._cursor.cancel()
def close(self):
self.handle.close()
def rollback(self, *args, **kwargs):
logger.debug("NotImplemented: rollback")
def fetchall(self):
return self._cursor.fetchall()
def execute(self, sql, bindings=None):
if sql.strip().endswith(";"):
sql = sql.strip()[:-1]
return self._cursor.execute(sql, bindings)
@property
def description(self):
return self._cursor.description
class SparkConnectionManager(SQLConnectionManager):
TYPE = 'spark'
@contextmanager
def exception_handler(self, sql, connection_name='master'):
try:
yield
except Exception as exc:
logger.debug("Error while running:\n{}".format(sql))
logger.debug(exc)
if len(exc.args) == 0:
raise
thrift_resp = exc.args[0]
if hasattr(thrift_resp, 'status'):
msg = thrift_resp.status.errorMessage
raise dbt.exceptions.RuntimeException(msg)
else:
raise dbt.exceptions.RuntimeException(str(exc))
# No transactions on Spark....
def add_begin_query(self, *args, **kwargs):
logger.debug("NotImplemented: add_begin_query")
def add_commit_query(self, *args, **kwargs):
logger.debug("NotImplemented: add_commit_query")
def commit(self, *args, **kwargs):
logger.debug("NotImplemented: commit")
def rollback(self, *args, **kwargs):
logger.debug("NotImplemented: rollback")
@classmethod
def open(cls, connection):
if connection.state == 'open':
logger.debug('Connection is already open, skipping open.')
return connection
conn_url = SPARK_CONNECTION_URL.format(**connection.credentials)
transport = THttpClient.THttpClient(conn_url)
creds = "token:{}".format(connection.credentials['token']).encode()
token = base64.standard_b64encode(creds).decode()
transport.setCustomHeaders({
'Authorization': 'Basic {}'.format(token)
})
conn = hive.connect(thrift_transport=transport)
wrapped = ConnectionWrapper(conn)
connection.state = 'open'
connection.handle = wrapped
return connection
@classmethod
def get_status(cls, cursor):
#status = cursor._cursor.poll()
return 'OK'
def cancel(self, connection):
import ipdb; ipdb.set_trace()
connection.handle.cancel()