diff --git a/redash/query_runner/__init__.py b/redash/query_runner/__init__.py index 5a293b83227..104848feedc 100644 --- a/redash/query_runner/__init__.py +++ b/redash/query_runner/__init__.py @@ -9,6 +9,7 @@ __all__ = [ 'ValidationError', 'BaseQueryRunner', + 'InterruptException', 'TYPE_DATETIME', 'TYPE_BOOLEAN', 'TYPE_INTEGER', @@ -38,6 +39,9 @@ TYPE_DATE ]) +class InterruptException(Exception): + pass + class BaseQueryRunner(object): def __init__(self, configuration): jsonschema.validate(configuration, self.configuration_schema()) diff --git a/redash/query_runner/pg.py b/redash/query_runner/pg.py index 0cd6403f65d..a69a03bb516 100644 --- a/redash/query_runner/pg.py +++ b/redash/query_runner/pg.py @@ -142,7 +142,7 @@ def run_query(self, query): logging.exception(e) error = e.message json_data = None - except KeyboardInterrupt: + except (KeyboardInterrupt, InterruptException): connection.cancel() error = "Query cancelled by user." json_data = None diff --git a/redash/tasks.py b/redash/tasks.py index 4693957d88b..6d5396d2fea 100644 --- a/redash/tasks.py +++ b/redash/tasks.py @@ -1,5 +1,6 @@ import time import logging +import signal from flask.ext.mail import Message import redis from celery import Task @@ -8,7 +9,7 @@ from redash import redis_connection, models, statsd_client, settings, utils, mail from redash.utils import gen_query_hash from redash.worker import celery -from redash.query_runner import get_query_runner +from redash.query_runner import get_query_runner, InterruptException logger = get_task_logger(__name__) @@ -132,7 +133,7 @@ def ready(self): return self._async_result.ready() def cancel(self): - return self._async_result.revoke(terminate=True) + return self._async_result.revoke(terminate=True, signal='SIGINT') @staticmethod def _job_lock_id(query_hash, data_source_id): @@ -263,9 +264,12 @@ def check_alerts_for_query(self, query_id): mail.send(message) +def signal_handler(*args): + raise InterruptException @celery.task(bind=True, base=BaseTask, track_started=True) def execute_query(self, query, data_source_id, metadata): + signal.signal(signal.SIGINT, signal_handler) start_time = time.time() logger.info("Loading data source (%d)...", data_source_id)