Skip to content

Commit

Permalink
Merge PR #90
Browse files Browse the repository at this point in the history
  • Loading branch information
lelit committed Sep 18, 2021
2 parents 0fc11c5 + 881854a commit ef5aa48
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
3 changes: 3 additions & 0 deletions pglast/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def workhorse(args):
split_string_literals_threshold=args.split_string_literals,
special_functions=args.special_functions,
comma_at_eoln=args.comma_at_eoln,
remove_pg_catalog_from_functions=args.remove_pg_catalog_from_functions,
semicolon_after_last_statement=args.semicolon_after_last_statement)
except Error as e:
print()
Expand Down Expand Up @@ -82,6 +83,8 @@ def main(options=None):
default=False, help="preserve comments in the statement")
parser.add_argument('-S', '--statement',
help='the SQL statement')
parser.add_argument('-F', '--remove-pg_catalog-from-functions', action='store_true', default=False,
help='remove pg_catalog from functions')
parser.add_argument('infile', nargs='?', type=argparse.FileType(),
help='a file containing the SQL statement to be pretty-printed,'
' by default stdin, when not specified with --statement option')
Expand Down
25 changes: 22 additions & 3 deletions pglast/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ class RawStream(OutputStream):
``False`` by default, when ``True`` add a semicolon after the last statement,
otherwise it is emitted only as a separator between multiple statements
:param comments: optional sequence of tuples with the comments extracted from the statement
:param bool remove_pg_catalog_from_functions:
``False`` by default, when ``True`` remove the pg_catalog schema from functions
This augments :class:`OutputStream` and implements the basic machinery needed to serialize
the *parse tree* produced by :func:`~.parser.parse_sql()` back to a textual representation,
Expand All @@ -127,7 +129,7 @@ class RawStream(OutputStream):

def __init__(self, expression_level=0, separate_statements=1, special_functions=False,
comma_at_eoln=False, semicolon_after_last_statement=False,
comments=None):
comments=None, remove_pg_catalog_from_functions=False):
super().__init__()
self.current_column = 0
self.expression_level = expression_level
Expand All @@ -136,6 +138,7 @@ def __init__(self, expression_level=0, separate_statements=1, special_functions=
self.comma_at_eoln = comma_at_eoln
self.semicolon_after_last_statement = semicolon_after_last_statement
self.comments = comments
self.remove_pg_catalog_from_functions = remove_pg_catalog_from_functions

def show(self, where=stderr): # pragma: no cover
"""Emit also current expression_level and a "pointer" showing current_column."""
Expand Down Expand Up @@ -256,7 +259,8 @@ def _concat_nodes(self, nodes, sep=' ', are_names=False):
"""

rawstream = RawStream(expression_level=self.expression_level,
special_functions=self.special_functions)
special_functions=self.special_functions,
remove_pg_catalog_from_functions=self.remove_pg_catalog_from_functions)
rawstream.print_list(nodes, sep, are_names=are_names, standalone_items=False)
return rawstream.getvalue()

Expand Down Expand Up @@ -364,10 +368,25 @@ def print_node(self, node, is_name=False, is_symbol=False):
printer(node, self)
self.separator()

def _is_pg_catalog_func(self, items):
return (
self.remove_pg_catalog_from_functions
and items.parent_attribute == 'funcname'
and len(items) > 1
and items[0].val.value == 'pg_catalog'
# The list contains all functions that cannot be found without an
# explicit pg_catalog schema. ie:
# position(a,b) is invalid but pg_catalog.position(a,b) is fine
and items[1].val.value not in ('position',)
)

def _print_items(self, items, sep, newline, are_names=False, is_symbol=False):
first = 1 if self._is_pg_catalog_func(items) else 0
last = len(items) - 1
for idx, item in enumerate(items):
if idx > 0:
if idx < first:
continue
if idx > first:
if sep == ',' and self.comma_at_eoln:
self.write(sep)
if newline:
Expand Down
31 changes: 31 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,34 @@ def test_cli_workhorse():
assert output.getvalue() == ("SELECT EXTRACT(HOUR FROM t1.modtime)\n"
" , count(*)\n"
"FROM t1\n")

in_stmt = """\
select substring('123',2,3),
regexp_split_to_array('x,x,x', ','),
btrim('xxx'), trim('xxx'),
POSITION('hour' in trim(substring('xyz hour ',1,6)))
"""

with StringIO(in_stmt) as input:
with UnclosableStream() as output:
with redirect_stdin(input), redirect_stdout(output):
main(['--compact-lists-margin', '100'])
assert output.getvalue() == """\
SELECT pg_catalog.substring('123', 2, 3)
, regexp_split_to_array('x,x,x', ',')
, btrim('xxx')
, pg_catalog.btrim('xxx')
, pg_catalog.position(pg_catalog.btrim(pg_catalog.substring('xyz hour ', 1, 6)), 'hour')
"""

with StringIO(in_stmt) as input:
with UnclosableStream() as output:
with redirect_stdin(input), redirect_stdout(output):
main(['--remove-pg_catalog-from-functions', '--compact-lists-margin', '100'])
assert output.getvalue() == """\
SELECT substring('123', 2, 3)
, regexp_split_to_array('x,x,x', ',')
, btrim('xxx')
, btrim('xxx')
, pg_catalog.position(btrim(substring('xyz hour ', 1, 6)), 'hour')
"""

0 comments on commit ef5aa48

Please sign in to comment.