Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation for model code #11844

Merged
merged 34 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
767e92e
initial
annzhang-db Apr 24, 2024
464d885
set chain
annzhang-db Apr 24, 2024
0190276
format
annzhang-db Apr 24, 2024
208ab5e
format again
annzhang-db Apr 24, 2024
f417867
update
annzhang-db Apr 24, 2024
c3b769b
docstring
annzhang-db Apr 24, 2024
d052868
update
annzhang-db Apr 24, 2024
776e425
update chain.py
annzhang-db Apr 24, 2024
5702feb
catch all exceptions
annzhang-db Apr 24, 2024
6ced8b2
check code_paths existence
annzhang-db Apr 25, 2024
62c38e8
tests
annzhang-db Apr 25, 2024
5bca18b
exception
annzhang-db Apr 25, 2024
f9ac68c
update
annzhang-db Apr 26, 2024
e319bb0
use model_code_dir_subpath
annzhang-db Apr 26, 2024
e8643d8
add test for different name
annzhang-db Apr 26, 2024
67c4d5b
format
annzhang-db Apr 26, 2024
6a48b9a
leave code_paths as none
annzhang-db Apr 26, 2024
c7bb50a
.py suffix
annzhang-db Apr 26, 2024
c34d9e0
rework temp file
annzhang-db Apr 26, 2024
48fd538
remove import
annzhang-db Apr 26, 2024
5b10dd5
remove set_chain
annzhang-db Apr 26, 2024
9335648
add back code_paths validation
annzhang-db Apr 26, 2024
246d963
format
annzhang-db Apr 26, 2024
f588fa8
Merge remote-tracking branch 'upstream/master' into langchain-log-model
annzhang-db Apr 26, 2024
94d5341
none check
annzhang-db Apr 27, 2024
22af799
format
annzhang-db Apr 27, 2024
5cd30d2
dbutils
annzhang-db Apr 27, 2024
3619537
look for magic commands
annzhang-db Apr 27, 2024
9482f09
format
annzhang-db Apr 27, 2024
315b0d4
add test
annzhang-db Apr 27, 2024
70634f0
format
annzhang-db Apr 27, 2024
8aafa87
improve regex match
annzhang-db Apr 29, 2024
53f2d79
Merge remote-tracking branch 'upstream/master' into code-validation
annzhang-db Apr 29, 2024
c87ed7d
fix test
annzhang-db Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 17 additions & 8 deletions mlflow/langchain/__init__.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR includes changes from #11817. Look here for clean diff: annzhang-db#2

Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@
from mlflow.utils.file_utils import get_total_file_size, write_to
from mlflow.utils.model_utils import (
FLAVOR_CONFIG_CODE,
FLAVOR_CONFIG_MODEL_CODE,
_add_code_from_conf_to_system_path,
_get_flavor_configuration,
_validate_and_copy_code_paths,
_validate_and_copy_model_code_path,
_validate_and_prepare_target_save_path,
)
from mlflow.utils.requirements_utils import _get_pinned_requirement
Expand Down Expand Up @@ -233,21 +235,21 @@ def load_retriever(persist_directory):

path = os.path.abspath(path)
_validate_and_prepare_target_save_path(path)
formatted_code_path = code_paths[:] if code_paths else []
model_code_path = None
if isinstance(lc_model, str):
# The LangChain model is defined as Python code located in the file at the path
# specified by `lc_model`. Verify that the path exists and, if so, copy it to the
# model directory along with any other specified code modules

if os.path.exists(lc_model):
formatted_code_path.append(lc_model)
model_code_path = lc_model
else:
raise mlflow.MlflowException.invalid_parameter_value(
f"If the {lc_model} is a string, it must be a valid python "
"file path containing the code for defining the chain instance."
f"If the provided model '{lc_model}' is a string, it must be a valid python "
"file path or a databricks notebook file path containing the code for defining "
"the chain instance."
)

if len(code_paths) > 1:
if code_paths and len(code_paths) > 1:
annzhang-db marked this conversation as resolved.
Show resolved Hide resolved
raise mlflow.MlflowException.invalid_parameter_value(
"When the model is a string, and if the code_paths are specified, "
"it should contain only one path."
Expand All @@ -256,7 +258,8 @@ def load_retriever(persist_directory):
f"Current code paths: {code_paths}"
)

code_dir_subpath = _validate_and_copy_code_paths(formatted_code_path, path)
code_dir_subpath = _validate_and_copy_code_paths(code_paths, path)
model_code_dir_subpath = _validate_and_copy_model_code_path(model_code_path, path)

