Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support file paths for model_config in langchain #11843

Merged
merged 48 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
767e92e
initial
annzhang-db Apr 24, 2024
464d885
set chain
annzhang-db Apr 24, 2024
0190276
format
annzhang-db Apr 24, 2024
208ab5e
format again
annzhang-db Apr 24, 2024
f417867
update
annzhang-db Apr 24, 2024
c3b769b
docstring
annzhang-db Apr 24, 2024
d052868
update
annzhang-db Apr 24, 2024
776e425
update chain.py
annzhang-db Apr 24, 2024
5702feb
catch all exceptions
annzhang-db Apr 24, 2024
6ced8b2
check code_paths existence
annzhang-db Apr 25, 2024
62c38e8
tests
annzhang-db Apr 25, 2024
5bca18b
exception
annzhang-db Apr 25, 2024
f9ac68c
update
annzhang-db Apr 26, 2024
e319bb0
use model_code_dir_subpath
annzhang-db Apr 26, 2024
e8643d8
add test for different name
annzhang-db Apr 26, 2024
67c4d5b
format
annzhang-db Apr 26, 2024
6a48b9a
leave code_paths as none
annzhang-db Apr 26, 2024
c7bb50a
.py suffix
annzhang-db Apr 26, 2024
c34d9e0
rework temp file
annzhang-db Apr 26, 2024
48fd538
remove import
annzhang-db Apr 26, 2024
5b10dd5
remove set_chain
annzhang-db Apr 26, 2024
9335648
add back code_paths validation
annzhang-db Apr 26, 2024
246d963
format
annzhang-db Apr 26, 2024
f588fa8
Merge remote-tracking branch 'upstream/master' into langchain-log-model
annzhang-db Apr 26, 2024
e73030f
initial
annzhang-db Apr 27, 2024
4c59208
update
annzhang-db Apr 27, 2024
cdd34ac
update config
annzhang-db Apr 27, 2024
8882490
no pyfunc
annzhang-db Apr 27, 2024
a1e6f92
remove use of _rag_utils
annzhang-db Apr 27, 2024
e235203
move global
annzhang-db Apr 27, 2024
2c7c5dc
backwards compatible with code_paths[0]
annzhang-db Apr 27, 2024
adeb7c5
format
annzhang-db Apr 27, 2024
ea4d0c4
remove extra pyfunc code
annzhang-db Apr 27, 2024
f3ae8b7
import _set_model_config correctly
annzhang-db Apr 29, 2024
555974b
backwards compatibility
annzhang-db Apr 29, 2024
db36e4d
update chain file
annzhang-db Apr 29, 2024
03ca1f2
fix
annzhang-db Apr 29, 2024
5a7f2de
update
annzhang-db Apr 29, 2024
1fc13c3
Merge remote-tracking branch 'upstream/master' into model-config-file
annzhang-db Apr 29, 2024
e422641
fix test_model_config
annzhang-db Apr 30, 2024
091a0b7
update
annzhang-db Apr 30, 2024
df5e50e
fix
annzhang-db Apr 30, 2024
9ae5daa
format
annzhang-db Apr 30, 2024
5d141bf
address comments
annzhang-db Apr 30, 2024
f1d5efc
docstring
annzhang-db Apr 30, 2024
e3a74c0
update documentation
annzhang-db Apr 30, 2024
600ac5d
Merge remote-tracking branch 'base/master' into model-config-file
mlflow-automation Apr 30, 2024
f876a43
Autoformat: https://github.com/mlflow/mlflow/actions/runs/8898481395
mlflow-automation Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 55 additions & 24 deletions mlflow/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
_update_langchain_model_config,
patched_inference,
)
from mlflow.langchain._rag_utils import _CODE_CONFIG, _CODE_PATH, _set_config_path
from mlflow.langchain._rag_utils import _CODE_CONFIG, _CODE_PATH
from mlflow.langchain.databricks_dependencies import (
_DATABRICKS_DEPENDENCY_KEY,
_detect_databricks_dependencies,
Expand All @@ -55,6 +55,7 @@
)
from mlflow.models import Model, ModelInputExample, ModelSignature, get_model_info
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.models.model_config import _set_model_config
from mlflow.models.signature import _infer_signature_from_input_example
from mlflow.models.utils import _save_example
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
Expand Down Expand Up @@ -85,7 +86,7 @@
_add_code_from_conf_to_system_path,
_get_flavor_configuration,
_validate_and_copy_code_paths,
_validate_and_copy_model_code_path,
_validate_and_copy_model_code_and_config_paths,
_validate_and_prepare_target_save_path,
)
from mlflow.utils.requirements_utils import _get_pinned_requirement
Expand All @@ -94,6 +95,8 @@

