diff --git a/pglast/visitors.py b/pglast/visitors.py index 1cb96d7..127d7cb 100644 --- a/pglast/visitors.py +++ b/pglast/visitors.py @@ -349,17 +349,20 @@ class RelationNames(Visitor): relations referenced by the given :class:`node `. """ + class CTENames(Visitor): + def __call__(self, node): + self.ctenames = set() + super().__call__(node) + return self.ctenames + + def visit_CommonTableExpr(self, ancestors, node): + self.ctenames.add(node.ctename) + def __call__(self, node): - self.ctenames = set() + self.ctenames = self.CTENames()(node) self.rnames = set() - self.ctes_rnames = set() super().__call__(node) - return (self.rnames - self.ctenames).union(self.ctes_rnames) - - def visit_CommonTableExpr(self, ancestors, node): - "Collect CTE names." - - self.ctenames.add(node.ctename) + return self.rnames def visit_DropStmt(self, ancestors, node): from .enums import ObjectType @@ -370,7 +373,7 @@ def visit_DropStmt(self, ancestors, node): self.rnames.add('.'.join(maybe_double_quote_name(n.val) for n in obj)) def visit_RangeVar(self, ancestors, node): - "Collect relation names." + "Collect relation names, taking into account defined CTE names" from .stream import maybe_double_quote_name @@ -382,10 +385,17 @@ def visit_RangeVar(self, ancestors, node): if node.catalogname: tname = f'{maybe_double_quote_name(node.catalogname)}.{tname}' - if ast.CommonTableExpr in ancestors: - self.ctes_rnames.add(tname) - else: + if tname not in self.ctenames: self.rnames.add(tname) + else: + # If the name is within a non-recursive CTE and matches its name, + # then it is a concrete relation + nearest_cte = ancestors.find_nearest(ast.CommonTableExpr) + if nearest_cte is not None: + with_clause = nearest_cte.find_nearest(ast.WithClause) + if not with_clause.node.recursive: + if tname == nearest_cte.node.ctename: + self.rnames.add(tname) def referenced_relations(stmt): diff --git a/tests/test_visitors.py b/tests/test_visitors.py index dc569ea..61b1a48 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -26,6 +26,21 @@ {'"my.schema".bar', 'bar."my.table"', '"foo.bar"'}), ('with my_ref as (select * from my_ref where a=1) select * from my_ref', {'my_ref'}), + ('with cte1 as (select 1), cte2 as (select * from cte1) select * from cte2', + set()), + (''' + with recursive t(n) as (values (1) union all select n+1 from t where n < 100) + select sum(n) from t + ''', set()), + (''' + with cte1 as (select 1) + select * from (with cte2 as (select * from cte1) + select * from cte2) as a + ''', set()), + (''' + with to_archive as (delete from products where date < '2010-11-01' returning *) + insert into products_log select * from to_archive + ''', {'products', 'products_log'}), )) def test_referenced_tables(stmt, rnames): assert visitors.referenced_relations(stmt) == rnames