if signature is None:
annzhang-db marked this conversation as resolved.
Show resolved Hide resolved
if input_example is not None:
Expand Down Expand Up @@ -311,6 +314,7 @@ def load_retriever(persist_directory):
**model_data_kwargs,
}
else:
# TODO: use model_config instead
# If the model is a string, we expect the code_path which is ideally config.yml
# would be used in the model. We set the code_path here so it can be set
# globally when the model is loaded with the local path. So the consumer
Expand All @@ -322,6 +326,7 @@ def load_retriever(persist_directory):
)
model_data_kwargs = {}

# TODO: pass model_config
pyfunc.add_to_model(
mlflow_model,
loader_module="mlflow.langchain",
Expand All @@ -330,12 +335,14 @@ def load_retriever(persist_directory):
code=code_dir_subpath,
predict_stream_fn="predict_stream",
streamable=streamable,
model_code=model_code_dir_subpath,
**model_data_kwargs,
)

if Version(langchain.__version__) >= Version("0.0.311"):
checker_model = lc_model
if isinstance(lc_model, str):
# TODO: use model_config instead of code_paths[0]
checker_model = (
_load_model_code_path(lc_model, code_paths[0])
annzhang-db marked this conversation as resolved.
Show resolved Hide resolved
if code_paths and len(code_paths) >= 1
Expand All @@ -349,6 +356,7 @@ def load_retriever(persist_directory):
FLAVOR_NAME,
langchain_version=langchain.__version__,
code=code_dir_subpath,
model_code=model_code_dir_subpath,
streamable=streamable,
**flavor_conf,
)
Expand Down Expand Up @@ -814,9 +822,10 @@ def _load_model_from_local_fs(local_model_path):
config_path = None

flavor_code_path = flavor_conf.get(_CODE_PATH, "chain.py")
flavor_model_code_config = flavor_conf.get(FLAVOR_CONFIG_MODEL_CODE)
code_path = os.path.join(
local_model_path,
flavor_code_config,
flavor_model_code_config,
os.path.basename(flavor_code_path),
)

Expand Down
64 changes: 59 additions & 5 deletions mlflow/langchain/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Utility functions for mlflow.langchain."""

import base64
import contextlib
import importlib
import json
import logging
import os
import re
import shutil
import tempfile
import types
import warnings
from functools import lru_cache
Expand All @@ -20,6 +22,7 @@

import mlflow
from mlflow.exceptions import MlflowException
from mlflow.models.utils import _validate_model_code_from_notebook
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR
from mlflow.utils.class_utils import _get_class_from_string

Expand Down Expand Up @@ -264,12 +267,41 @@ def safe_import_and_add(module_name, class_name):
for llm_name in ["Databricks", "Mlflow"]:
try_adding_llm(langchain.llms, llm_name)

for chat_model_name in ["ChatDatabricks", "ChatMlflow", "ChatOpenAI", "AzureChatOpenAI"]:
for chat_model_name in [
"ChatDatabricks",
"ChatMlflow",
"ChatOpenAI",
"AzureChatOpenAI",
]:
try_adding_llm(langchain.chat_models, chat_model_name)

return supported_llms


def _get_temp_file_with_content(file_name: str, content: str, content_format) -> str:
"""
Write the contents to a temporary file and return the path to that file.

Args:
file_name: The name of the file to be created.
content: The contents to be written to the file.

Returns:
The string path to the file where the chain model is build.
"""
# Get the temporary directory path
temp_dir = tempfile.gettempdir()

# Construct the full path where the temporary file will be created
temp_file_path = os.path.join(temp_dir, file_name)

# Create and write to the file
with open(temp_file_path, content_format) as tmp_file:
tmp_file.write(content)

return temp_file_path


def _validate_and_wrap_lc_model(lc_model, loader_fn):
import langchain.agents.agent
import langchain.chains.base
Expand All @@ -278,13 +310,35 @@ def _validate_and_wrap_lc_model(lc_model, loader_fn):
import langchain.llms.openai
import langchain.schema

# lc_model is a file path
if isinstance(lc_model, str):
if os.path.basename(os.path.abspath(lc_model)) != "chain.py":
if not os.path.exists(lc_model):
raise mlflow.MlflowException.invalid_parameter_value(
f"If {lc_model} is a string, it must be the path to a file "
"named `chain.py` on the local filesystem."
f"If the provided model '{lc_model}' is a string, it must be a valid python "
"file path or a databricks notebook file path containing the code for defining "
"the chain instance."
)
return lc_model

try:
with open(lc_model) as _:
return lc_model
except Exception:
try:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.workspace import ExportFormat

w = WorkspaceClient()
response = w.workspace.export(path=lc_model, format=ExportFormat.SOURCE)
decoded_content = base64.b64decode(response.content)
_validate_model_code_from_notebook(decoded_content)

return _get_temp_file_with_content("lc_model.py", decoded_content, "wb")
except Exception:
raise mlflow.MlflowException.invalid_parameter_value(
f"If the provided model '{lc_model}' is a string, it must be a valid python "
"file path or a databricks notebook file path containing the code for defining "
"the chain instance."
annzhang-db marked this conversation as resolved.
Show resolved Hide resolved
)

if not isinstance(lc_model, supported_lc_types()):
raise mlflow.MlflowException.invalid_parameter_value(
Expand Down
65 changes: 65 additions & 0 deletions mlflow/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import re
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -1389,3 +1390,67 @@ def convert_complex_types_pyspark_to_pandas(value, dataType):
if converter:
return converter(value)
return value


def _is_in_comment(line, start):
"""
Check if the code at the index "start" of the line is in a comment.

