diff --git a/pyhive/common.py b/pyhive/common.py index d4aa0771..590a8f1b 100644 --- a/pyhive/common.py +++ b/pyhive/common.py @@ -23,6 +23,7 @@ class DBAPICursor(with_metaclass(abc.ABCMeta, object)): _STATE_NONE = 0 _STATE_RUNNING = 1 _STATE_FINISHED = 2 + _STATE_CANCELLED = 3 def __init__(self, poll_interval=1): self._poll_interval = poll_interval diff --git a/pyhive/presto.py b/pyhive/presto.py index 2236dbb3..c5a413ad 100644 --- a/pyhive/presto.py +++ b/pyhive/presto.py @@ -172,6 +172,16 @@ def execute(self, operation, parameters=None): response = requests.post(url, data=sql.encode('utf-8'), headers=headers) self._process_response(response) + def cancel(self): + if self._state == self._STATE_NONE: + raise ProgrammingError("No query yet") + if self._nextUri is None: + assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None" + return None + response = requests.delete(self._nextUri) + self._process_response(response) + return response.status_code + def poll(self): """Poll for and return the raw status data provided by the Presto REST API. @@ -207,12 +217,18 @@ def _process_response(self, response): URI and any data from the response """ # TODO handle HTTP 503 - if response.status_code != requests.codes.ok: + if response.status_code not in (requests.codes.ok, requests.codes.no_content): fmt = "Unexpected status code {}\n{}" raise OperationalError(fmt.format(response.status_code, response.content)) + + if response.status_code == requests.codes.no_content: + self._state = self._STATE_CANCELLED + return + response_json = response.json() _logger.debug("Got response %s", response_json) - assert self._state == self._STATE_RUNNING, "Should be running if processing response" + assert self._state in (self._STATE_RUNNING, self._STATE_CANCELLED), \ + "Should be running or cancelled if processing response" self._nextUri = response_json.get('nextUri') self._columns = response_json.get('columns') if 'X-Presto-Clear-Session' in response.headers: diff --git a/pyhive/tests/test_presto.py b/pyhive/tests/test_presto.py index 35a62b5e..bb3b470d 100644 --- a/pyhive/tests/test_presto.py +++ b/pyhive/tests/test_presto.py @@ -3,7 +3,6 @@ These rely on having a Presto+Hadoop cluster set up. They also require a tables created by make_test_tables.sh. """ - from __future__ import absolute_import from __future__ import unicode_literals @@ -17,13 +16,14 @@ import unittest _HOST = 'localhost' +_PORT = '8080' class TestPresto(unittest.TestCase, DBAPITestCase): __test__ = True def connect(self): - return presto.connect(host=_HOST, source=self.id()) + return presto.connect(host=_HOST, port=_PORT, source=self.id()) @with_cursor def test_description(self, cursor): @@ -78,6 +78,17 @@ def test_complex(self, cursor): # catch unicode/str self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) + @with_cursor + def test_cancel(self, cursor): + cursor.execute( + "SELECT a.a * rand(), b.a*rand()" + "FROM many_rows a " + "CROSS JOIN many_rows b " + ) + self.assertIn(cursor.poll()['stats']['state'], ('PLANNING', 'RUNNING')) + cursor.cancel() + self.assertRaises(exc.DatabaseError, cursor.poll) + def test_noops(self): """The DB-API specification requires that certain actions exist, even though they might not be applicable."""