Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix function params whose name collide with SQL type names #6150

Merged
merged 1 commit into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 9 additions & 9 deletions edb/pgsql/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def gen_ctes(self, ctes: List[pgast.CommonTableExpr]) -> None:
if cte.aliascolnames:
self.write('(')
for (index, col_name) in enumerate(cte.aliascolnames):
self.write(common.qname(col_name))
self.write(common.qname(col_name, column=True))
if index + 1 < len(cte.aliascolnames):
self.write(',')
self.write(')')
Expand Down Expand Up @@ -578,18 +578,18 @@ def visit_ResTarget(self, node: pgast.ResTarget) -> None:
if node.indirection:
self._visit_indirection_ops(node.indirection)
if node.name:
self.write(' AS ' + common.quote_ident(node.name))
self.write(' AS ' + common.quote_col(node.name))

def visit_InsertTarget(self, node: pgast.InsertTarget) -> None:
self.write(common.quote_ident(node.name))
self.write(common.quote_col(node.name))

def visit_UpdateTarget(self, node: pgast.UpdateTarget) -> None:
if isinstance(node.name, list):
self.write('(')
self.write(', '.join(common.quote_ident(n) for n in node.name))
self.write(', '.join(common.quote_col(n) for n in node.name))
self.write(')')
else:
self.write(common.quote_ident(node.name))
self.write(common.quote_col(node.name))
if node.indirection:
self._visit_indirection_ops(node.indirection)
self.write(' = ')
Expand All @@ -599,7 +599,7 @@ def visit_Alias(self, node: pgast.Alias) -> None:
self.write(common.quote_ident(node.aliasname))
if node.colnames:
self.write('(')
self.write(', '.join(common.quote_ident(n) for n in node.colnames))
self.write(', '.join(common.quote_col(n) for n in node.colnames))
self.write(')')

def visit_Keyword(self, node: pgast.Keyword) -> None:
Expand Down Expand Up @@ -668,15 +668,15 @@ def visit_ColumnRef(self, node: pgast.ColumnRef) -> None:
self.write(names[0])
if len(names) > 1:
self.write('.')
self.write(common.qname(*names[1:]))
self.write(common.qname(*names[1:], column=True))
else:
self.write(common.qname(*names))
self.write(common.qname(*names, column=True))

def visit_ExprOutputVar(self, node: pgast.ExprOutputVar) -> None:
self.visit(node.expr)

def visit_ColumnDef(self, node: pgast.ColumnDef) -> None:
self.write(common.quote_ident(node.name))
self.write(common.quote_col(node.name))
if node.typename:
self.write(' ')
self.visit(node.typename)
Expand Down
19 changes: 14 additions & 5 deletions edb/pgsql/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,17 @@ def _quote_ident(string: str) -> str:
return '"' + string.replace('"', '""') + '"'


def quote_ident(ident: str | pgast.Star, *, force=False) -> str:
def quote_ident(ident: str | pgast.Star, *, force=False, column=False) -> str:
if isinstance(ident, pgast.Star):
return "*"
return _quote_ident(ident) if needs_quoting(ident) or force else ident
return (
_quote_ident(ident)
if needs_quoting(ident, column=column) or force else ident
)


def quote_col(ident: str | pgast.Star) -> str:
return quote_ident(ident, column=True)


def quote_bytea_literal(data: bytes) -> str:
Expand All @@ -85,7 +92,7 @@ def quote_bytea_literal(data: bytes) -> str:
return "''::bytea"


def needs_quoting(string: str) -> bool:
def needs_quoting(string: str, column: bool=False) -> bool:
isalnum = (string and not string[0].isdecimal() and
string.replace('_', 'a').isalnum())
return (
Expand All @@ -94,13 +101,15 @@ def needs_quoting(string: str) -> bool:
pg_keywords.RESERVED_KEYWORD] or
string.lower() in pg_keywords.by_type[
pg_keywords.TYPE_FUNC_NAME_KEYWORD] or
(column and string.lower() in pg_keywords.by_type[
pg_keywords.COL_NAME_KEYWORD]) or
string.lower() != string
)


def qname(*parts: str | pgast.Star) -> str:
def qname(*parts: str | pgast.Star, column: bool=False) -> str:
assert len(parts) <= 3, parts
return '.'.join([quote_ident(q) for q in parts])
return '.'.join([quote_ident(q, column=column) for q in parts])


def quote_type(type_: Tuple[str, ...] | str):
Expand Down
6 changes: 3 additions & 3 deletions edb/pgsql/dbops/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import *

from ..common import qname as qn
from ..common import quote_ident as qi
from ..common import quote_col as qc
from ..common import quote_literal as ql
from ..common import quote_type as qt

Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, name, args=None):
self.args = args

def code(self, block: base.PLBlock) -> str:
args = f"ARRAY[{','.join(qi(a) for a in self.args)}]"
args = f"ARRAY[{','.join(qc(a) for a in self.args)}]"

return textwrap.dedent(f'''\
SELECT
Expand Down Expand Up @@ -111,7 +111,7 @@ def format_args(

if isinstance(arg, tuple):
if arg[0] is not None:
arg_expr += qn(arg[0])
arg_expr += qn(arg[0], column=True)
if len(arg) > 1:
arg_expr += ' ' + qt(arg[1])
if include_defaults:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_edgeql_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4921,6 +4921,18 @@ async def test_edgeql_ddl_function_37(self):
)
self.assertEqual(val, 1)

async def test_edgeql_ddl_function_38(self):
await self.con.execute('''
create function myFuncFailA(character: int64) -> float64
using (
select 2.3
);
create function myFuncFailB(interval: str) -> float64
using (
select 2.3
);
''')

async def test_edgeql_ddl_function_rename_01(self):
await self.con.execute("""
CREATE FUNCTION foo(s: str) -> str {
Expand Down