Skip to content

Commit

Permalink
Fix function params whose name collide with SQL type names (#6150)
Browse files Browse the repository at this point in the history
We need to make sure to quote them. Unfortunately we can't
*unconditionally* quote them in quote_ident, because when referring to
them as types or builtin functions they need to not be quoted.

Fixes #6062.
  • Loading branch information
msullivan committed Sep 23, 2023
1 parent 7e5ad53 commit a63aa13
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 17 deletions.
18 changes: 9 additions & 9 deletions edb/pgsql/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def gen_ctes(self, ctes: List[pgast.CommonTableExpr]) -> None:
if cte.aliascolnames:
self.write('(')
for (index, col_name) in enumerate(cte.aliascolnames):
self.write(common.qname(col_name))
self.write(common.qname(col_name, column=True))
if index + 1 < len(cte.aliascolnames):
self.write(',')
self.write(')')
Expand Down Expand Up @@ -578,18 +578,18 @@ def visit_ResTarget(self, node: pgast.ResTarget) -> None:
if node.indirection:
self._visit_indirection_ops(node.indirection)
if node.name:
self.write(' AS ' + common.quote_ident(node.name))
self.write(' AS ' + common.quote_col(node.name))

def visit_InsertTarget(self, node: pgast.InsertTarget) -> None:
self.write(common.quote_ident(node.name))
self.write(common.quote_col(node.name))

def visit_UpdateTarget(self, node: pgast.UpdateTarget) -> None:
if isinstance(node.name, list):
self.write('(')
self.write(', '.join(common.quote_ident(n) for n in node.name))
self.write(', '.join(common.quote_col(n) for n in node.name))
self.write(')')
else:
self.write(common.quote_ident(node.name))
self.write(common.quote_col(node.name))
if node.indirection:
self._visit_indirection_ops(node.indirection)
self.write(' = ')
Expand All @@ -599,7 +599,7 @@ def visit_Alias(self, node: pgast.Alias) -> None:
self.write(common.quote_ident(node.aliasname))
if node.colnames:
self.write('(')
self.write(', '.join(common.quote_ident(n) for n in node.colnames))
self.write(', '.join(common.quote_col(n) for n in node.colnames))
self.write(')')

def visit_Keyword(self, node: pgast.Keyword) -> None:
Expand Down Expand Up @@ -668,15 +668,15 @@ def visit_ColumnRef(self, node: pgast.ColumnRef) -> None:
self.write(names[0])
if len(names) > 1:
self.write('.')
self.write(common.qname(*names[1:]))
self.write(common.qname(*names[1:], column=True))
else:
self.write(common.qname(*names))
self.write(common.qname(*names, column=True))

def visit_ExprOutputVar(self, node: pgast.ExprOutputVar) -> None:
self.visit(node.expr)

def visit_ColumnDef(self, node: pgast.ColumnDef) -> None:
self.write(common.quote_ident(node.name))
self.write(common.quote_col(node.name))
if node.typename:
self.write(' ')
self.visit(node.typename)
Expand Down
19 changes: 14 additions & 5 deletions edb/pgsql/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,17 @@ def _quote_ident(string: str) -> str:
return '"' + string.replace('"', '""') + '"'


def quote_ident(ident: str | pgast.Star, *, force=False) -> str:
def quote_ident(ident: str | pgast.Star, *, force=False, column=False) -> str:
if isinstance(ident, pgast.Star):
return "*"
return _quote_ident(ident) if needs_quoting(ident) or force else ident
return (
_quote_ident(ident)
if needs_quoting(ident, column=column) or force else ident
)


def quote_col(ident: str | pgast.Star) -> str:
return quote_ident(ident, column=True)


def quote_bytea_literal(data: bytes) -> str:
Expand All @@ -85,7 +92,7 @@ def quote_bytea_literal(data: bytes) -> str:
return "''::bytea"


def needs_quoting(string: str) -> bool:
def needs_quoting(string: str, column: bool=False) -> bool:
isalnum = (string and not string[0].isdecimal() and
string.replace('_', 'a').isalnum())
return (
Expand All @@ -94,13 +101,15 @@ def needs_quoting(string: str) -> bool:
pg_keywords.RESERVED_KEYWORD] or
string.lower() in pg_keywords.by_type[
pg_keywords.TYPE_FUNC_NAME_KEYWORD] or
(column and string.lower() in pg_keywords.by_type[
pg_keywords.COL_NAME_KEYWORD]) or
string.lower() != string
)


def qname(*parts: str | pgast.Star) -> str:
def qname(*parts: str | pgast.Star, column: bool=False) -> str:
assert len(parts) <= 3, parts
return '.'.join([quote_ident(q) for q in parts])
return '.'.join([quote_ident(q, column=column) for q in parts])


def quote_type(type_: Tuple[str, ...] | str):
Expand Down
6 changes: 3 additions & 3 deletions edb/pgsql/dbops/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import *

from ..common import qname as qn
from ..common import quote_ident as qi
from ..common import quote_col as qc
from ..common import quote_literal as ql
from ..common import quote_type as qt

Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, name, args=None):
self.args = args

def code(self, block: base.PLBlock) -> str:
args = f"ARRAY[{','.join(qi(a) for a in self.args)}]"
args = f"ARRAY[{','.join(qc(a) for a in self.args)}]"

return textwrap.dedent(f'''\
SELECT
Expand Down Expand Up @@ -111,7 +111,7 @@ def format_args(

if isinstance(arg, tuple):
if arg[0] is not None:
arg_expr += qn(arg[0])
arg_expr += qn(arg[0], column=True)
if len(arg) > 1:
arg_expr += ' ' + qt(arg[1])
if include_defaults:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_edgeql_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4921,6 +4921,18 @@ async def test_edgeql_ddl_function_37(self):
)
self.assertEqual(val, 1)

async def test_edgeql_ddl_function_38(self):
await self.con.execute('''
create function myFuncFailA(character: int64) -> float64
using (
select 2.3
);
create function myFuncFailB(interval: str) -> float64
using (
select 2.3
);
''')

async def test_edgeql_ddl_function_rename_01(self):
await self.con.execute("""
CREATE FUNCTION foo(s: str) -> str {
Expand Down

0 comments on commit a63aa13

Please sign in to comment.