diff --git a/deepnote_toolkit/runtime_initialization.py b/deepnote_toolkit/runtime_initialization.py index d8a7e3c..bfafaf5 100644 --- a/deepnote_toolkit/runtime_initialization.py +++ b/deepnote_toolkit/runtime_initialization.py @@ -6,6 +6,8 @@ import psycopg2.extensions import psycopg2.extras +from deepnote_toolkit.runtime_patches import apply_runtime_patches + from .dataframe_utils import add_formatters from .execute_post_start_hooks import execute_post_start_hooks from .logging import LoggerManager @@ -24,6 +26,11 @@ def init_deepnote_runtime(): logger.debug("Initializing Deepnote runtime environment started.") + try: + apply_runtime_patches() + except Exception as e: + logger.error("Failed to apply runtime patches with a error: %s", e) + # Register sparksql magic try: IPython.get_ipython().register_magics(SparkSql) diff --git a/deepnote_toolkit/runtime_patches.py b/deepnote_toolkit/runtime_patches.py new file mode 100644 index 0000000..9dda657 --- /dev/null +++ b/deepnote_toolkit/runtime_patches.py @@ -0,0 +1,53 @@ +from typing import Any, Optional, Union + +from deepnote_toolkit.logging import LoggerManager + +logger = LoggerManager().get_logger() + + +# TODO(BLU-5171): Temporary hack to allow cancelling BigQuery jobs on KeyboardInterrupt (e.g. when user cancels cell execution) +# Can be removed once +# 1. https://github.com/googleapis/python-bigquery/pull/2331 is merged and released +# 2. Dependencies updated for the toolkit. We don't depend on google-cloud-bigquery directly, but it's transitive +# dependency through sqlalchemy-bigquery +def _monkeypatch_bigquery_wait_or_cancel(): + try: + import google.cloud.bigquery._job_helpers as _job_helpers + from google.cloud.bigquery import job, table + + def _wait_or_cancel( + job_obj: job.QueryJob, + api_timeout: Optional[float], + wait_timeout: Optional[Union[object, float]], + retry: Optional[Any], + page_size: Optional[int], + max_results: Optional[int], + ) -> table.RowIterator: + try: + return job_obj.result( + page_size=page_size, + max_results=max_results, + retry=retry, + timeout=wait_timeout, + ) + except (KeyboardInterrupt, Exception): + try: + job_obj.cancel(retry=retry, timeout=api_timeout) + except (KeyboardInterrupt, Exception): + pass + raise + + _job_helpers._wait_or_cancel = _wait_or_cancel + logger.debug( + "Successfully monkeypatched google.cloud.bigquery._job_helpers._wait_or_cancel" + ) + except ImportError: + logger.warning( + "Could not monkeypatch BigQuery _wait_or_cancel: google.cloud.bigquery not available" + ) + except Exception as e: + logger.warning("Failed to monkeypatch BigQuery _wait_or_cancel: %s", repr(e)) + + +def apply_runtime_patches(): + _monkeypatch_bigquery_wait_or_cancel() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a871dd8..5f9179d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -7,6 +7,14 @@ import pytest +@pytest.fixture(autouse=True, scope="session") +def apply_patches() -> None: + """Apply runtime patches once before any tests run.""" + from deepnote_toolkit.runtime_patches import apply_runtime_patches + + apply_runtime_patches() + + @pytest.fixture(autouse=True) def clean_runtime_state() -> Generator[None, None, None]: """Automatically clean in-memory env state and config cache before and after each test.""" diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index 9e0187c..05eecf0 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -9,6 +9,27 @@ from deepnote_toolkit.sql import sql_execution as se +def test_bigquery_wait_or_cancel_handles_keyboard_interrupt(): + import google.cloud.bigquery._job_helpers as _job_helpers + + mock_job = mock.Mock() + mock_job.result.side_effect = KeyboardInterrupt("User interrupted") + mock_job.cancel = mock.Mock() + + with pytest.raises(KeyboardInterrupt): + # _wait_or_cancel should be monkeypatched by `_monkeypatch_bigquery_wait_or_cancel` + _job_helpers._wait_or_cancel( + job_obj=mock_job, + api_timeout=30.0, + wait_timeout=60.0, + retry=None, + page_size=None, + max_results=None, + ) + + mock_job.cancel.assert_called_once_with(retry=None, timeout=30.0) + + def test_build_params_for_bigquery_oauth_ok(): with mock.patch( "deepnote_toolkit.sql.sql_execution.bigquery.Client"