Skip to content

Commit

Permalink
Merge branch 'master' into add-tls-auth
Browse files Browse the repository at this point in the history
  • Loading branch information
donbowman committed Mar 13, 2019
2 parents e314772 + fa7f4f2 commit 6248f4d
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ __pycache__
*~
.tox
env
venv
32 changes: 25 additions & 7 deletions pydruid/db/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,24 @@ class Type(object):
BOOLEAN = 3


def connect(host='localhost', port=8082, path='/druid/v2/sql/', scheme='http',
user=None, password=None):
def connect(
host='localhost',
port=8082,
path='/druid/v2/sql/',
scheme='http',
user=None,
password=None,
context=None,
):
"""
Constructor for creating a connection to the database.
>>> conn = connect('localhost', 8082)
>>> curs = conn.cursor()
"""
return Connection(host, port, path, scheme, user, password)
context = context or {}
return Connection(host, port, path, scheme, user, password, context)


def check_closed(f):
Expand Down Expand Up @@ -100,10 +108,12 @@ def __init__(
scheme='http',
user=None,
password=None,
context=None,
):
netloc = '{host}:{port}'.format(host=host, port=port)
self.url = parse.urlunparse(
(scheme, netloc, path, None, None, None))
self.context = context or {}
self.closed = False
self.cursors = []
self.user = user
Expand Down Expand Up @@ -131,7 +141,7 @@ def commit(self):
@check_closed
def cursor(self):
"""Return a new Cursor Object using the connection."""
cursor = Cursor(self.url, self.user, self.password)
cursor = Cursor(self.url, self.user, self.password, self.context)
self.cursors.append(cursor)

return cursor
Expand All @@ -152,10 +162,11 @@ class Cursor(object):

"""Connection cursor."""

def __init__(self, url, user=None, password=None):
def __init__(self, url, user=None, password=None, context=None):
self.url = url
self.user = user
self.password = password
self.context = context or {}

# This read/write attribute specifies the number of rows to fetch at a
# time with .fetchmany(). It defaults to 1 meaning to fetch a single
Expand Down Expand Up @@ -269,7 +280,7 @@ def _stream_query(self, query):
self.description = None

headers = {'Content-Type': 'application/json'}
payload = {'query': query}
payload = {'query': query, 'context': self.context}
auth = requests.auth.HTTPBasicAuth(self.user,
self.password) if self.user else None
r = requests.post(self.url, stream=True, headers=headers, json=payload,
Expand All @@ -279,7 +290,14 @@ def _stream_query(self, query):

# raise any error messages
if r.status_code != 200:
payload = r.json()
try:
payload = r.json()
except Exception:
payload = {
'error': 'Unknown error',
'errorClass': 'Unknown',
'errorMessage': r.text,
}
msg = (
'{error} ({errorClass}): {errorMessage}'.format(**payload)
)
Expand Down
5 changes: 5 additions & 0 deletions pydruid/db/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ class DruidDialect(default.DefaultDialect):
description_encoding = None
supports_native_boolean = True

def __init__(self, context=None, *args, **kwargs):
super(DruidDialect, self).__init__(*args, **kwargs)
self.context = context or {}

@classmethod
def dbapi(cls):
return pydruid.db
Expand All @@ -122,6 +126,7 @@ def create_connect_args(self, url):
'password': url.password or None,
'path': url.database,
'scheme': self.scheme,
'context': self.context,
}
return ([], kwargs)

Expand Down
21 changes: 21 additions & 0 deletions tests/db/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,27 @@ def test_execute_empty_result(self, requests_post_mock):
expected = []
self.assertEquals(result, expected)

@patch('requests.post')
def test_context(self, requests_post_mock):
response = Response()
response.status_code = 200
response.raw = BytesIO(b'[]')
requests_post_mock.return_value = response

url = 'http://example.com/'
query = 'SELECT * FROM table'
context = {'source': 'unittest'}

cursor = Cursor(url, context)
cursor.execute(query)

requests_post_mock.assert_called_with(
'http://example.com/',
stream=True,
headers={'Content-Type': 'application/json'},
json={'query': query, 'context': context},
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 6248f4d

Please sign in to comment.