Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "sqlgpt-parser"
version = "0.0.1a4"
version = "0.0.1a5"
authors = [
{ name="luliwjc", email="chenxiaoxi_wjc@163.com" },
{ name="Ifffff", email="tingkai.ztk@antgroup.com" },
Expand Down
63 changes: 49 additions & 14 deletions sqlgpt_parser/format/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,24 +421,33 @@ def visit_list_expression(self, node, unmangle_names):
else:
return "(%s)" % self._join_expressions(node.values, unmangle_names)

def visit_window_func(self, node, unmangle_names):
args = ", ".join([self.process(arg, unmangle_names) for arg in node.func_args])
ignore_null = f" {node.ignore_null} NULLS" if node.ignore_null else ""
window_spec = " OVER " + self.process(node.window_spec, unmangle_names)
return f"{node.func_name.upper()}({args}){ignore_null}{window_spec}"

def visit_window_spec(self, node, unmangle_names):
parts = []
if node.window_name is not None:
return node.window_name

parts = []
if node.partition_by:
parts.append(
"PARTITION BY "
+ self._join_expressions(node.partition_by, unmangle_names)
)
self.process(node.partition_by, unmangle_names)
if node.order_by:
parts.append("ORDER BY " + format_sort_items(node.order_by, unmangle_names))
if node.frame:
parts.append(self.process(node.frame, unmangle_names))

if node.frame_clause:
parts.append(self.process(node.frame_clause, unmangle_names))
return '(' + ' '.join(parts) + ')'

def visit_window_frame(self, node, unmangle_names):
ret = node.type + " "
def visit_partition_by_clause(self, node, unmangle_names):
return "PARTITION BY " + self._join_expressions(node.items, unmangle_names)

def visit_frame_clause(self, node, unmangle_names):
return f"{node.type} {self.process(node.frame_range, unmangle_names)}"

def visit_window_frame(self, node, unmangle_names):
ret = ""
if node.end:
ret += "BETWEEN %s AND %s" % (
self.process(node.start, unmangle_names),
Expand All @@ -449,6 +458,19 @@ def visit_window_frame(self, node, unmangle_names):

return ret

def visit_frame_bound(self, node, unmangle_names):
if node.type.upper() == "ROW":
return "CURRENT ROW"
expr = (
self.process(node.expr, unmangle_names)
if node.expr is not None
else "UNBOUNDED "
)
return f"{expr} {node.type.upper()}"

def visit_frame_expr(self, node, unmangle_names):
return self.process(node.value, unmangle_names)

def visit_single_column(self, node, indent):
format_expression(node.expression)

Expand All @@ -468,6 +490,9 @@ def visit_match_against_expression(self, node, unmangle_names):
full_text_search_modifier = full_text_search_modifier.upper()
return f"MATCH({columns}) AGAINST ({self.process(node.expr, unmangle_names)}{full_text_search_modifier})"

def visit_sound_like(self, node, unmangle_names):
return f"{self.process(node.arguments[0])} SOUNDS LIKE {self.process(node.arguments[1])}"

def _format_binary_expression(self, operator, left, right, unmangle_names):
return "%s %s %s" % (
self.process(left, unmangle_names),
Expand Down Expand Up @@ -689,13 +714,14 @@ def visit_table_subquery(self, node, indent):
return None

def visit_union(self, node, indent):
all = node.all
for i, relation in enumerate(node.relations):
self._process_relation(relation, indent)
self.builder.append("\n")
if i != len(node.relations) - 1:
if all:
if node.all:
self._append(indent, "UNION ALL")
elif node.distinct:
self._append(indent, "UNION DISTINCT")
else:
self._append(indent, "UNION")
self.builder.append("\n")
Expand All @@ -704,7 +730,12 @@ def visit_union(self, node, indent):

def visit_except(self, node, indent):
self._process_relation(node.left, indent)
self.builder.append("EXCEPT " + "ALL " if not node.distinct else "")
if node.all is not None:
self._append(indent, "EXCEPT ALL")
elif node.distinct is not None:
self._append(indent, "EXCEPT DISTINCT")
else:
self._append(indent, "EXCEPT")
self._process_relation(node.right, indent)

return None
Expand Down Expand Up @@ -756,7 +787,11 @@ def visit_intersect(self, node, indent):
relations = [
self._process_relation(relation, indent) for relation in node.relations
]
intersect = "INTERSECT " + "ALL " if not node.distinct else ""
intersect = "INTERSECT"
if node.all is not None:
intersect += " ALL"
elif node.distinct is not None:
intersect += " DISTINCT"
self.builder.append(intersect.join(relations))
return None

Expand Down
22 changes: 10 additions & 12 deletions sqlgpt_parser/parser/mysql_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ def p_alias_opt(p):
if p.slice[1].type == "alias":
p[0] = p[1]
else:
p[0] = ()
p[0] = []


def p_alias(p):
Expand All @@ -1215,9 +1215,9 @@ def p_alias(p):
| AS string_lit
| string_lit"""
if len(p) == 3:
p[0] = (p[1], p[2])
p[0] = [p[1], p[2]]
else:
p[0] = p[1]
p[0] = [p[1]]


def p_expression(p):
Expand Down Expand Up @@ -1570,7 +1570,7 @@ def p_window_func_call(p):
| ROW_NUMBER LPAREN RPAREN over_clause
"""
length = len(p)
window_spec = p[-1]
window_spec = p[length-1]
args = []
ignore_null = None

Expand Down Expand Up @@ -1711,7 +1711,10 @@ def p_frame_start(p):
| frame_expr PRECEDING
| frame_expr FOLLOWING
"""
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
if p.slice[1].type == 'frame_expr':
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
else:
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=None)


def p_frame_end(p):
Expand All @@ -1726,13 +1729,8 @@ def p_frame_between(p):

def p_frame_expr(p):
r"""frame_expr : figure
| QM
| INTERVAL expression time_unit
|"""
if len(p) == 4:
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[2], unit=p[3])
else:
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[1])
| time_interval"""
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[1])


def p_lead_lag_info_opt(p):
Expand Down
2,968 changes: 1,483 additions & 1,485 deletions sqlgpt_parser/parser/mysql_parser/parser_table.py

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions sqlgpt_parser/parser/oceanbase_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,7 +1203,7 @@ def p_alias_opt(p):
if p.slice[1].type == "alias":
p[0] = p[1]
else:
p[0] = ()
p[0] = []


def p_alias(p):
Expand All @@ -1212,9 +1212,9 @@ def p_alias(p):
| AS string_lit
| string_lit"""
if len(p) == 3:
p[0] = (p[1], p[2])
p[0] = [p[1], p[2]]
else:
p[0] = p[1]
p[0] = [p[1]]


def p_expression(p):
Expand Down Expand Up @@ -1649,7 +1649,7 @@ def p_window_func_call(p):
| ROW_NUMBER LPAREN RPAREN over_clause
"""
length = len(p)
window_spec = p[-1]
window_spec = p[length-1]
args = []
ignore_null = None

Expand Down Expand Up @@ -1790,7 +1790,10 @@ def p_frame_start(p):
| frame_expr PRECEDING
| frame_expr FOLLOWING
"""
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
if p.slice[1].type == 'frame_expr':
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
else:
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=None)


def p_frame_end(p):
Expand All @@ -1802,12 +1805,9 @@ def p_frame_between(p):
r"""frame_between : BETWEEN frame_start AND frame_end"""
p[0] = WindowFrame(p.lineno(1), p.lexpos(1), start=p[2], end=p[4])


def p_frame_expr(p):
r"""frame_expr : figure
| QM
| time_interval
|"""
| time_interval"""
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[1])


Expand Down
3,790 changes: 1,894 additions & 1,896 deletions sqlgpt_parser/parser/oceanbase_parser/parser_table.py

Large diffs are not rendered by default.

17 changes: 8 additions & 9 deletions sqlgpt_parser/parser/odps_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ def p_alias_opt(p):
if p.slice[1].type == "alias":
p[0] = p[1]
else:
p[0] = ()
p[0] = []


def p_alias(p):
Expand All @@ -1214,7 +1214,7 @@ def p_alias(p):
| AS string_lit
| string_lit"""
if len(p) == 3:
p[0] = (p[1], p[2])
p[0] = [p[1], p[2]]
else:
p[0] = p[1]

Expand Down Expand Up @@ -1657,7 +1657,7 @@ def p_window_func_call(p):
| ROW_NUMBER LPAREN RPAREN over_clause
"""
length = len(p)
window_spec = p[-1]
window_spec = p[length-1]
args = []
ignore_null = None

Expand Down Expand Up @@ -1798,8 +1798,10 @@ def p_frame_start(p):
| frame_expr PRECEDING
| frame_expr FOLLOWING
"""
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])

if p.slice[1].type == 'frame_expr':
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=p[1])
else:
p[0] = FrameBound(p.lineno(1), p.lexpos(1), type=p[2], expr=None)

def p_frame_end(p):
r"""frame_end : frame_start"""
Expand All @@ -1810,12 +1812,9 @@ def p_frame_between(p):
r"""frame_between : BETWEEN frame_start AND frame_end"""
p[0] = WindowFrame(p.lineno(1), p.lexpos(1), start=p[2], end=p[4])


def p_frame_expr(p):
r"""frame_expr : figure
| QM
| time_interval
|"""
| time_interval"""
p[0] = FrameExpr(p.lineno(1), p.lexpos(1), value=p[1])


Expand Down
3,798 changes: 1,898 additions & 1,900 deletions sqlgpt_parser/parser/odps_parser/parser_table.py

Large diffs are not rendered by default.

Loading