FLAVOR_NAME = "langchain"
_MODEL_TYPE_KEY = "model_type"
_MODEL_CODE_CONFIG = "model_config"
_MODEL_CODE_PATH = "model_code_path"


def get_default_pip_requirements():
Expand Down Expand Up @@ -133,6 +136,7 @@ def save_model(
loader_fn=None,
persist_dir=None,
example_no_conversion=False,
model_config=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_config param doc is lost

):
"""
Save a LangChain model to a path on the local file system.
Expand Down Expand Up @@ -235,6 +239,8 @@ def load_retriever(persist_directory):

path = os.path.abspath(path)
_validate_and_prepare_target_save_path(path)

model_config_path = None
model_code_path = None
if isinstance(lc_model, str):
# The LangChain model is defined as Python code located in the file at the path
Expand All @@ -249,17 +255,25 @@ def load_retriever(persist_directory):
"file path or a databricks notebook file path containing the code for defining "
"the chain instance."
)
if code_paths and len(code_paths) > 1:
raise mlflow.MlflowException.invalid_parameter_value(
"When the model is a string, and if the code_paths are specified, "
"it should contain only one path."
"This config path is used to set config.yml file path "
"for the model. This path should be passed in via the code_paths. "
f"Current code paths: {code_paths}"
)

if isinstance(model_config, str):
if os.path.exists(model_config):
model_config_path = model_config
else:
raise mlflow.MlflowException.invalid_parameter_value(
f"If the provided model_config '{model_config}' is a string, it must be a "
"valid yaml file path containing the configuration for the model."
)
# TODO: deal with dicts properly as well

if not model_config:
# for backwards compatibility
annzhang-db marked this conversation as resolved.
Show resolved Hide resolved
if code_paths and len(code_paths) == 1 and os.path.exists(code_paths[0]):
model_config_path = code_paths[0]

_validate_and_copy_model_code_and_config_paths(lc_model, model_config_path, path)

code_dir_subpath = _validate_and_copy_code_paths(code_paths, path)
model_code_dir_subpath = _validate_and_copy_model_code_path(model_code_path, path)

if signature is None:
if input_example is not None:
Expand Down Expand Up @@ -319,14 +333,15 @@ def load_retriever(persist_directory):
# would be used in the model. We set the code_path here so it can be set
# globally when the model is loaded with the local path. So the consumer
# can use that path instead of the config.yml path when the model is loaded
# TODO: what if model_config is not a string / file path?
flavor_conf = (
{_CODE_CONFIG: code_paths[0], _CODE_PATH: lc_model}
if code_paths and len(code_paths) >= 1
else {_CODE_CONFIG: None, _CODE_PATH: lc_model}
{_MODEL_CODE_CONFIG: model_config_path, _MODEL_CODE_PATH: lc_model}
if model_config_path
else {_MODEL_CODE_CONFIG: None, _MODEL_CODE_PATH: lc_model}
)
model_data_kwargs = {}

# TODO: pass model_config
# TODO: Pass file paths for model_config when it is supported in pyfunc
pyfunc.add_to_model(
mlflow_model,
loader_module="mlflow.langchain",
Expand All @@ -335,7 +350,8 @@ def load_retriever(persist_directory):
code=code_dir_subpath,
predict_stream_fn="predict_stream",
streamable=streamable,
model_code=model_code_dir_subpath,
model_code=model_code_path,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think previous this was the path to the folder and now its path to the file. Do we want it to be the path of the folder similar to before?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: Weichen's comment, we decided to make this a path to the file unlike code_paths since we will only have 1) the code file and 2) the config file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_code=model_code_path,
model_code_path=model_code_path,

to keep consistent with the key name in flavor_conf

model_config=None if isinstance(model_config, str) else model_config,
**model_data_kwargs,
)

Expand All @@ -344,8 +360,8 @@ def load_retriever(persist_directory):
if isinstance(lc_model, str):
# TODO: use model_config instead of code_paths[0]
annzhang-db marked this conversation as resolved.
Show resolved Hide resolved
checker_model = (
_load_model_code_path(lc_model, code_paths[0])
if code_paths and len(code_paths) >= 1
_load_model_code_path(lc_model, model_config_path)
if model_config_path
else _load_model_code_path(lc_model)
)

Expand All @@ -356,7 +372,6 @@ def load_retriever(persist_directory):
FLAVOR_NAME,
langchain_version=langchain.__version__,
code=code_dir_subpath,
model_code=model_code_dir_subpath,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model code path is directly captured in the flavor_conf

streamable=streamable,
**flavor_conf,
)
Expand Down Expand Up @@ -408,6 +423,7 @@ def log_model(
persist_dir=None,
example_no_conversion=False,
run_id=None,
model_config=None,
):
"""
Log a LangChain model as an MLflow artifact for the current run.
Expand Down Expand Up @@ -511,6 +527,11 @@ def load_retriever(persist_directory):
run_id: run_id to associate with this model version. If specified, we resume the
run and log the model to that run. Otherwise, a new run is created.
Default to None.
model_config: The model configuration to apply to the model. This configuration
is available during model loading.

