Skip to content

Commit

Permalink
Support 'trace.xxx' filter string in search_traces() API (#12193)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Signed-off-by: Yuki Watanabe <31463517+B-Step62@users.noreply.github.com>
Co-authored-by: Harutaka Kawamura <hkawamura0130@gmail.com>
  • Loading branch information
B-Step62 and harupy committed May 31, 2024
1 parent ba2e67f commit bd1ecf5
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 3 deletions.
27 changes: 26 additions & 1 deletion mlflow/utils/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,8 +1500,20 @@ class SearchTraceUtils(SearchUtils):
"execution_time",
}

_TAG_IDENTIFIER = "tag"
_REQUEST_METADATA_IDENTIFIER = "request_metadata"
_TAG_IDENTIFIER = "tag"
_ATTRIBUTE_IDENTIFIER = "attribute"

# These are aliases for the base identifiers
# e.g. trace.status is equivalent to attribute.status
_ALTERNATE_IDENTIFIERS = {
"tags": _TAG_IDENTIFIER,
"attributes": _ATTRIBUTE_IDENTIFIER,
"trace": _ATTRIBUTE_IDENTIFIER,
}
_IDENTIFIERS = {_TAG_IDENTIFIER, _REQUEST_METADATA_IDENTIFIER, _ATTRIBUTE_IDENTIFIER}
_VALID_IDENTIFIERS = _IDENTIFIERS | set(_ALTERNATE_IDENTIFIERS.keys())

SUPPORT_IN_COMPARISON_ATTRIBUTE_KEYS = {"name", "status", "request_id", "run_id"}

# Some search keys are defined differently in the DB models.
Expand Down Expand Up @@ -1600,6 +1612,19 @@ def is_request_metadata(cls, key_type, comparator):
return True
return False

@classmethod
def _valid_entity_type(cls, entity_type):
entity_type = cls._trim_backticks(entity_type)
if entity_type not in cls._VALID_IDENTIFIERS:
raise MlflowException(
f"Invalid entity type '{entity_type}'. Valid values are {cls._VALID_IDENTIFIERS}",
error_code=INVALID_PARAMETER_VALUE,
)
elif entity_type in cls._ALTERNATE_IDENTIFIERS:
return cls._ALTERNATE_IDENTIFIERS[entity_type]
else:
return entity_type

@classmethod
def _get_sort_key(cls, order_by_list):
order_by = []
Expand Down
35 changes: 35 additions & 0 deletions tests/store/tracking/test_file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2963,6 +2963,20 @@ def test_search_traces_filter(generate_trace_infos):
_validate_search_traces(store, [exp_id], "status LIKE 'O%'", trace_infos[:2][::-1])
_validate_search_traces(store, [exp_id], "status ILIKE 'ok'", trace_infos[:2][::-1])

# filter by status w/ attributes. or trace. prefix
_validate_search_traces(
store,
[exp_id],
"trace.status = 'ERROR'",
trace_infos[2:5][::-1],
)
_validate_search_traces(
store,
[exp_id],
"attributes.status IN ('IN_PROGRESS', 'OK')",
(trace_infos[:2] + trace_infos[5:])[::-1],
)

# filter by timestamp
for timestamp_key in ["timestamp", "timestamp_ms"]:
_validate_search_traces(store, [exp_id], f"{timestamp_key} < 10", trace_infos[:1])
Expand Down Expand Up @@ -3034,6 +3048,27 @@ def test_search_traces_filter(generate_trace_infos):
)


@pytest.mark.parametrize(
("filter_string", "error"),
[
("invalid", r"Invalid clause\(s\) in filter string"),
("name = 'foo' AND invalid", r"Invalid clause\(s\) in filter string"),
("foo.bar = 'baz'", r"Invalid entity type 'foo'"),
("invalid = 'foo'", r"Invalid attribute key 'invalid'"),
("trace.tags.foo = 'bar'", r"Invalid attribute key 'tags\.foo'"),
# TODO: This should raise
# ("trace.status < 'OK'", r"Invalid comparator '<'"),
],
)
def test_search_traces_invalid_filter(generate_trace_infos, filter_string, error):
store = generate_trace_infos.store
exp_id = generate_trace_infos.exp_id

# Invalid filter key
with pytest.raises(MlflowException, match=error):
store.search_traces([exp_id], filter_string)


def test_search_traces_order(generate_trace_infos):
trace_infos = generate_trace_infos.trace_infos
store = generate_trace_infos.store
Expand Down
31 changes: 29 additions & 2 deletions tests/store/tracking/test_sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4029,12 +4029,17 @@ def test_search_traces_order_by(store_with_traces, order_by, expected_ids):
# Search by status
("status = 'OK'", ["tr-4", "tr-3", "tr-0"]),
("status != 'OK'", ["tr-2", "tr-1"]),
("attributes.status = 'OK'", ["tr-4", "tr-3", "tr-0"]),
("attributes.name != 'aaa'", ["tr-4", "tr-3", "tr-2", "tr-0"]),
("trace.status = 'OK'", ["tr-4", "tr-3", "tr-0"]),
("trace.name != 'aaa'", ["tr-4", "tr-3", "tr-2", "tr-0"]),
# Search by timestamp
("`timestamp` >= 1 AND execution_time < 10", ["tr-2", "tr-1"]),
# Search by tag
("tags.fruit = 'apple'", ["tr-2", "tr-1"]),
("tag.fruit = 'apple'", ["tr-2", "tr-1"]),
("tag.color LIKE 're%'", ["tr-1"]),
# tags is an alias for tag
("tags.fruit = 'apple' and tags.color != 'red'", ["tr-2"]),
("tags.color LIKE 're%'", ["tr-1"]),
# Search by request metadata
("run_id = 'run0'", ["tr-0"]),
],
Expand All @@ -4053,6 +4058,28 @@ def test_search_traces_with_filter(store_with_traces, filter_string, expected_id
assert actual_ids == expected_ids


@pytest.mark.parametrize(
("filter_string", "error"),
[
("invalid", r"Invalid clause\(s\) in filter string"),
("name = 'foo' AND invalid", r"Invalid clause\(s\) in filter string"),
("foo.bar = 'baz'", r"Invalid entity type 'foo'"),
("invalid = 'foo'", r"Invalid attribute key 'invalid'"),
("trace.tags.foo = 'bar'", r"Invalid attribute key 'tags\.foo'"),
("trace.status < 'OK'", r"Invalid comparator '<'"),
],
)
def test_search_traces_with_invalid_filter(store_with_traces, filter_string, error):
exp1 = store_with_traces.get_experiment_by_name("exp1").experiment_id
exp2 = store_with_traces.get_experiment_by_name("exp2").experiment_id

with pytest.raises(MlflowException, match=error):
store_with_traces.search_traces(
experiment_ids=[exp1, exp2],
filter_string=filter_string,
)


def test_search_traces_raise_if_max_results_arg_is_invalid(store):
with pytest.raises(MlflowException, match="Invalid value for request parameter"):
store.search_traces(experiment_ids=[], max_results=50001)
Expand Down

0 comments on commit bd1ecf5

Please sign in to comment.