Skip to content

Commit

Permalink
Add validation for model code (#11844)
Browse files Browse the repository at this point in the history
Signed-off-by: Ann Zhang <ann.zhang@databricks.com>
  • Loading branch information
annzhang-db committed Apr 30, 2024
1 parent a990a30 commit e90aa19
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mlflow/langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import mlflow
from mlflow.exceptions import MlflowException
from mlflow.models.utils import _validate_model_code_from_notebook
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR
from mlflow.utils.class_utils import _get_class_from_string

Expand Down Expand Up @@ -329,7 +330,7 @@ def _validate_and_wrap_lc_model(lc_model, loader_fn):
w = WorkspaceClient()
response = w.workspace.export(path=lc_model, format=ExportFormat.SOURCE)
decoded_content = base64.b64decode(response.content)
# TODO: code validation
_validate_model_code_from_notebook(decoded_content)

return _get_temp_file_with_content("lc_model.py", decoded_content, "wb")
except Exception:
Expand Down
65 changes: 65 additions & 0 deletions mlflow/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import re
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -1389,3 +1390,67 @@ def convert_complex_types_pyspark_to_pandas(value, dataType):
if converter:
return converter(value)
return value


def _is_in_comment(line, start):
"""
Check if the code at the index "start" of the line is in a comment.
Limitations: This function does not handle multi-line comments, and the # symbol could be in a
string, or otherwise not indicate a comment.
"""
return "#" in line[:start]


def _is_in_string_only(line, search_string):
"""
Check is the search_string
Limitations: This function does not handle multi-line strings.
"""
# Regex for matching double quotes and everything inside
double_quotes_regex = r"\"(\\.|[^\"])*\""

# Regex for matching single quotes and everything inside
single_quotes_regex = r"\'(\\.|[^\'])*\'"

# Regex for matching search_string exactly
search_string_regex = rf"({re.escape(search_string)})"

# Concatenate the patterns using the OR operator '|'
# This will matches left to right - on quotes first, search_string last
pattern = double_quotes_regex + r"|" + single_quotes_regex + r"|" + search_string_regex

# Iterate through all matches in the line
for match in re.finditer(pattern, line):
# If the regex matched on the search_string, we know that it did not match in quotes since
# that is the order. So we know that the search_string exists outside of quotes
# (at least once).
if match.group() == search_string:
return False
return True


def _validate_model_code_from_notebook(code):
"""
Validate there isn't any code that would work in a notebook but not as exported Python file.
For now, this checks for dbutils and magic commands.
"""
error_message = (
"The model file uses 'dbutils' command which is not supported. To ensure your code "
"functions correctly, remove or comment out usage of 'dbutils' command."
)

for line in code.splitlines():
for match in re.finditer(r"\bdbutils\b", line):
start = match.start()
if not _is_in_comment(line, start) and not _is_in_string_only(line, "dbutils"):
raise ValueError(error_message)

magic_regex = r"^# MAGIC %\S+.*"
if re.search(magic_regex, code, re.MULTILINE):
_logger.warning(
"The model file uses magic commands which have been commented out. To ensure your code "
"functions correctly, make sure that it does not rely on these magic commands for."
"correctness."
)
24 changes: 24 additions & 0 deletions tests/models/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_enforce_datatype,
_enforce_object,
_enforce_property,
_validate_model_code_from_notebook,
get_model_version_from_model_uri,
)
from mlflow.types import DataType
Expand Down Expand Up @@ -429,3 +430,26 @@ def test_enforce_array_with_errors():
)
),
)


def test_model_code_validation():
invalid_code = "dbutils.library.restartPython()\nsome_python_variable = 5"

warning_code = "# dbutils.library.restartPython()\n# MAGIC %run ../wheel_installer"

valid_code = "some_valid_python_code = 'valid'"

with pytest.raises(
ValueError, match="The model file uses 'dbutils' command which is not supported."
):
_validate_model_code_from_notebook(invalid_code)

with mock.patch("mlflow.models.utils._logger.warning") as mock_warning:
_validate_model_code_from_notebook(warning_code)
mock_warning.assert_called_once_with(
"The model file uses magic commands which have been commented out. To ensure your code "
"functions correctly, make sure that it does not rely on these magic commands for."
"correctness."
)

_validate_model_code_from_notebook(valid_code)

0 comments on commit e90aa19

Please sign in to comment.