Skip to content

Commit

Permalink
Close socket on KeyboardInterrupt
Browse files Browse the repository at this point in the history
  • Loading branch information
xzkostyan committed Oct 20, 2019
1 parent 0140d5c commit e69809a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 13 deletions.
8 changes: 4 additions & 4 deletions clickhouse_driver/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def packet_generator(self):

yield packet

except Exception:
except (Exception, KeyboardInterrupt):
self.disconnect()
raise

Expand Down Expand Up @@ -216,7 +216,7 @@ def execute(self, query, params=None, with_column_types=False,
self.last_query.store_elapsed(time() - start_time)
return rv

except Exception:
except (Exception, KeyboardInterrupt):
self.disconnect()
raise

Expand Down Expand Up @@ -261,7 +261,7 @@ def execute_with_progress(
types_check=types_check, columnar=columnar
)

except Exception:
except (Exception, KeyboardInterrupt):
self.disconnect()
raise

Expand Down Expand Up @@ -304,7 +304,7 @@ def execute_iter(
query_id=query_id, types_check=types_check
)

except Exception:
except (Exception, KeyboardInterrupt):
self.disconnect()
raise

Expand Down
46 changes: 37 additions & 9 deletions tests/test_blocks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import types

from mock import patch

from clickhouse_driver.errors import ServerException
from tests.testcase import BaseTestCase
from tests.util import capture_logging, require_server_version
Expand Down Expand Up @@ -61,6 +63,24 @@ def test_select_with_column_types(self):
)
self.assertEqual(rv, ([(1,)], [('x', 'Int32')]))

def test_select_with_columnar_with_column_types(self):
progress = self.client.execute_with_progress(
'SELECT arrayJoin(A) -1 as j,'
'arrayJoin(A)+1 as k FROM('
'SELECT range(3) as A)',
columnar=True, with_column_types=True)
rv = ([(-1, 0, 1), (1, 2, 3)], [('j', 'Int16'), ('k', 'UInt16')])
self.assertEqual(progress.get_result(), rv)

def test_close_connection_on_keyboard_interrupt(self):
connection = self.client.connection
with self.assertRaises(KeyboardInterrupt):
with patch.object(connection, 'send_query') as mocked_send_query:
mocked_send_query.side_effect = KeyboardInterrupt
self.client.execute('SELECT 1')

self.assertFalse(self.client.connection.connected)


class ProgressTestCase(BaseTestCase):
def test_select_with_progress(self):
Expand Down Expand Up @@ -111,6 +131,15 @@ def test_select_with_progress_with_params(self):
self.assertEqual(progress.get_result(), [(2,)])
self.assertTrue(self.client.connection.connected)

def test_close_connection_on_keyboard_interrupt(self):
connection = self.client.connection
with self.assertRaises(KeyboardInterrupt):
with patch.object(connection, 'send_query') as mocked_send_query:
mocked_send_query.side_effect = KeyboardInterrupt
self.client.execute_with_progress('SELECT 1')

self.assertFalse(self.client.connection.connected)


class IteratorTestCase(BaseTestCase):
def test_select_with_iter(self):
Expand Down Expand Up @@ -145,15 +174,14 @@ def test_select_with_iter_error(self):

self.assertFalse(self.client.connection.connected)

def test_select_with_columar_with_column_types(self):
progress = self.client.execute_with_progress(
'SELECT arrayJoin(A) -1 as j,'
'arrayJoin(A)+1 as k FROM('
'SELECT range(3) as A)',
columnar=True, with_column_types=True)
rv = ([(-1, 0, 1), (1, 2, 3)], [('j', 'Int16'), ('k', 'UInt16')])
self.assertEqual(progress.get_result(), rv)
self.assertTrue(self.client.connection.connected)
def test_close_connection_on_keyboard_interrupt(self):
connection = self.client.connection
with self.assertRaises(KeyboardInterrupt):
with patch.object(connection, 'send_query') as mocked_send_query:
mocked_send_query.side_effect = KeyboardInterrupt
self.client.execute_iter('SELECT 1')

self.assertFalse(self.client.connection.connected)


class LogTestCase(BaseTestCase):
Expand Down

0 comments on commit e69809a

Please sign in to comment.