From 419a05ff8ab1c22c6526e622f5e7b334b5b8d663 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Fri, 22 Sep 2023 14:14:28 -0700 Subject: [PATCH] Fix function params whose name collide with SQL type names We need to make sure to quote them. Unfortunately we can't *unconditionally* quote them in quote_ident, because when referring to them as types or builtin functions they need to not be quoted. Fixes #6062. --- edb/pgsql/codegen.py | 18 +++++++++--------- edb/pgsql/common.py | 19 ++++++++++++++----- edb/pgsql/dbops/functions.py | 6 +++--- tests/test_edgeql_ddl.py | 12 ++++++++++++ 4 files changed, 38 insertions(+), 17 deletions(-) diff --git a/edb/pgsql/codegen.py b/edb/pgsql/codegen.py index 258a460ea2c..25b1adc226b 100644 --- a/edb/pgsql/codegen.py +++ b/edb/pgsql/codegen.py @@ -212,7 +212,7 @@ def gen_ctes(self, ctes: List[pgast.CommonTableExpr]) -> None: if cte.aliascolnames: self.write('(') for (index, col_name) in enumerate(cte.aliascolnames): - self.write(common.qname(col_name)) + self.write(common.qname(col_name, column=True)) if index + 1 < len(cte.aliascolnames): self.write(',') self.write(')') @@ -578,18 +578,18 @@ def visit_ResTarget(self, node: pgast.ResTarget) -> None: if node.indirection: self._visit_indirection_ops(node.indirection) if node.name: - self.write(' AS ' + common.quote_ident(node.name)) + self.write(' AS ' + common.quote_col(node.name)) def visit_InsertTarget(self, node: pgast.InsertTarget) -> None: - self.write(common.quote_ident(node.name)) + self.write(common.quote_col(node.name)) def visit_UpdateTarget(self, node: pgast.UpdateTarget) -> None: if isinstance(node.name, list): self.write('(') - self.write(', '.join(common.quote_ident(n) for n in node.name)) + self.write(', '.join(common.quote_col(n) for n in node.name)) self.write(')') else: - self.write(common.quote_ident(node.name)) + self.write(common.quote_col(node.name)) if node.indirection: self._visit_indirection_ops(node.indirection) self.write(' = ') @@ -599,7 +599,7 @@ def visit_Alias(self, node: pgast.Alias) -> None: self.write(common.quote_ident(node.aliasname)) if node.colnames: self.write('(') - self.write(', '.join(common.quote_ident(n) for n in node.colnames)) + self.write(', '.join(common.quote_col(n) for n in node.colnames)) self.write(')') def visit_Keyword(self, node: pgast.Keyword) -> None: @@ -668,15 +668,15 @@ def visit_ColumnRef(self, node: pgast.ColumnRef) -> None: self.write(names[0]) if len(names) > 1: self.write('.') - self.write(common.qname(*names[1:])) + self.write(common.qname(*names[1:], column=True)) else: - self.write(common.qname(*names)) + self.write(common.qname(*names, column=True)) def visit_ExprOutputVar(self, node: pgast.ExprOutputVar) -> None: self.visit(node.expr) def visit_ColumnDef(self, node: pgast.ColumnDef) -> None: - self.write(common.quote_ident(node.name)) + self.write(common.quote_col(node.name)) if node.typename: self.write(' ') self.visit(node.typename) diff --git a/edb/pgsql/common.py b/edb/pgsql/common.py index 21506996477..8304c0f067f 100644 --- a/edb/pgsql/common.py +++ b/edb/pgsql/common.py @@ -69,10 +69,17 @@ def _quote_ident(string: str) -> str: return '"' + string.replace('"', '""') + '"' -def quote_ident(ident: str | pgast.Star, *, force=False) -> str: +def quote_ident(ident: str | pgast.Star, *, force=False, column=False) -> str: if isinstance(ident, pgast.Star): return "*" - return _quote_ident(ident) if needs_quoting(ident) or force else ident + return ( + _quote_ident(ident) + if needs_quoting(ident, column=column) or force else ident + ) + + +def quote_col(ident: str | pgast.Star) -> str: + return quote_ident(ident, column=True) def quote_bytea_literal(data: bytes) -> str: @@ -85,7 +92,7 @@ def quote_bytea_literal(data: bytes) -> str: return "''::bytea" -def needs_quoting(string: str) -> bool: +def needs_quoting(string: str, column: bool=False) -> bool: isalnum = (string and not string[0].isdecimal() and string.replace('_', 'a').isalnum()) return ( @@ -94,13 +101,15 @@ def needs_quoting(string: str) -> bool: pg_keywords.RESERVED_KEYWORD] or string.lower() in pg_keywords.by_type[ pg_keywords.TYPE_FUNC_NAME_KEYWORD] or + (column and string.lower() in pg_keywords.by_type[ + pg_keywords.COL_NAME_KEYWORD]) or string.lower() != string ) -def qname(*parts: str | pgast.Star) -> str: +def qname(*parts: str | pgast.Star, column: bool=False) -> str: assert len(parts) <= 3, parts - return '.'.join([quote_ident(q) for q in parts]) + return '.'.join([quote_ident(q, column=column) for q in parts]) def quote_type(type_: Tuple[str, ...] | str): diff --git a/edb/pgsql/dbops/functions.py b/edb/pgsql/dbops/functions.py index 1d6c0aca6c5..82ece65155d 100644 --- a/edb/pgsql/dbops/functions.py +++ b/edb/pgsql/dbops/functions.py @@ -23,7 +23,7 @@ from typing import * from ..common import qname as qn -from ..common import quote_ident as qi +from ..common import quote_col as qc from ..common import quote_literal as ql from ..common import quote_type as qt @@ -73,7 +73,7 @@ def __init__(self, name, args=None): self.args = args def code(self, block: base.PLBlock) -> str: - args = f"ARRAY[{','.join(qi(a) for a in self.args)}]" + args = f"ARRAY[{','.join(qc(a) for a in self.args)}]" return textwrap.dedent(f'''\ SELECT @@ -111,7 +111,7 @@ def format_args( if isinstance(arg, tuple): if arg[0] is not None: - arg_expr += qn(arg[0]) + arg_expr += qn(arg[0], column=True) if len(arg) > 1: arg_expr += ' ' + qt(arg[1]) if include_defaults: diff --git a/tests/test_edgeql_ddl.py b/tests/test_edgeql_ddl.py index 05334684229..16f5834940d 100644 --- a/tests/test_edgeql_ddl.py +++ b/tests/test_edgeql_ddl.py @@ -4921,6 +4921,18 @@ async def test_edgeql_ddl_function_37(self): ) self.assertEqual(val, 1) + async def test_edgeql_ddl_function_38(self): + await self.con.execute(''' + create function myFuncFailA(character: int64) -> float64 + using ( + select 2.3 + ); + create function myFuncFailB(interval: str) -> float64 + using ( + select 2.3 + ); + ''') + async def test_edgeql_ddl_function_rename_01(self): await self.con.execute(""" CREATE FUNCTION foo(s: str) -> str {