Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions mindsdb_sql_parser/ast/mindsdb/create_database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
from mindsdb_sql_parser.ast.base import ASTNode
from mindsdb_sql_parser.utils import indent
from mindsdb_sql_parser.utils import indent, dump_json


class CreateDatabase(ASTNode):
Expand Down Expand Up @@ -49,6 +48,6 @@ def get_string(self, *args, **kwargs):

parameters_str = ''
if self.parameters:
parameters_str = f', PARAMETERS = {json.dumps(self.parameters)}'
parameters_str = f', PARAMETERS = {dump_json(self.parameters)}'
out_str = f'CREATE{replace_str} DATABASE {"IF NOT EXISTS " if self.if_not_exists else ""}{self.name.to_string()} {engine_str}{parameters_str}'
return out_str
7 changes: 3 additions & 4 deletions mindsdb_sql_parser/ast/mindsdb/create_predictor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
from mindsdb_sql_parser.ast.base import ASTNode
from mindsdb_sql_parser.utils import indent
from mindsdb_sql_parser.utils import indent, dump_json
from mindsdb_sql_parser.ast.select import Identifier
from mindsdb_sql_parser.ast.select.operation import Object

Expand Down Expand Up @@ -101,13 +100,13 @@ def get_string(self, *args, **kwargs):
for key, value in self.using.items():
if isinstance(value, Object):
args = [
f'{k}={json.dumps(v)}'
f'{k}={dump_json(v)}'
for k, v in value.params.items()
]
args_str = ', '.join(args)
value = f'{value.type}({args_str})'
else:
value = json.dumps(value)
value = dump_json(value)

using_ar.append(f'{Identifier(key).to_string()}={value}')

Expand Down
8 changes: 4 additions & 4 deletions mindsdb_sql_parser/ast/select/select.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List, Union
import json
from mindsdb_sql_parser.ast.base import ASTNode
from mindsdb_sql_parser.utils import indent
from mindsdb_sql_parser.utils import indent, dump_json
from mindsdb_sql_parser.ast.select.operation import Object


class Select(ASTNode):