.. Note:: Experimental: This parameter may change or be removed in a future
release without warning.

Returns:
A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
Expand All @@ -535,6 +556,7 @@ def load_retriever(persist_directory):
persist_dir=persist_dir,
example_no_conversion=example_no_conversion,
run_id=run_id,
model_config=model_config,
)


Expand Down Expand Up @@ -691,7 +713,9 @@ def predict_stream(
Returns:
An iterator of model prediction chunks.
"""
from mlflow.langchain.api_request_parallel_processor import process_stream_request
from mlflow.langchain.api_request_parallel_processor import (
process_stream_request,
)

if isinstance(data, list):
raise MlflowException("LangChain model predict_stream only supports single input.")
Expand Down Expand Up @@ -720,7 +744,9 @@ def _predict_stream_with_callbacks(
Returns:
An iterator of model prediction chunks.
"""
from mlflow.langchain.api_request_parallel_processor import process_stream_request
from mlflow.langchain.api_request_parallel_processor import (
process_stream_request,
)

if isinstance(data, list):
raise MlflowException("LangChain model predict_stream only supports single input.")
Expand Down Expand Up @@ -809,7 +835,12 @@ def _load_pyfunc(path):

def _load_model_from_local_fs(local_model_path):
flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
if _CODE_CONFIG in flavor_conf:
if _MODEL_CODE_PATH in flavor_conf:
code_path = flavor_conf.get(_MODEL_CODE_PATH)
config_path = flavor_conf.get(_MODEL_CODE_CONFIG, None)
return _load_model_code_path(code_path, config_path)
# Code for backwards compatibility, relies on RAG utils - remove in the future
annzhang-db marked this conversation as resolved.
Show resolved Hide resolved
elif _CODE_CONFIG in flavor_conf:
path = flavor_conf.get(_CODE_CONFIG)
flavor_code_config = flavor_conf.get(FLAVOR_CONFIG_CODE)
if path is not None:
Expand Down Expand Up @@ -865,11 +896,11 @@ def load_model(model_uri, dst_path=None):

@contextmanager
def _config_path_context(code_path: Optional[str] = None):
annzhang-db marked this conversation as resolved.
Show resolved Hide resolved
_set_config_path(code_path)
_set_model_config(code_path)
try:
yield
finally:
_set_config_path(None)
_set_model_config(None)


# In the Python's module caching mechanism, which by default, prevents the
Expand Down
16 changes: 10 additions & 6 deletions mlflow/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@

import yaml

import mlflow
__mlflow_model_config__ = None


# TODO: Let ModelConfig take in a dictionary instead of a file path
class ModelConfig:
"""
ModelConfig used in code to read a YAML configuration file, and this configuration file can be
overridden when logging a model.
"""

def __init__(self, *, development_config: Optional[str] = None):
# TODO: Update global path after we pass in paths using model_config
_mlflow_rag_config_path = getattr(
mlflow.langchain._rag_utils, "__databricks_rag_config_path__", None
)
self.config_path = _mlflow_rag_config_path or development_config
config = globals().get("__mlflow_model_config__", None)
# backwards compatibility
rag_config = globals().get("__databricks_rag_config_path__", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this global set?
Do we need to use mlflow.langchain._rag_utils or something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will get set when loading the model - I think if the versions are out of sync there could be some issues. Let's leave this in? I also added setting both globals for now.

self.config_path = config or rag_config or development_config

if not self.config_path:
raise FileNotFoundError("Config file is None. Please provide a valid path.")
Expand Down Expand Up @@ -50,3 +50,7 @@ def get(self, key):
return config_data[key]
else:
raise KeyError(f"Key '{key}' not found in configuration: {config_data}.")


def _set_model_config(model_config):
globals()["__mlflow_model_config__"] = model_config
39 changes: 33 additions & 6 deletions mlflow/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,45 @@ def _validate_and_copy_code_paths(code_paths, path, default_subpath="code"):
return code_dir_subpath


def _validate_and_copy_model_code_path(code_path, path, default_subpath="model_code"):
def _validate_path_exists(path, name):
if path and not os.path.exists(path):
raise MlflowException(
message=(
f"Failed to copy the specified {name} path '{path}' into the model "
f"artifacts. The specified {name }path does not exist. Please specify a valid "
f"{name} path and try again."
),
error_code=INVALID_PARAMETER_VALUE,
)


def _validate_and_copy_model_code_and_config_paths(code_path, config_path, path):
"""Copies the model code from code_path to a directory.

Args:
code_path: A file containing model code that should be logged as an artifact.
config_path: A file containing model config code that should be logged as an artifact.
path: The local model path.
default_subpath: The default directory name used to store model code artifacts.
"""
if code_path:
return _validate_and_copy_code_paths([code_path], path, default_subpath)
else:
return None
_validate_path_exists(code_path, "code")
_validate_path_exists(config_path, "config")
try:
_copy_file_or_tree(src=code_path, dst=path)
if config_path:
_copy_file_or_tree(src=config_path, dst=path)
except OSError as e:
# A common error is code-paths includes Databricks Notebook. We include it in error
# message when running in Databricks, but not in other envs tp avoid confusion.
example = ", such as Databricks Notebooks" if is_in_databricks_runtime() else ""
raise MlflowException(
message=(
f"Failed to copy the specified code path '{code_path}' into the model "
"artifacts. It appears that your code path includes file(s) that cannot "
f"be copied{example}. Please specify a code path that does not include "
"such files and try again.",
),
error_code=INVALID_PARAMETER_VALUE,
) from e


def _add_code_to_system_path(code_path):
Expand Down
2 changes: 1 addition & 1 deletion tests/langchain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _llm_type(self) -> str:
return FakeChatModel(endpoint=endpoint)


config_path = mlflow.langchain._rag_utils.__databricks_rag_config_path__
config_path = mlflow.models.model_config.__mlflow_model_config__
assert os.path.exists(config_path)

with open(config_path) as f:
Expand Down