Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions mindsdb_sql_parser/ast/select/select.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List, Union
import json
from mindsdb_sql_parser.ast.base import ASTNode
from mindsdb_sql_parser.utils import indent
Expand All @@ -7,7 +8,7 @@ class Select(ASTNode):

def __init__(self,
targets,
distinct=False,
distinct: Union[List, bool] = False,
from_table=None,
where=None,
group_by=None,
Expand All @@ -22,6 +23,8 @@ def __init__(self,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.targets = targets

# if it is list: SELECT DISTINCT ON (c1, c2) ...
self.distinct = distinct
self.from_table = from_table
self.where = where
Expand Down Expand Up @@ -105,8 +108,12 @@ def get_string(self, *args, **kwargs):

out_str += "SELECT"

if self.distinct:
if self.distinct is True:
out_str += ' DISTINCT'
elif isinstance(self.distinct, list):
distinct_str = ', '.join([c.to_string() for c in self.distinct])

out_str += f' DISTINCT ON ({distinct_str})'

targets_str = ', '.join([out.to_string() for out in self.targets])
out_str += f' {targets_str}'
Expand Down
5 changes: 5 additions & 0 deletions mindsdb_sql_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,11 @@ def select(self, p):
targets = p.result_columns
return Select(targets=targets, distinct=True)

@_('SELECT DISTINCT ON LPAREN expr_list RPAREN result_columns')
def select(self, p):
targets = p.result_columns
return Select(targets=targets, distinct=p.expr_list)

@_('SELECT result_columns')
def select(self, p):
targets = p.result_columns
Expand Down
26 changes: 26 additions & 0 deletions tests/test_base_sql/test_select_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,32 @@ def test_select_distinct(self):
assert str(parse_sql(sql)) == sql
assert parse_sql(sql).distinct

def test_select_distinct_on(self):
# single column with parts, star
sql = """SELECT DISTINCT ON (t1.column1) * FROM t1"""

expected_ast = Select(
targets=[Star()],
from_table=Identifier('t1'),
distinct=[Identifier('t1.column1')]
)
ast = parse_sql(sql)
assert str(ast) == str(expected_ast)
assert ast.to_tree() == expected_ast.to_tree()

# many columns without parts, not star

sql = """SELECT DISTINCT ON (column1, column2) column3, column4 FROM t1"""

expected_ast = Select(
targets=[Identifier('column3'), Identifier('column4')],
from_table=Identifier('t1'),
distinct=[Identifier('column1'), Identifier('column2')]
)
ast = parse_sql(sql)
assert str(ast) == str(expected_ast)
assert ast.to_tree() == expected_ast.to_tree()

def test_select_multiple_from_table(self):
sql = f'SELECT column1, column2, 1 AS renamed_constant FROM tab'
ast = parse_sql(sql)
Expand Down