Skip to content

Commit

Permalink
Handle cases tables preceded by complex queries (#423)
Browse files Browse the repository at this point in the history
* Add test_join_followed_by_tables

* Handle ON keyword followed by comma

* Update sql_metadata/parser.py

Fix flake8 linting warning

* Black-formatted test/test_getting_tables.py

* Black-formatted test/test_getting_tables.py

* add test_getting_tables

* add _preceded_keywords to store keywords followed by subqueries

---------

Co-authored-by: Byunk <clearman001@gmail.com>
Co-authored-by: Maciej Brencz <maciej.brencz@gmail.com>
Co-authored-by: Kyungho Byoun <kyungho.byoun@sap.com>
  • Loading branch information
4 people committed Sep 19, 2023
1 parent d5d8fce commit 0e37db9
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 0 deletions.
6 changes: 6 additions & 0 deletions sql_metadata/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, sql: str = "", disable_logging: bool = False) -> None:
self._nested_level = 0
self._parenthesis_level = 0
self._open_parentheses: List[SQLToken] = []
self._preceded_keywords: List[SQLToken] = []
self._aliases_to_check = None
self._is_in_nested_function = False
self._is_in_with_block = False
Expand Down Expand Up @@ -164,6 +165,8 @@ def tokens(self) -> List[SQLToken]: # noqa: C901
elif token.is_right_parenthesis:
token.token_type = TokenType.PARENTHESIS
self._determine_closing_parenthesis_type(token=token)
if token.is_subquery_end:
last_keyword = self._preceded_keywords.pop()

last_keyword = self._determine_last_relevant_keyword(
token=token, last_keyword=last_keyword
Expand Down Expand Up @@ -856,6 +859,7 @@ def _determine_opening_parenthesis_type(self, token: SQLToken):
# inside subquery / derived table
token.is_subquery_start = True
self._subquery_level += 1
self._preceded_keywords.append(token.last_keyword_normalized)
token.subquery_level = self._subquery_level
elif token.previous_token.normalized in KEYWORDS_BEFORE_COLUMNS.union({","}):
# we are in columns and in a column subquery definition
Expand Down Expand Up @@ -970,6 +974,8 @@ def replace_back_quotes_in_string(match):
return query

def _determine_last_relevant_keyword(self, token: SQLToken, last_keyword: str):
if token.value == "," and token.last_keyword_normalized == "ON":
return "FROM"
if token.is_keyword and "".join(token.normalized.split()) in RELEVANT_KEYWORDS:
if (
not (
Expand Down
133 changes: 133 additions & 0 deletions test/test_getting_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,136 @@ def test_insert_into_on_duplicate_key_ipdate():
" VALUES ('user1', 'john doe', 20)"
" ON DUPLICATE KEY UPDATE name='john doe', age=20"
).tables == ["user"]


def test_join_followed_by_tables():
query = """
SELECT
web_site_id,
sum(sales_price) AS sales,
sum(profit) AS profit,
sum(return_amt) AS returns1,
sum(net_loss) AS profit_loss
FROM
(
SELECT
ws_web_site_sk AS wsr_web_site_sk,
ws_sold_date_sk AS date_sk,
ws_ext_sales_price AS sales_price,
ws_net_profit AS profit,
cast(
0 AS decimal(7, 2)
) AS return_amt,
cast(
0 AS decimal(7, 2)
) AS net_loss
FROM
web_sales
UNION ALL
SELECT
ws_web_site_sk AS wsr_web_site_sk,
wr_returned_date_sk AS date_sk,
cast(
0 AS decimal(7, 2)
) AS sales_price,
cast(
0 AS decimal(7, 2)
) AS profit,
wr_return_amt AS return_amt,
wr_net_loss AS net_loss
FROM
web_returns
LEFT OUTER JOIN web_sales ON (
wr_item_sk = ws_item_sk
AND wr_order_number = ws_order_number
)
) salesreturns,
date_dim,
web_site
WHERE
date_sk = d_date_sk
AND d_date BETWEEN cast('2002-08-22' AS date)
AND (
cast('2002-08-22' AS date) + INTERVAL '14' day
)
AND wsr_web_site_sk = web_site_sk
GROUP BY
web_site_id
"""
parser = Parser(query)
assert parser.tables == ["web_sales", "web_returns", "date_dim", "web_site"]

parser = Parser(
"""
SELECT *
FROM Sales
JOIN Customers
ON Sales.CustomerID = Customers.CustomerID,
(SELECT MAX(Revenue) FROM Sales),
Stores
"""
)
assert parser.tables == ["Sales", "Customers", "Stores"]


def test_subquery_followed_by_tables():
query = """
SELECT top 100 c_last_name ,
c_first_name ,
ca_city ,
bought_city ,
ss_ticket_number ,
amt,
profit
FROM
(SELECT ss_ticket_number ,
ss_customer_sk ,
ca_city bought_city ,
sum(ss_coupon_amt) amt ,
sum(ss_net_profit) profit
FROM store_sales,
date_dim,
store,
household_demographics,
customer_address
WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk
AND store_sales.ss_store_sk = store.s_store_sk
AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk
AND store_sales.ss_addr_sk = customer_address.ca_address_sk
AND (household_demographics.hd_dep_count = 3
OR household_demographics.hd_vehicle_count= 2)
AND date_dim.d_dow IN (6,
0)
AND date_dim.d_year IN (1998,
1998+1,
1998+2)
AND store.s_city IN ('Oak Grove',
'Fairview',
'Riverside',
'Five Points',
'Midway')
GROUP BY ss_ticket_number,
ss_customer_sk,
ss_addr_sk,
ca_city) dn,
customer,
customer_address current_addr
WHERE ss_customer_sk = c_customer_sk
AND customer.c_current_addr_sk = current_addr.ca_address_sk
AND current_addr.ca_city <> bought_city
ORDER BY c_last_name ,
c_first_name ,
ca_city ,
bought_city ,
ss_ticket_number
"""

parser = Parser(query)
assert parser.tables == [
"store_sales",
"date_dim",
"store",
"household_demographics",
"customer_address",
"customer",
]

0 comments on commit 0e37db9

Please sign in to comment.