Skip to content

Commit

Permalink
Further fixes to the logic to extract relation names from a statement
Browse files Browse the repository at this point in the history
Traverse the statement tree twice, first to collect the CTE names, then
to recognize concrete relation names.

This should fix additional defects reported in issue #106.
  • Loading branch information
lelit committed Jun 18, 2022
1 parent f6b187c commit aa80697
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 12 deletions.
34 changes: 22 additions & 12 deletions pglast/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,17 +349,20 @@ class RelationNames(Visitor):
relations referenced by the given :class:`node <pglast.ast.Node>`.
"""

class CTENames(Visitor):
def __call__(self, node):
self.ctenames = set()
super().__call__(node)
return self.ctenames

def visit_CommonTableExpr(self, ancestors, node):
self.ctenames.add(node.ctename)

def __call__(self, node):
self.ctenames = set()
self.ctenames = self.CTENames()(node)
self.rnames = set()
self.ctes_rnames = set()
super().__call__(node)
return (self.rnames - self.ctenames).union(self.ctes_rnames)

def visit_CommonTableExpr(self, ancestors, node):
"Collect CTE names."

self.ctenames.add(node.ctename)
return self.rnames

def visit_DropStmt(self, ancestors, node):
from .enums import ObjectType
Expand All @@ -370,7 +373,7 @@ def visit_DropStmt(self, ancestors, node):
self.rnames.add('.'.join(maybe_double_quote_name(n.val) for n in obj))

def visit_RangeVar(self, ancestors, node):
"Collect relation names."
"Collect relation names, taking into account defined CTE names"

from .stream import maybe_double_quote_name

Expand All @@ -382,10 +385,17 @@ def visit_RangeVar(self, ancestors, node):
if node.catalogname:
tname = f'{maybe_double_quote_name(node.catalogname)}.{tname}'

if ast.CommonTableExpr in ancestors:
self.ctes_rnames.add(tname)
else:
if tname not in self.ctenames:
self.rnames.add(tname)
else:
# If the name is within a non-recursive CTE and matches its name,
# then it is a concrete relation
nearest_cte = ancestors.find_nearest(ast.CommonTableExpr)
if nearest_cte is not None:
with_clause = nearest_cte.find_nearest(ast.WithClause)
if not with_clause.node.recursive:
if tname == nearest_cte.node.ctename:
self.rnames.add(tname)


def referenced_relations(stmt):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@
{'"my.schema".bar', 'bar."my.table"', '"foo.bar"'}),
('with my_ref as (select * from my_ref where a=1) select * from my_ref',
{'my_ref'}),
('with cte1 as (select 1), cte2 as (select * from cte1) select * from cte2',
set()),
('''
with recursive t(n) as (values (1) union all select n+1 from t where n < 100)
select sum(n) from t
''', set()),
('''
with cte1 as (select 1)
select * from (with cte2 as (select * from cte1)
select * from cte2) as a
''', set()),
('''
with to_archive as (delete from products where date < '2010-11-01' returning *)
insert into products_log select * from to_archive
''', {'products', 'products_log'}),
))
def test_referenced_tables(stmt, rnames):
assert visitors.referenced_relations(stmt) == rnames
Expand Down

0 comments on commit aa80697

Please sign in to comment.