Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,31 @@ def get_workunits_internal(
if use_cached_audit_log:
logger.info(f"Using cached audit log at {audit_log_file}")
else:
logger.info(f"Fetching audit log into {audit_log_file}")
# Check if any query-based features are enabled before fetching
needs_query_data = any(
[
self.config.include_lineage,
self.config.include_queries,
self.config.include_usage_statistics,
self.config.include_query_usage_statistics,
self.config.include_operations,
]
)

if not needs_query_data:
logger.info(
"All query-based features are disabled. Skipping expensive query log fetch."
)
else:
logger.info(f"Fetching audit log into {audit_log_file}")

with self.report.copy_history_fetch_timer:
for copy_entry in self.fetch_copy_history():
queries.append(copy_entry)
with self.report.copy_history_fetch_timer:
for copy_entry in self.fetch_copy_history():
queries.append(copy_entry)

with self.report.query_log_fetch_timer:
for entry in self.fetch_query_log(users):
queries.append(entry)
with self.report.query_log_fetch_timer:
for entry in self.fetch_query_log(users):
queries.append(entry)

stored_proc_tracker: StoredProcLineageTracker = self._exit_stack.enter_context(
StoredProcLineageTracker(
Expand Down
216 changes: 213 additions & 3 deletions metadata-ingestion/tests/unit/snowflake/test_snowflake_queries.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
import datetime
from unittest.mock import Mock, patch

import pytest
import sqlglot
from sqlglot.dialects.snowflake import Snowflake

from datahub.configuration.common import AllowDenyPattern
from datahub.configuration.time_window_config import BucketDuration
from datahub.ingestion.source.snowflake.snowflake_config import QueryDedupStrategyType
from datahub.ingestion.source.snowflake.snowflake_queries import QueryLogQueryBuilder
from datahub.configuration.time_window_config import (
BaseTimeWindowConfig,
BucketDuration,
)
from datahub.ingestion.source.snowflake.snowflake_config import (
QueryDedupStrategyType,
SnowflakeIdentifierConfig,
)
from datahub.ingestion.source.snowflake.snowflake_queries import (
QueryLogQueryBuilder,
SnowflakeQueriesExtractor,
SnowflakeQueriesExtractorConfig,
)
from datahub.ingestion.source.snowflake.snowflake_query import SnowflakeQuery


Expand Down Expand Up @@ -389,3 +400,202 @@ def test_get_views_for_schema_query_syntax(self):
assert where_clause is not None
where_str = str(where_clause).upper()
assert "TABLE_SCHEMA" in where_str and "PUBLIC" in where_str


class TestSnowflakeQueriesExtractorOptimization:
"""Tests for the query fetch optimization when all features are disabled."""

def _create_mock_extractor(
self,
include_lineage: bool = False,
include_queries: bool = False,
include_usage_statistics: bool = False,
include_query_usage_statistics: bool = False,
include_operations: bool = False,
) -> SnowflakeQueriesExtractor:
"""Helper to create a SnowflakeQueriesExtractor with mocked dependencies."""
mock_connection = Mock()
mock_connection.query.return_value = []

config = SnowflakeQueriesExtractorConfig(
window=BaseTimeWindowConfig(
start_time=datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc),
end_time=datetime.datetime(2021, 1, 2, tzinfo=datetime.timezone.utc),
),
include_lineage=include_lineage,
include_queries=include_queries,
include_usage_statistics=include_usage_statistics,
include_query_usage_statistics=include_query_usage_statistics,
include_operations=include_operations,
)

mock_report = Mock()
mock_filters = Mock()
mock_identifiers = Mock()
mock_identifiers.platform = "snowflake"
mock_identifiers.identifier_config = SnowflakeIdentifierConfig()

extractor = SnowflakeQueriesExtractor(
connection=mock_connection,
config=config,
structured_report=mock_report,
filters=mock_filters,
identifiers=mock_identifiers,
)

return extractor

def test_skip_query_fetch_when_all_features_disabled(self):
"""Test that query fetching is skipped when all query features are disabled."""
extractor = self._create_mock_extractor(
include_lineage=False,
include_queries=False,
include_usage_statistics=False,
include_query_usage_statistics=False,
include_operations=False,
)

# Mock the fetch methods
with (
patch.object(extractor, "fetch_users", return_value={}) as mock_fetch_users,
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
# Execute the method
list(extractor.get_workunits_internal())

# Verify fetch_users was called (always needed for setup)
mock_fetch_users.assert_called_once()

# Verify expensive fetches were NOT called
mock_fetch_copy_history.assert_not_called()
mock_fetch_query_log.assert_not_called()

def test_fetch_queries_when_lineage_enabled(self):
"""Test that query fetching happens when lineage is enabled."""
extractor = self._create_mock_extractor(
include_lineage=True,
include_queries=False,
include_usage_statistics=False,
include_query_usage_statistics=False,
include_operations=False,
)

with (
patch.object(extractor, "fetch_users", return_value={}) as mock_fetch_users,
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
list(extractor.get_workunits_internal())

mock_fetch_users.assert_called_once()
mock_fetch_copy_history.assert_called_once()
mock_fetch_query_log.assert_called_once()

def test_fetch_queries_when_usage_statistics_enabled(self):
"""Test that query fetching happens when usage statistics are enabled."""
extractor = self._create_mock_extractor(
include_lineage=False,
include_queries=False,
include_usage_statistics=True,
include_query_usage_statistics=False,
include_operations=False,
)

with (
patch.object(extractor, "fetch_users", return_value={}),
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
list(extractor.get_workunits_internal())

mock_fetch_copy_history.assert_called_once()
mock_fetch_query_log.assert_called_once()

def test_fetch_queries_when_operations_enabled(self):
"""Test that query fetching happens when operations are enabled."""
extractor = self._create_mock_extractor(
include_lineage=False,
include_queries=False,
include_usage_statistics=False,
include_query_usage_statistics=False,
include_operations=True,
)

with (
patch.object(extractor, "fetch_users", return_value={}),
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
list(extractor.get_workunits_internal())

mock_fetch_copy_history.assert_called_once()
mock_fetch_query_log.assert_called_once()

def test_fetch_queries_when_any_single_feature_enabled(self):
"""Test that query fetching happens when any single feature is enabled."""
features = [
"include_lineage",
"include_queries",
"include_usage_statistics",
"include_query_usage_statistics",
"include_operations",
]

for feature in features:
kwargs = {f: False for f in features}
kwargs[feature] = True

extractor = self._create_mock_extractor(**kwargs)

with (
patch.object(extractor, "fetch_users", return_value={}),
patch.object(
extractor, "fetch_copy_history", return_value=[]
) as mock_fetch_copy_history,
patch.object(
extractor, "fetch_query_log", return_value=[]
) as mock_fetch_query_log,
):
list(extractor.get_workunits_internal())

# Verify fetches were called
mock_fetch_copy_history.assert_called_once()
mock_fetch_query_log.assert_called_once()

def test_report_counts_with_disabled_features(self):
"""Test that report counts are zero when features are disabled."""
extractor = self._create_mock_extractor(
include_lineage=False,
include_queries=False,
include_usage_statistics=False,
include_query_usage_statistics=False,
include_operations=False,
)

with (
patch.object(extractor, "fetch_users", return_value={}),
patch.object(extractor, "fetch_copy_history", return_value=[]),
patch.object(extractor, "fetch_query_log", return_value=[]),
):
list(extractor.get_workunits_internal())

# Verify that num_preparsed_queries is 0
assert extractor.report.sql_aggregator is not None
assert extractor.report.sql_aggregator.num_preparsed_queries == 0
Loading