Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncate long tag and param values to maximum length #11208

Merged
merged 4 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlflow/environment_variables.py
Expand Up @@ -517,3 +517,11 @@ def get(self):
#: logging handlers and formatters.
#: (default: ``True``)
MLFLOW_CONFIGURE_LOGGING = _BooleanEnvironmentVariable("MLFLOW_LOGGING_CONFIGURE_LOGGING", True)

#: If set to True, the following entities will be truncated to their maximum length:
#: - Param value
#: - Tag value
#: If set to False, an exception will be raised if the length of the entity exceeds the maximum
#: length.
#: (default: ``True``)
MLFLOW_TRUNCATE_LONG_VALUES = _BooleanEnvironmentVariable("MLFLOW_TRUNCATE_LONG_VALUES", True)
4 changes: 2 additions & 2 deletions mlflow/store/tracking/file_store.py
Expand Up @@ -946,7 +946,7 @@ def _writeable_value(self, tag_value):

def log_param(self, run_id, param):
_validate_run_id(run_id)
_validate_param(param.key, param.value)
param = _validate_param(param.key, param.value)
run_info = self._get_run_info(run_id)
check_run_is_active(run_info)
self._log_run_param(run_info, param)
Expand Down Expand Up @@ -1046,7 +1046,7 @@ def _overwrite_run_info(self, run_info, deleted_time=None):

def log_batch(self, run_id, metrics, params, tags):
_validate_run_id(run_id)
_validate_batch_log_data(metrics, params, tags)
metrics, params, tags = _validate_batch_log_data(metrics, params, tags)
_validate_batch_log_limits(metrics, params, tags)
_validate_param_keys_unique(params)
run_info = self._get_run_info(run_id)
Expand Down
9 changes: 4 additions & 5 deletions mlflow/store/tracking/sqlalchemy_store.py
Expand Up @@ -1030,7 +1030,7 @@ def _search_datasets(self, experiment_ids):
]

def log_param(self, run_id, param):
_validate_param(param.key, param.value)
param = _validate_param(param.key, param.value)
with self.ManagedSessionMaker() as session:
run = self._get_run(run_uuid=run_id, session=session)
self._check_run_is_active(run)
Expand Down Expand Up @@ -1139,7 +1139,7 @@ def set_tag(self, run_id, tag):
tag: RunTag instance to log.
"""
with self.ManagedSessionMaker() as session:
_validate_tag(tag.key, tag.value)
tag = _validate_tag(tag.key, tag.value)
run = self._get_run(run_uuid=run_id, session=session)
self._check_run_is_active(run)
if tag.key == MLFLOW_RUN_NAME:
Expand All @@ -1160,8 +1160,7 @@ def _set_tags(self, run_id, tags):
if not tags:
return

for tag in tags:
_validate_tag(tag.key, tag.value)
tags = [_validate_tag(t.key, t.value) for t in tags]

with self.ManagedSessionMaker() as session:
run = self._get_run(run_uuid=run_id, session=session)
Expand Down Expand Up @@ -1331,7 +1330,7 @@ def compute_next_token(current_size):

def log_batch(self, run_id, metrics, params, tags):
_validate_run_id(run_id)
_validate_batch_log_data(metrics, params, tags)
metrics, params, tags = _validate_batch_log_data(metrics, params, tags)
_validate_batch_log_limits(metrics, params, tags)
_validate_param_keys_unique(params)

Expand Down
51 changes: 36 additions & 15 deletions mlflow/utils/validation.py
@@ -1,17 +1,21 @@
"""
Utilities for validating user inputs such as metric names and parameter names.
"""
import logging
import numbers
import posixpath
import re
from typing import List

from mlflow.entities import Dataset, DatasetInput, InputTag
from mlflow.entities import Dataset, DatasetInput, InputTag, Param, RunTag
from mlflow.environment_variables import MLFLOW_TRUNCATE_LONG_VALUES
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.store.db.db_types import DATABASE_ENGINES
from mlflow.utils.string_utils import is_string_type

_logger = logging.getLogger(__name__)

# Regex for valid param and metric names: may only contain slashes, alphanumerics,
# underscores, periods, dashes, and spaces.
_VALID_PARAM_AND_METRIC_NAMES = re.compile(r"^[/\w.\- ]*$")
Expand Down Expand Up @@ -172,17 +176,21 @@ def _validate_param(key, value):
isn't.
"""
_validate_param_name(key)
_validate_length_limit("Param key", MAX_ENTITY_KEY_LENGTH, key)
_validate_length_limit("Param value", MAX_PARAM_VAL_LENGTH, value)
return Param(
_validate_length_limit("Param key", MAX_ENTITY_KEY_LENGTH, key),
_validate_length_limit("Param value", MAX_PARAM_VAL_LENGTH, value, truncate=True),
)


