diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 0c491553aefc..3c90975fa3e6 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -22,7 +22,7 @@ import tempfile import time from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING from urllib import parse import numpy as np @@ -576,3 +576,38 @@ def has_implicit_cancel(cls) -> bool: """ return True + + @classmethod + def get_view_names( + cls, + database: "Database", + inspector: Inspector, + schema: Optional[str], + ) -> Set[str]: + """ + Get all the view names within the specified schema. + + Per the SQLAlchemy definition if the schema is omitted the database’s default + schema is used, however some dialects infer the request as schema agnostic. + + Note that PyHive's Hive SQLAlchemy dialect does not adhere to the specification + where the `get_view_names` method returns both real tables and views. Futhermore + the dialect wrongfully infers the request as schema agnostic when the schema is + omitted. + + :param database: The database to inspect + :param inspector: The SQLAlchemy inspector + :param schema: The schema to inspect + :returns: The view names + """ + + sql = "SHOW VIEWS" + + if schema: + sql += f" IN `{schema}`" + + with database.get_raw_connection(schema=schema) as conn: + cursor = conn.cursor() + cursor.execute(sql) + results = cursor.fetchall() + return {row[0] for row in results} diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index e1aa918879f9..2e8fc09fd1fb 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -638,9 +638,10 @@ def get_view_names( Per the SQLAlchemy definition if the schema is omitted the database’s default schema is used, however some dialects infer the request as schema agnostic. - Note that PyHive's Hive and Presto SQLAlchemy dialects do not implement the - `get_view_names` method. To ensure consistency with the `get_table_names` method - the request is deemed schema agnostic when the schema is omitted. + Note that PyHive's Presto SQLAlchemy dialect does not adhere to the + specification as the `get_view_names` method is not defined. Futhermore the + dialect wrongfully infers the request as schema agnostic when the schema is + omitted. :param database: The database to inspect :param inspector: The SQLAlchemy inspector diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index 366648effa98..b39f2658979c 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -403,3 +403,41 @@ def is_correct_result(data: List, result: List) -> bool: ["ds=01-01-19/hour=1", "ds=01-03-19/hour=1", "ds=01-02-19/hour=2"], ["01-03-19", "1"], ) + + +def test_get_view_names_with_schema(): + database = mock.MagicMock() + mock_execute = mock.MagicMock() + database.get_raw_connection().__enter__().cursor().execute = mock_execute + database.get_raw_connection().__enter__().cursor().fetchall = mock.MagicMock( + return_value=[["a", "b,", "c"], ["d", "e"]] + ) + + schema = "schema" + result = HiveEngineSpec.get_view_names(database, mock.Mock(), schema) + mock_execute.assert_called_once_with(f"SHOW VIEWS IN `{schema}`") + assert result == {"a", "d"} + + +def test_get_view_names_without_schema(): + database = mock.MagicMock() + mock_execute = mock.MagicMock() + database.get_raw_connection().__enter__().cursor().execute = mock_execute + database.get_raw_connection().__enter__().cursor().fetchall = mock.MagicMock( + return_value=[["a", "b,", "c"], ["d", "e"]] + ) + result = HiveEngineSpec.get_view_names(database, mock.Mock(), None) + mock_execute.assert_called_once_with("SHOW VIEWS") + assert result == {"a", "d"} + + +@mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names") +@mock.patch("superset.db_engine_specs.hive.HiveEngineSpec.get_view_names") +def test_get_table_names( + mock_get_view_names, + mock_get_table_names, +): + mock_get_view_names.return_value = {"view1", "view2"} + mock_get_table_names.return_value = {"table1", "table2", "view1", "view2"} + tables = HiveEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None) + assert tables == {"table1", "table2"}