Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix IN comparison parsing #8268

Merged
merged 2 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions mlflow/utils/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,61 @@ def _ilike(string, pattern):

def _join_in_comparison_tokens(tokens):
"""
If a given list of tokens matches the pattern of an IN comparison or a NOT IN comparison,
Find a sequence of tokens that matches the pattern of an IN comparison or a NOT IN comparison,
join the tokens into a single Comparison token. Otherwise, return the original list of tokens.
"""
if Version(sqlparse.__version__) < Version("0.4.4"):
# In sqlparse < 0.4.4, IN is treated as a comparison, we don't need to join tokens
return tokens

tokens = [t for t in tokens if not t.is_whitespace]
num_tokens = len(tokens)
# IN
if num_tokens == 3:
first, second, third = tokens
non_whitespace_tokens = [t for t in tokens if not t.is_whitespace]
joined_tokens = []
num_tokens = len(non_whitespace_tokens)
iterator = enumerate(non_whitespace_tokens)
while elem := next(iterator, None):
index, first = elem
# We need at least 3 tokens to form an IN comparison or a NOT IN comparison
if num_tokens - index < 3:
joined_tokens.extend(non_whitespace_tokens[index:])
break

# Wait until we encounter an identifier token
if not isinstance(first, Identifier):
joined_tokens.append(first)
continue

(_, second) = next(iterator)
(_, third) = next(iterator)

# IN
if (
isinstance(first, Identifier)
and second.match(ttype=TokenType.Keyword, values=["IN"])
and isinstance(third, Parenthesis)
):
return [Comparison(TokenList(tokens))]
# NOT IN
elif num_tokens == 4:
first, second, third, fourth = tokens
joined_tokens.append(Comparison(TokenList([first, second, third])))
continue

(_, fourth) = next(iterator, (None, None))
if fourth is None:
joined_tokens.extend([first, second, third])
break

# NOT IN
if (
isinstance(first, Identifier)
and second.match(ttype=TokenType.Keyword, values=["NOT"])
and third.match(ttype=TokenType.Keyword, values=["IN"])
and isinstance(fourth, Parenthesis)
):
return [Comparison(TokenList([first, Token(TokenType.Keyword, "NOT IN"), fourth]))]
joined_tokens.append(
Comparison(TokenList([first, Token(TokenType.Keyword, "NOT IN"), fourth]))
)
continue

joined_tokens.extend([first, second, third, fourth])

return tokens
return joined_tokens


class SearchUtils:
Expand Down
5 changes: 5 additions & 0 deletions tests/store/model_registry/test_file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,11 @@ def search_versions(filter_string):
# search IN operator with right-hand side value containing whitespaces
assert set(search_versions(f"run_id IN ('{run_id_1}', '{run_id_2}')")) == {1, 2, 3}

# search IN operator with other conditions
assert set(
search_versions(f"version_number=2 AND run_id IN ('{run_id_1.upper()}','{run_id_2}')")
) == {2}

# search using the IN operator with bad lists should return exceptions
with pytest.raises(
MlflowException,
Expand Down
5 changes: 5 additions & 0 deletions tests/store/model_registry/test_sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,11 @@ def search_versions(filter_string, max_results=10, order_by=None, page_token=Non
# search IN operator is case sensitive
assert set(search_versions(f"run_id IN ('{run_id_1.upper()}','{run_id_2}')")) == {2, 3}

# search IN operator with other conditions
assert set(
search_versions(f"version_number=2 AND run_id IN ('{run_id_1.upper()}','{run_id_2}')")
) == {2}

# search IN operator with right-hand side value containing whitespaces
assert set(search_versions(f"run_id IN ('{run_id_1}', '{run_id_2}')")) == {1, 2, 3}

Expand Down
7 changes: 7 additions & 0 deletions tests/store/tracking/test_file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,13 @@ def test_search_runs_run_id(store):
)
assert [r.info.run_id for r in result] == [run_id2]

result = store.search_runs(
[exp_id],
filter_string=f"run_name = '{run1.info.run_name}' AND run_id IN ('{run_id1}')",
run_view_type=ViewType.ACTIVE_ONLY,
)
assert [r.info.run_id for r in result] == [run_id1]

for filter_string in [
f"attributes.run_id IN ('{run_id1}','{run_id2}')",
f"attributes.run_id IN ('{run_id1}', '{run_id2}')",
Expand Down
7 changes: 7 additions & 0 deletions tests/store/tracking/test_sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2110,6 +2110,13 @@ def test_search_runs_run_id(self):
run_view_type=ViewType.ACTIVE_ONLY,
)

result = self.store.search_runs(
[exp_id],
filter_string=f"run_name = '{run1.info.run_name}' AND run_id IN ('{run_id1}')",
run_view_type=ViewType.ACTIVE_ONLY,
)
assert [r.info.run_id for r in result] == [run_id1]

for filter_string in [
f"attributes.run_id IN ('{run_id1}','{run_id2}')",
f"attributes.run_id IN ('{run_id1}', '{run_id2}')",
Expand Down