def _validate_tag(key, value):
"""
Check that a tag with the specified key & value is valid and raise an exception if it isn't.
"""
_validate_tag_name(key)
_validate_length_limit("Tag key", MAX_ENTITY_KEY_LENGTH, key)
_validate_length_limit("Tag value", MAX_TAG_VAL_LENGTH, value)
return RunTag(
_validate_length_limit("Tag key", MAX_ENTITY_KEY_LENGTH, key),
_validate_length_limit("Tag value", MAX_TAG_VAL_LENGTH, value, truncate=True),
)


def _validate_experiment_tag(key, value):
Expand Down Expand Up @@ -270,13 +278,25 @@ def _validate_tag_name(name):
)


def _validate_length_limit(entity_name, limit, value):
if value is not None and len(value) > limit:
raise MlflowException(
f"{entity_name} '{value[:250]}' had length {len(value)}, "
f"which exceeded length limit of {limit}",
error_code=INVALID_PARAMETER_VALUE,
def _validate_length_limit(entity_name, limit, value, *, truncate=False):
if value is None:
return None

if len(value) <= limit:
return value

if truncate and MLFLOW_TRUNCATE_LONG_VALUES.get():
_logger.warning(
f"{entity_name} '{value[:100]}...' ({len(value)} characters) is truncated to "
f"{limit} characters to meet the length limit."
)
return value[:limit]

raise MlflowException(
f"{entity_name} '{value[:250]}' had length {len(value)}, "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"{entity_name} '{value[:250]}' had length {len(value)}, "
f"{entity_name} '{value[:250]...}' had length {len(value)}, "

Copy link
Member Author

@harupy harupy Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to break existing tests/code.

f"which exceeded length limit of {limit}",
error_code=INVALID_PARAMETER_VALUE,
)


def _validate_run_id(run_id):
Expand Down Expand Up @@ -317,10 +337,11 @@ def _validate_batch_log_limits(metrics, params, tags):
def _validate_batch_log_data(metrics, params, tags):
for metric in metrics:
_validate_metric(metric.key, metric.value, metric.timestamp, metric.step)
for param in params:
_validate_param(param.key, param.value)
for tag in tags:
_validate_tag(tag.key, tag.value)
return (
metrics,
[_validate_param(p.key, p.value) for p in params],
[_validate_tag(t.key, t.value) for t in tags],
)


def _validate_batch_log_api_req(json_req):
Expand Down
12 changes: 10 additions & 2 deletions tests/store/tracking/test_file_store.py
Expand Up @@ -1522,17 +1522,21 @@ def test_log_param_enforces_value_immutability(store):
assert run.data.params[param_name] == "value1"


def test_log_param_max_length_value(store):
def test_log_param_max_length_value(store, monkeypatch):
param_name = "new param"
param_value = "x" * 6000
_, exp_data, _ = _create_root(store)
run_id = exp_data[FileStore.DEFAULT_EXPERIMENT_ID]["runs"][0]
store.log_param(run_id, Param(param_name, param_value))
run = store.get_run(run_id)
assert run.data.params[param_name] == param_value
monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "false")
with pytest.raises(MlflowException, match="exceeded length"):
store.log_param(run_id, Param(param_name, "x" * 6001))

monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "true")
store.log_param(run_id, Param(param_name, "x" * 6001))


def test_weird_metric_names(store):
WEIRD_METRIC_NAME = "this is/a weird/but valid metric"
Expand Down Expand Up @@ -1798,7 +1802,7 @@ def test_log_batch(store):
_verify_logged(store, run_id, metric_entities, param_entities, tag_entities)


def test_log_batch_max_length_value(store):
def test_log_batch_max_length_value(store, monkeypatch):
param_entities = [Param("long param", "x" * 6000), Param("short param", "xyz")]
expected_param_entities = [
Param("long param", "x" * 6000),
Expand All @@ -1814,10 +1818,14 @@ def test_log_batch_max_length_value(store):
store.log_batch(run.info.run_id, (), param_entities, ())
_verify_logged(store, run.info.run_id, (), expected_param_entities, ())

monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "false")
param_entities = [Param("long param", "x" * 6001), Param("short param", "xyz")]
with pytest.raises(MlflowException, match="exceeded length"):
store.log_batch(run.info.run_id, (), param_entities, ())

monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "true")
store.log_batch(run.info.run_id, (), param_entities, ())


def test_log_batch_internal_error(store):
# Verify that internal errors during log_batch result in MlflowExceptions
Expand Down
19 changes: 16 additions & 3 deletions tests/store/tracking/test_sqlalchemy_store.py
Expand Up @@ -1154,17 +1154,21 @@ def test_log_null_param(store: SqlAlchemyStore):
reason="large string parameters are sent as TEXT/NTEXT; "
"see tests/db/compose.yml for details",
)
def test_log_param_max_length_value(store: SqlAlchemyStore):
def test_log_param_max_length_value(store: SqlAlchemyStore, monkeypatch):
run = _run_factory(store)
tkey = "blahmetric"
tval = "x" * 6000
param = entities.Param(tkey, tval)
store.log_param(run.info.run_id, param)
run = store.get_run(run.info.run_id)
assert run.data.params[tkey] == str(tval)
monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "false")
with pytest.raises(MlflowException, match="exceeded length"):
store.log_param(run.info.run_id, entities.Param(tkey, "x" * 6001))

monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "true")
store.log_param(run.info.run_id, entities.Param(tkey, "x" * 6001))


def test_set_experiment_tag(store: SqlAlchemyStore):
exp_id = _create_experiments(store, "setExperimentTagExp")
Expand Down Expand Up @@ -1206,7 +1210,7 @@ def test_set_experiment_tag(store: SqlAlchemyStore):
store.set_experiment_tag(exp_id, entities.ExperimentTag("should", "notset"))


def test_set_tag(store: SqlAlchemyStore):
def test_set_tag(store: SqlAlchemyStore, monkeypatch):
run = _run_factory(store)

tkey = "test tag"
Expand All @@ -1218,8 +1222,13 @@ def test_set_tag(store: SqlAlchemyStore):
# Overwriting tags is allowed
store.set_tag(run.info.run_id, new_tag)
# test setting tags that are too long fails.
monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "false")
with pytest.raises(MlflowException, match="exceeded length limit of 5000"):
store.set_tag(run.info.run_id, entities.RunTag("longTagKey", "a" * 5001))

monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "true")
store.set_tag(run.info.run_id, entities.RunTag("longTagKey", "a" * 5001))

# test can set tags that are somewhat long
store.set_tag(run.info.run_id, entities.RunTag("longTagKey", "a" * 4999))
run = store.get_run(run.info.run_id)
Expand Down Expand Up @@ -2747,16 +2756,20 @@ def test_log_batch_null_metrics(store: SqlAlchemyStore):
assert exception_context.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)


def test_log_batch_params_max_length_value(store: SqlAlchemyStore):
def test_log_batch_params_max_length_value(store: SqlAlchemyStore, monkeypatch):
run = _run_factory(store)
param_entities = [Param("long param", "x" * 6000), Param("short param", "xyz")]
expected_param_entities = [Param("long param", "x" * 6000), Param("short param", "xyz")]
store.log_batch(run.info.run_id, [], param_entities, [])
_verify_logged(store, run.info.run_id, [], expected_param_entities, [])
param_entities = [Param("long param", "x" * 6001)]
monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "false")
with pytest.raises(MlflowException, match="exceeded length"):
store.log_batch(run.info.run_id, [], param_entities, [])

monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "true")
store.log_batch(run.info.run_id, [], param_entities, [])


def test_upgrade_cli_idempotence(store: SqlAlchemyStore):
# Repeatedly run `mlflow db upgrade` against our database, verifying that the command
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_validation.py
Expand Up @@ -219,7 +219,7 @@ def test_validate_batch_log_limits():
_validate_batch_log_limits([], [], too_many_tags[:100])


def test_validate_batch_log_data():
def test_validate_batch_log_data(monkeypatch):
metrics_with_bad_key = [
Metric("good-metric-key", 1.0, 0, 0),
Metric("super-long-bad-key" * 1000, 4.0, 0, 0),
Expand Down Expand Up @@ -258,6 +258,7 @@ def test_validate_batch_log_data():
"tags": [tags_with_bad_key, tags_with_bad_val],
}
good_kwargs = {"metrics": [], "params": [], "tags": []}
monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "false")
for arg_name, arg_values in bad_kwargs.items():
for arg_value in arg_values:
final_kwargs = copy.deepcopy(good_kwargs)
Expand Down