def __init__(self,
Expand Down Expand Up @@ -158,15 +158,15 @@ def get_string(self, *args, **kwargs):
for key, value in self.using.items():
if isinstance(value, Object):
args = [
f'{k}={json.dumps(v)}'
f'{k}={dump_json(v)}'
for k, v in value.params.items()
]
args_str = ', '.join(args)
value = f'{value.type}({args_str})'
if isinstance(value, Identifier):
value = value.to_string()
else:
value = json.dumps(value)
value = dump_json(value)

using_ar.append(f'{Identifier(key).to_string()}={value}')

Expand Down
2 changes: 1 addition & 1 deletion mindsdb_sql_parser/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def INTEGER(self, t):
def QUOTE_STRING(self, t):
return t

@_(r'"(?:\\.|[^"])*"')
@_(r'"(?:\\.|[^"])*(?:""(?:\\.|[^"])*)*"')
def DQUOTE_STRING(self, t):
return t

Expand Down
6 changes: 3 additions & 3 deletions mindsdb_sql_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from mindsdb_sql_parser.exceptions import ParsingException
from mindsdb_sql_parser.ast.mindsdb.retrain_predictor import RetrainPredictor
from mindsdb_sql_parser.ast.mindsdb.finetune_predictor import FinetunePredictor
from mindsdb_sql_parser.utils import ensure_select_keyword_order, JoinType, tokens_to_string
from mindsdb_sql_parser.utils import ensure_select_keyword_order, JoinType, tokens_to_string, unquote
from mindsdb_sql_parser.logger import ParserLogger

from mindsdb_sql_parser.lexer import MindsDBLexer
Expand Down Expand Up @@ -2024,11 +2024,11 @@ def integer(self, p):

@_('QUOTE_STRING')
def quote_string(self, p):
return p[0].replace('\\"', '"').replace("\\'", "'").replace("''", "'").strip('\'')
return unquote(p[0]).strip('\'')

@_('DQUOTE_STRING')
def dquote_string(self, p):
return p[0].replace('\\"', '"').replace("\\'", "'").strip('\"')
return unquote(p[0], is_double_quoted=True).strip('\"')

# for raw query

Expand Down
54 changes: 54 additions & 0 deletions mindsdb_sql_parser/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,57 @@ def tokens_to_string(tokens):
# last line
content += line
return content


def unquote(s, is_double_quoted=False):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness: The new unquote() function processes escape sequences differently than the old inline implementation, potentially breaking existing code with backslash-containing strings.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance: Multiple string replacements in unquote() could be optimized

s = s.replace('\\"', '"').replace("\\'", "'")
if is_double_quoted:
s = s.replace('""', '"')
else:
s = s.replace("''", "'")
return s


def dump_json(obj) -> str:
'''
dump dict into json-like string using:
- single quotes for strings
- the same quoting rules as `unquote` function
'''


if isinstance(obj, dict):
items = []
for k, v in obj.items():
# keys must be strings in JSON
if not isinstance(k, str):
k = str(k)
items.append(f'{dump_json(k)}: {dump_json(v)}')
return "{" + ", ".join(items) + "}"

if isinstance(obj, (list, tuple)):
items = [
dump_json(i) for i in obj
]
return "[" + ", ".join(items) + "]"

if isinstance(obj, str):
obj = obj.replace("'", "''")
return f"'{obj}'"

if isinstance(obj, (int, float)):
if obj != obj: # NaN
return "null"
if obj == float('inf'):
return "null"
if obj == float('-inf'):
return "null"
return str(obj)

if obj is None:
return "null"

if isinstance(obj, bool):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Security: The dump_json() function lacks circular reference detection, making it vulnerable to infinite recursion attacks that could crash the application.

📝 Committable Code Suggestion

‼️ Ensure you review the code suggestion before committing it to the branch. Make sure it replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if isinstance(obj, bool):
def dump_json(obj, _seen=None, _depth=0, max_depth=100) -> str:
'''
Secure version of dump_json with circular reference detection and recursion limits.
Serializes Python objects to JSON with single quotes for strings.
'''
# Initialize seen set for first call
if _seen is None:
_seen = set()
# Check recursion depth
if _depth > max_depth:
raise RecursionError(f"Maximum recursion depth exceeded ({max_depth})")
# Handle None
if obj is None:
return 'null'
# Handle basic types
if isinstance(obj, (int, float, bool)):
return str(obj).lower()
# Handle strings
if isinstance(obj, str):
# Escape single quotes and backslashes
escaped = obj.replace("\\", "\\\\").replace("'", "\\'")
return f"'{escaped}'"
# Handle lists
if isinstance(obj, (list, tuple)):
items = []
for item in obj:
items.append(dump_json(item, _seen.copy(), _depth + 1, max_depth))
return f"[{', '.join(items)}]"
# Handle dictionaries
if isinstance(obj, dict):
# Check for circular references
obj_id = id(obj)
if obj_id in _seen:
raise ValueError("Circular reference detected in object")
_seen.add(obj_id)
items = []
for key, value in obj.items():
key_str = dump_json(key, _seen.copy(), _depth + 1, max_depth)
value_str = dump_json(value, _seen.copy(), _depth + 1, max_depth)
items.append(f"{key_str}: {value_str}")
return f"{{{', '.join(items)}}}"
# Handle other objects by converting to string
try:
# Limit string size to prevent DoS
obj_str = str(obj)
if len(obj_str) > 10000: # Reasonable limit
obj_str = obj_str[:10000] + "...(truncated)"
return f"'{obj_str}'"
except Exception as e:
return f"'<Object representation error: {str(e)}>'"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Security: No limits on output size or input validation for malicious objects, allowing memory exhaustion attacks through objects with malicious str methods.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style: Inconsistent docstring format in dump_json() function

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance: String concatenation in dump_json() could use list joining for better performance

return "true" if obj else "false"

return dump_json(str(obj))
29 changes: 27 additions & 2 deletions tests/test_mindsdb/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def test_create_project(self):
assert str(ast).lower() == str(expected_ast).lower()
assert ast.to_tree() == expected_ast.to_tree()


def test_create_database_using(self):

sql = "CREATE DATABASE db using ENGINE = 'mysql', PARAMETERS = {'A': 1}"
Expand All @@ -130,7 +129,6 @@ def test_create_database_using(self):
assert str(ast).lower() == str(expected_ast).lower()
assert ast.to_tree() == expected_ast.to_tree()


def test_alter_database(self):
sql = "ALTER DATABASE db PARAMETERS = {'A': 1, 'B': 2}"
ast = parse_sql(sql)
Expand All @@ -139,3 +137,30 @@ def test_alter_database(self):

assert str(ast) == str(expected_ast)
assert ast.to_tree() == expected_ast.to_tree()

def test_parser_render(self):

value = "a dm\\in123\"_.,';:!@#$%^&*()\n<>`{}[]"

'''
quoting rules:
' => '' (in single quoted strings)
" => "" (in double quoted strings)
'''
for symbol in ("'", '"'):
sql = f"""
CREATE DATABASE db WITH engine = 'postgres'
PARAMETERS = {{
'password': {symbol}{value.replace(symbol, symbol * 2)}{symbol}
}}
"""

# check parsing
query = parse_sql(sql)
assert query.parameters['password'] == value

# check render
sql2 = str(query)
query2 = parse_sql(sql2)
assert query2.parameters['password'] == value