diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index 750d1007bbfe0e..3ea5daa5ce9184 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -297,15 +297,31 @@ def get_workunits_internal( if use_cached_audit_log: logger.info(f"Using cached audit log at {audit_log_file}") else: - logger.info(f"Fetching audit log into {audit_log_file}") + # Check if any query-based features are enabled before fetching + needs_query_data = any( + [ + self.config.include_lineage, + self.config.include_queries, + self.config.include_usage_statistics, + self.config.include_query_usage_statistics, + self.config.include_operations, + ] + ) + + if not needs_query_data: + logger.info( + "All query-based features are disabled. Skipping expensive query log fetch." + ) + else: + logger.info(f"Fetching audit log into {audit_log_file}") - with self.report.copy_history_fetch_timer: - for copy_entry in self.fetch_copy_history(): - queries.append(copy_entry) + with self.report.copy_history_fetch_timer: + for copy_entry in self.fetch_copy_history(): + queries.append(copy_entry) - with self.report.query_log_fetch_timer: - for entry in self.fetch_query_log(users): - queries.append(entry) + with self.report.query_log_fetch_timer: + for entry in self.fetch_query_log(users): + queries.append(entry) stored_proc_tracker: StoredProcLineageTracker = self._exit_stack.enter_context( StoredProcLineageTracker( diff --git a/metadata-ingestion/tests/unit/snowflake/test_snowflake_queries.py b/metadata-ingestion/tests/unit/snowflake/test_snowflake_queries.py index a59b2824cde3b9..35e5de6ebbe973 100644 --- a/metadata-ingestion/tests/unit/snowflake/test_snowflake_queries.py +++ b/metadata-ingestion/tests/unit/snowflake/test_snowflake_queries.py @@ -1,13 +1,24 @@ import datetime +from unittest.mock import Mock, patch import pytest import sqlglot from sqlglot.dialects.snowflake import Snowflake from datahub.configuration.common import AllowDenyPattern -from datahub.configuration.time_window_config import BucketDuration -from datahub.ingestion.source.snowflake.snowflake_config import QueryDedupStrategyType -from datahub.ingestion.source.snowflake.snowflake_queries import QueryLogQueryBuilder +from datahub.configuration.time_window_config import ( + BaseTimeWindowConfig, + BucketDuration, +) +from datahub.ingestion.source.snowflake.snowflake_config import ( + QueryDedupStrategyType, + SnowflakeIdentifierConfig, +) +from datahub.ingestion.source.snowflake.snowflake_queries import ( + QueryLogQueryBuilder, + SnowflakeQueriesExtractor, + SnowflakeQueriesExtractorConfig, +) from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery @@ -389,3 +400,202 @@ def test_get_views_for_schema_query_syntax(self): assert where_clause is not None where_str = str(where_clause).upper() assert "TABLE_SCHEMA" in where_str and "PUBLIC" in where_str + + +class TestSnowflakeQueriesExtractorOptimization: + """Tests for the query fetch optimization when all features are disabled.""" + + def _create_mock_extractor( + self, + include_lineage: bool = False, + include_queries: bool = False, + include_usage_statistics: bool = False, + include_query_usage_statistics: bool = False, + include_operations: bool = False, + ) -> SnowflakeQueriesExtractor: + """Helper to create a SnowflakeQueriesExtractor with mocked dependencies.""" + mock_connection = Mock() + mock_connection.query.return_value = [] + + config = SnowflakeQueriesExtractorConfig( + window=BaseTimeWindowConfig( + start_time=datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc), + end_time=datetime.datetime(2021, 1, 2, tzinfo=datetime.timezone.utc), + ), + include_lineage=include_lineage, + include_queries=include_queries, + include_usage_statistics=include_usage_statistics, + include_query_usage_statistics=include_query_usage_statistics, + include_operations=include_operations, + ) + + mock_report = Mock() + mock_filters = Mock() + mock_identifiers = Mock() + mock_identifiers.platform = "snowflake" + mock_identifiers.identifier_config = SnowflakeIdentifierConfig() + + extractor = SnowflakeQueriesExtractor( + connection=mock_connection, + config=config, + structured_report=mock_report, + filters=mock_filters, + identifiers=mock_identifiers, + ) + + return extractor + + def test_skip_query_fetch_when_all_features_disabled(self): + """Test that query fetching is skipped when all query features are disabled.""" + extractor = self._create_mock_extractor( + include_lineage=False, + include_queries=False, + include_usage_statistics=False, + include_query_usage_statistics=False, + include_operations=False, + ) + + # Mock the fetch methods + with ( + patch.object(extractor, "fetch_users", return_value={}) as mock_fetch_users, + patch.object( + extractor, "fetch_copy_history", return_value=[] + ) as mock_fetch_copy_history, + patch.object( + extractor, "fetch_query_log", return_value=[] + ) as mock_fetch_query_log, + ): + # Execute the method + list(extractor.get_workunits_internal()) + + # Verify fetch_users was called (always needed for setup) + mock_fetch_users.assert_called_once() + + # Verify expensive fetches were NOT called + mock_fetch_copy_history.assert_not_called() + mock_fetch_query_log.assert_not_called() + + def test_fetch_queries_when_lineage_enabled(self): + """Test that query fetching happens when lineage is enabled.""" + extractor = self._create_mock_extractor( + include_lineage=True, + include_queries=False, + include_usage_statistics=False, + include_query_usage_statistics=False, + include_operations=False, + ) + + with ( + patch.object(extractor, "fetch_users", return_value={}) as mock_fetch_users, + patch.object( + extractor, "fetch_copy_history", return_value=[] + ) as mock_fetch_copy_history, + patch.object( + extractor, "fetch_query_log", return_value=[] + ) as mock_fetch_query_log, + ): + list(extractor.get_workunits_internal()) + + mock_fetch_users.assert_called_once() + mock_fetch_copy_history.assert_called_once() + mock_fetch_query_log.assert_called_once() + + def test_fetch_queries_when_usage_statistics_enabled(self): + """Test that query fetching happens when usage statistics are enabled.""" + extractor = self._create_mock_extractor( + include_lineage=False, + include_queries=False, + include_usage_statistics=True, + include_query_usage_statistics=False, + include_operations=False, + ) + + with ( + patch.object(extractor, "fetch_users", return_value={}), + patch.object( + extractor, "fetch_copy_history", return_value=[] + ) as mock_fetch_copy_history, + patch.object( + extractor, "fetch_query_log", return_value=[] + ) as mock_fetch_query_log, + ): + list(extractor.get_workunits_internal()) + + mock_fetch_copy_history.assert_called_once() + mock_fetch_query_log.assert_called_once() + + def test_fetch_queries_when_operations_enabled(self): + """Test that query fetching happens when operations are enabled.""" + extractor = self._create_mock_extractor( + include_lineage=False, + include_queries=False, + include_usage_statistics=False, + include_query_usage_statistics=False, + include_operations=True, + ) + + with ( + patch.object(extractor, "fetch_users", return_value={}), + patch.object( + extractor, "fetch_copy_history", return_value=[] + ) as mock_fetch_copy_history, + patch.object( + extractor, "fetch_query_log", return_value=[] + ) as mock_fetch_query_log, + ): + list(extractor.get_workunits_internal()) + + mock_fetch_copy_history.assert_called_once() + mock_fetch_query_log.assert_called_once() + + def test_fetch_queries_when_any_single_feature_enabled(self): + """Test that query fetching happens when any single feature is enabled.""" + features = [ + "include_lineage", + "include_queries", + "include_usage_statistics", + "include_query_usage_statistics", + "include_operations", + ] + + for feature in features: + kwargs = {f: False for f in features} + kwargs[feature] = True + + extractor = self._create_mock_extractor(**kwargs) + + with ( + patch.object(extractor, "fetch_users", return_value={}), + patch.object( + extractor, "fetch_copy_history", return_value=[] + ) as mock_fetch_copy_history, + patch.object( + extractor, "fetch_query_log", return_value=[] + ) as mock_fetch_query_log, + ): + list(extractor.get_workunits_internal()) + + # Verify fetches were called + mock_fetch_copy_history.assert_called_once() + mock_fetch_query_log.assert_called_once() + + def test_report_counts_with_disabled_features(self): + """Test that report counts are zero when features are disabled.""" + extractor = self._create_mock_extractor( + include_lineage=False, + include_queries=False, + include_usage_statistics=False, + include_query_usage_statistics=False, + include_operations=False, + ) + + with ( + patch.object(extractor, "fetch_users", return_value={}), + patch.object(extractor, "fetch_copy_history", return_value=[]), + patch.object(extractor, "fetch_query_log", return_value=[]), + ): + list(extractor.get_workunits_internal()) + + # Verify that num_preparsed_queries is 0 + assert extractor.report.sql_aggregator is not None + assert extractor.report.sql_aggregator.num_preparsed_queries == 0