Limitations: This function does not handle multi-line comments, and the # symbol could be in a
string, or otherwise not indicate a comment.
"""
return "#" in line[:start]


def _is_in_string_only(line, search_string):
"""
Check is the search_string

Limitations: This function does not handle multi-line strings.
"""
# Regex for matching double quotes and everything inside
double_quotes_regex = r"\"(\\.|[^\"])*\""

# Regex for matching single quotes and everything inside
single_quotes_regex = r"\'(\\.|[^\'])*\'"

# Regex for matching search_string exactly
search_string_regex = rf"({re.escape(search_string)})"

# Concatenate the patterns using the OR operator '|'
# This will matches left to right - on quotes first, search_string last
pattern = double_quotes_regex + r"|" + single_quotes_regex + r"|" + search_string_regex

# Iterate through all matches in the line
for match in re.finditer(pattern, line):
# If the regex matched on the search_string, we know that it did not match in quotes since
# that is the order. So we know that the search_string exists outside of quotes
# (at least once).
if match.group() == search_string:
return False
return True


def _validate_model_code_from_notebook(code):
"""
Validate there isn't any code that would work in a notebook but not as exported Python file.
For now, this checks for dbutils and magic commands.
"""
error_message = (
"The model file uses 'dbutils' command which is not supported. To ensure your code "
"functions correctly, remove or comment out usage of 'dbutils' command."
)

for line in code.splitlines():
for match in re.finditer(r"\bdbutils\b", line):
start = match.start()
if not _is_in_comment(line, start) and not _is_in_string_only(line, "dbutils"):
raise ValueError(error_message)

magic_regex = r"# MAGIC %\S+.*"
if re.search(magic_regex, code):
annzhang-db marked this conversation as resolved.
Show resolved Hide resolved
_logger.warning(
"The model file uses magic commands which have been commented out. To ensure your code "
"functions correctly, make sure that it does not rely on these magic commands for."
"correctness."
)
20 changes: 19 additions & 1 deletion mlflow/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mlflow.utils.uri import append_to_uri_path

FLAVOR_CONFIG_CODE = "code"
FLAVOR_CONFIG_MODEL_CODE = "model_code"


def _get_all_flavor_configurations(model_path):
Expand Down Expand Up @@ -162,6 +163,20 @@ def _validate_and_copy_code_paths(code_paths, path, default_subpath="code"):
return code_dir_subpath


def _validate_and_copy_model_code_path(code_path, path, default_subpath="model_code"):
annzhang-db marked this conversation as resolved.
Show resolved Hide resolved
"""Copies the model code from code_path to a directory.

Args:
code_path: A file containing model code that should be logged as an artifact.
path: The local model path.
default_subpath: The default directory name used to store model code artifacts.
"""
if code_path:
return _validate_and_copy_code_paths([code_path], path, default_subpath)
else:
return None


def _add_code_to_system_path(code_path):
sys.path = [code_path] + sys.path

Expand Down Expand Up @@ -217,7 +232,10 @@ def _validate_onnx_session_options(onnx_session_options):
f"Value for key {key} in onnx_session_options should be a dict, "
"not {type(value)}"
)
elif key == "execution_mode" and value.upper() not in ["PARALLEL", "SEQUENTIAL"]:
elif key == "execution_mode" and value.upper() not in [
"PARALLEL",
"SEQUENTIAL",
]:
raise ValueError(
f"Value for key {key} in onnx_session_options should be "
f"'parallel' or 'sequential', not {value}"
Expand Down