Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation for model code #11844

Merged
merged 34 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
767e92e
initial
annzhang-db Apr 24, 2024
464d885
set chain
annzhang-db Apr 24, 2024
0190276
format
annzhang-db Apr 24, 2024
208ab5e
format again
annzhang-db Apr 24, 2024
f417867
update
annzhang-db Apr 24, 2024
c3b769b
docstring
annzhang-db Apr 24, 2024
d052868
update
annzhang-db Apr 24, 2024
776e425
update chain.py
annzhang-db Apr 24, 2024
5702feb
catch all exceptions
annzhang-db Apr 24, 2024
6ced8b2
check code_paths existence
annzhang-db Apr 25, 2024
62c38e8
tests
annzhang-db Apr 25, 2024
5bca18b
exception
annzhang-db Apr 25, 2024
f9ac68c
update
annzhang-db Apr 26, 2024
e319bb0
use model_code_dir_subpath
annzhang-db Apr 26, 2024
e8643d8
add test for different name
annzhang-db Apr 26, 2024
67c4d5b
format
annzhang-db Apr 26, 2024
6a48b9a
leave code_paths as none
annzhang-db Apr 26, 2024
c7bb50a
.py suffix
annzhang-db Apr 26, 2024
c34d9e0
rework temp file
annzhang-db Apr 26, 2024
48fd538
remove import
annzhang-db Apr 26, 2024
5b10dd5
remove set_chain
annzhang-db Apr 26, 2024
9335648
add back code_paths validation
annzhang-db Apr 26, 2024
246d963
format
annzhang-db Apr 26, 2024
f588fa8
Merge remote-tracking branch 'upstream/master' into langchain-log-model
annzhang-db Apr 26, 2024
94d5341
none check
annzhang-db Apr 27, 2024
22af799
format
annzhang-db Apr 27, 2024
5cd30d2
dbutils
annzhang-db Apr 27, 2024
3619537
look for magic commands
annzhang-db Apr 27, 2024
9482f09
format
annzhang-db Apr 27, 2024
315b0d4
add test
annzhang-db Apr 27, 2024
70634f0
format
annzhang-db Apr 27, 2024
8aafa87
improve regex match
annzhang-db Apr 29, 2024
53f2d79
Merge remote-tracking branch 'upstream/master' into code-validation
annzhang-db Apr 29, 2024
c87ed7d
fix test
annzhang-db Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlflow/langchain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import mlflow
from mlflow.exceptions import MlflowException
from mlflow.models.utils import _validate_model_code_from_notebook
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR
from mlflow.utils.class_utils import _get_class_from_string

Expand Down Expand Up @@ -329,7 +330,7 @@ def _validate_and_wrap_lc_model(lc_model, loader_fn):
w = WorkspaceClient()
response = w.workspace.export(path=lc_model, format=ExportFormat.SOURCE)
decoded_content = base64.b64decode(response.content)
# TODO: code validation
_validate_model_code_from_notebook(decoded_content)

return _get_temp_file_with_content("lc_model.py", decoded_content, "wb")
except Exception:
Expand Down
65 changes: 65 additions & 0 deletions mlflow/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import re
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -1389,3 +1390,67 @@ def convert_complex_types_pyspark_to_pandas(value, dataType):
if converter:
return converter(value)
return value


def _is_in_comment(line, start):
"""
Check if the code at the index "start" of the line is in a comment.

Limitations: This function does not handle multi-line comments, and the # symbol could be in a
string, or otherwise not indicate a comment.
"""
return "#" in line[:start]


def _is_in_string_only(line, search_string):
"""
Check is the search_string

Limitations: This function does not handle multi-line strings.
"""
# Regex for matching double quotes and everything inside
double_quotes_regex = r"\"(\\.|[^\"])*\""

# Regex for matching single quotes and everything inside
single_quotes_regex = r"\'(\\.|[^\'])*\'"

# Regex for matching search_string exactly
search_string_regex = rf"({re.escape(search_string)})"

# Concatenate the patterns using the OR operator '|'
# This will matches left to right - on quotes first, search_string last
pattern = double_quotes_regex + r"|" + single_quotes_regex + r"|" + search_string_regex

# Iterate through all matches in the line
for match in re.finditer(pattern, line):
# If the regex matched on the search_string, we know that it did not match in quotes since
# that is the order. So we know that the search_string exists outside of quotes
# (at least once).
if match.group() == search_string:
return False
return True


def _validate_model_code_from_notebook(code):
"""
Validate there isn't any code that would work in a notebook but not as exported Python file.
For now, this checks for dbutils and magic commands.
"""
error_message = (
"The model file uses 'dbutils' command which is not supported. To ensure your code "
"functions correctly, remove or comment out usage of 'dbutils' command."
)

for line in code.splitlines():
for match in re.finditer(r"\bdbutils\b", line):
start = match.start()
if not _is_in_comment(line, start) and not _is_in_string_only(line, "dbutils"):
raise ValueError(error_message)

magic_regex = r"^# MAGIC %\S+.*"
if re.search(magic_regex, code, re.MULTILINE):
_logger.warning(
"The model file uses magic commands which have been commented out. To ensure your code "
"functions correctly, make sure that it does not rely on these magic commands for."
"correctness."
)
32 changes: 32 additions & 0 deletions tests/models/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_enforce_datatype,
_enforce_object,
_enforce_property,
_validate_model_code_from_notebook,
get_model_version_from_model_uri,
)
from mlflow.types import DataType
Expand Down Expand Up @@ -429,3 +430,34 @@ def test_enforce_array_with_errors():
)
),
)


def test_model_code_validation():
invalid_code = """
dbutils.library.restartPython()
some_python_variable = 5
"""

warning_code = """
# dbutils.library.restartPython()
# MAGIC %run ../wheel_installer
"""

valid_code = """
some_valid_python_code = "valid"
"""

with pytest.raises(
ValueError, match="The model file uses 'dbutils' command which is not supported."
):
_validate_model_code_from_notebook(invalid_code)

with mock.patch("mlflow.models.utils._logger.warning") as mock_warning:
_validate_model_code_from_notebook(warning_code)
mock_warning.assert_called_once_with(
"The model file uses magic commands which have been commented out. To ensure your code "
"functions correctly, make sure that it does not rely on these magic commands for."
"correctness."
)

_validate_model_code_from_notebook(valid_code)