-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support file paths for model_config in langchain #11843
Changes from all commits
767e92e
464d885
0190276
208ab5e
f417867
c3b769b
d052868
776e425
5702feb
6ced8b2
62c38e8
5bca18b
f9ac68c
e319bb0
e8643d8
67c4d5b
6a48b9a
c7bb50a
c34d9e0
48fd538
5b10dd5
9335648
246d963
f588fa8
e73030f
4c59208
cdd34ac
8882490
a1e6f92
e235203
2c7c5dc
adeb7c5
ea4d0c4
f3ae8b7
555974b
db36e4d
03ca1f2
5a7f2de
1fc13c3
e422641
091a0b7
df5e50e
9ae5daa
5d141bf
f1d5efc
e3a74c0
600ac5d
f876a43
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,6 +55,7 @@ | |
) | ||
from mlflow.models import Model, ModelInputExample, ModelSignature, get_model_info | ||
from mlflow.models.model import MLMODEL_FILE_NAME | ||
from mlflow.models.model_config import _set_model_config | ||
from mlflow.models.signature import _infer_signature_from_input_example | ||
from mlflow.models.utils import _save_example | ||
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS | ||
|
@@ -85,7 +86,7 @@ | |
_add_code_from_conf_to_system_path, | ||
_get_flavor_configuration, | ||
_validate_and_copy_code_paths, | ||
_validate_and_copy_model_code_path, | ||
_validate_and_copy_model_code_and_config_paths, | ||
_validate_and_prepare_target_save_path, | ||
) | ||
from mlflow.utils.requirements_utils import _get_pinned_requirement | ||
|
@@ -94,6 +95,8 @@ | |
|
||
FLAVOR_NAME = "langchain" | ||
_MODEL_TYPE_KEY = "model_type" | ||
_MODEL_CODE_CONFIG = "model_config" | ||
_MODEL_CODE_PATH = "model_code_path" | ||
|
||
|
||
def get_default_pip_requirements(): | ||
|
@@ -133,6 +136,7 @@ def save_model( | |
loader_fn=None, | ||
persist_dir=None, | ||
example_no_conversion=False, | ||
model_config=None, | ||
): | ||
""" | ||
Save a LangChain model to a path on the local file system. | ||
|
@@ -225,6 +229,11 @@ def load_retriever(persist_directory): | |
|
||
See a complete example in examples/langchain/retrieval_qa_chain.py. | ||
example_no_conversion: {{ example_no_conversion }} | ||
model_config: The model configuration to apply to the model if saving model as code. This | ||
configuration is available during model loading. | ||
|
||
.. Note:: Experimental: This parameter may change or be removed in a future | ||
release without warning. | ||
""" | ||
import langchain | ||
from langchain.schema import BaseRetriever | ||
|
@@ -235,6 +244,8 @@ def load_retriever(persist_directory): | |
|
||
path = os.path.abspath(path) | ||
_validate_and_prepare_target_save_path(path) | ||
|
||
model_config_path = None | ||
model_code_path = None | ||
if isinstance(lc_model, str): | ||
# The LangChain model is defined as Python code located in the file at the path | ||
|
@@ -249,17 +260,26 @@ def load_retriever(persist_directory): | |
"file path or a databricks notebook file path containing the code for defining " | ||
"the chain instance." | ||
) | ||
if code_paths and len(code_paths) > 1: | ||
raise mlflow.MlflowException.invalid_parameter_value( | ||
"When the model is a string, and if the code_paths are specified, " | ||
"it should contain only one path." | ||
"This config path is used to set config.yml file path " | ||
"for the model. This path should be passed in via the code_paths. " | ||
f"Current code paths: {code_paths}" | ||
) | ||
|
||
if isinstance(model_config, str): | ||
if os.path.exists(model_config): | ||
model_config_path = model_config | ||
else: | ||
raise mlflow.MlflowException.invalid_parameter_value( | ||
f"If the provided model_config '{model_config}' is a string, it must be a " | ||
"valid yaml file path containing the configuration for the model." | ||
) | ||
# TODO: deal with dicts properly as well | ||
|
||
if not model_config: | ||
# If the model_config is not provided we fallback to getting the config path | ||
# from code_paths so that is backwards compatible. | ||
if code_paths and len(code_paths) == 1 and os.path.exists(code_paths[0]): | ||
model_config_path = code_paths[0] | ||
|
||
_validate_and_copy_model_code_and_config_paths(lc_model, model_config_path, path) | ||
|
||
code_dir_subpath = _validate_and_copy_code_paths(code_paths, path) | ||
model_code_dir_subpath = _validate_and_copy_model_code_path(model_code_path, path) | ||
|
||
if signature is None: | ||
if input_example is not None: | ||
|
@@ -314,19 +334,19 @@ def load_retriever(persist_directory): | |
**model_data_kwargs, | ||
} | ||
else: | ||
# TODO: use model_config instead | ||
# If the model is a string, we expect the code_path which is ideally config.yml | ||
# would be used in the model. We set the code_path here so it can be set | ||
# globally when the model is loaded with the local path. So the consumer | ||
# can use that path instead of the config.yml path when the model is loaded | ||
# TODO: what if model_config is not a string / file path? | ||
flavor_conf = ( | ||
{_CODE_CONFIG: code_paths[0], _CODE_PATH: lc_model} | ||
if code_paths and len(code_paths) >= 1 | ||
else {_CODE_CONFIG: None, _CODE_PATH: lc_model} | ||
{_MODEL_CODE_CONFIG: model_config_path, _MODEL_CODE_PATH: lc_model} | ||
if model_config_path | ||
else {_MODEL_CODE_CONFIG: None, _MODEL_CODE_PATH: lc_model} | ||
) | ||
model_data_kwargs = {} | ||
|
||
# TODO: pass model_config | ||
# TODO: Pass file paths for model_config when it is supported in pyfunc | ||
pyfunc.add_to_model( | ||
mlflow_model, | ||
loader_module="mlflow.langchain", | ||
|
@@ -335,17 +355,17 @@ def load_retriever(persist_directory): | |
code=code_dir_subpath, | ||
predict_stream_fn="predict_stream", | ||
streamable=streamable, | ||
model_code=model_code_dir_subpath, | ||
model_code_path=model_code_path, | ||
model_config=None if isinstance(model_config, str) else model_config, | ||
**model_data_kwargs, | ||
) | ||
|
||
if Version(langchain.__version__) >= Version("0.0.311"): | ||
checker_model = lc_model | ||
if isinstance(lc_model, str): | ||
# TODO: use model_config instead of code_paths[0] | ||
checker_model = ( | ||
_load_model_code_path(lc_model, code_paths[0]) | ||
if code_paths and len(code_paths) >= 1 | ||
_load_model_code_path(lc_model, model_config_path) | ||
if model_config_path | ||
else _load_model_code_path(lc_model) | ||
) | ||
|
||
|
@@ -356,7 +376,6 @@ def load_retriever(persist_directory): | |
FLAVOR_NAME, | ||
langchain_version=langchain.__version__, | ||
code=code_dir_subpath, | ||
model_code=model_code_dir_subpath, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. model code path is directly captured in the flavor_conf |
||
streamable=streamable, | ||
**flavor_conf, | ||
) | ||
|
@@ -408,6 +427,7 @@ def log_model( | |
persist_dir=None, | ||
example_no_conversion=False, | ||
run_id=None, | ||
model_config=None, | ||
): | ||
""" | ||
Log a LangChain model as an MLflow artifact for the current run. | ||
|
@@ -511,6 +531,11 @@ def load_retriever(persist_directory): | |
run_id: run_id to associate with this model version. If specified, we resume the | ||
run and log the model to that run. Otherwise, a new run is created. | ||
Default to None. | ||
model_config: The model configuration to apply to the model if saving model as code. This | ||
configuration is available during model loading. | ||
|
||
.. Note:: Experimental: This parameter may change or be removed in a future | ||
release without warning. | ||
|
||
Returns: | ||
A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the | ||
|
@@ -535,6 +560,7 @@ def load_retriever(persist_directory): | |
persist_dir=persist_dir, | ||
example_no_conversion=example_no_conversion, | ||
run_id=run_id, | ||
model_config=model_config, | ||
) | ||
|
||
|
||
|
@@ -709,7 +735,9 @@ def predict_stream( | |
Returns: | ||
An iterator of model prediction chunks. | ||
""" | ||
from mlflow.langchain.api_request_parallel_processor import process_stream_request | ||
from mlflow.langchain.api_request_parallel_processor import ( | ||
process_stream_request, | ||
) | ||
|
||
data = self._prepare_predict_stream_messages(data) | ||
return process_stream_request( | ||
|
@@ -735,7 +763,9 @@ def _predict_stream_with_callbacks( | |
Returns: | ||
An iterator of model prediction chunks. | ||
""" | ||
from mlflow.langchain.api_request_parallel_processor import process_stream_request | ||
from mlflow.langchain.api_request_parallel_processor import ( | ||
process_stream_request, | ||
) | ||
|
||
data = self._prepare_predict_stream_messages(data) | ||
return process_stream_request( | ||
|
@@ -821,7 +851,12 @@ def _load_pyfunc(path): | |
|
||
def _load_model_from_local_fs(local_model_path): | ||
flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME) | ||
if _CODE_CONFIG in flavor_conf: | ||
if _MODEL_CODE_PATH in flavor_conf: | ||
code_path = flavor_conf.get(_MODEL_CODE_PATH) | ||
config_path = flavor_conf.get(_MODEL_CODE_CONFIG, None) | ||
return _load_model_code_path(code_path, config_path) | ||
# Code for backwards compatibility, relies on RAG utils - remove in the future | ||
annzhang-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif _CODE_CONFIG in flavor_conf: | ||
path = flavor_conf.get(_CODE_CONFIG) | ||
flavor_code_config = flavor_conf.get(FLAVOR_CONFIG_CODE) | ||
if path is not None: | ||
|
@@ -877,10 +912,14 @@ def load_model(model_uri, dst_path=None): | |
|
||
@contextmanager | ||
def _config_path_context(code_path: Optional[str] = None): | ||
annzhang-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_set_model_config(code_path) | ||
# set rag utils global for backwards compatibility | ||
_set_config_path(code_path) | ||
try: | ||
yield | ||
finally: | ||
_set_model_config(None) | ||
# unset rag utils global for backwards compatibility | ||
_set_config_path(None) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,21 +3,21 @@ | |
|
||
import yaml | ||
|
||
import mlflow | ||
__mlflow_model_config__ = None | ||
|
||
|
||
# TODO: Let ModelConfig take in a dictionary instead of a file path | ||
class ModelConfig: | ||
""" | ||
ModelConfig used in code to read a YAML configuration file, and this configuration file can be | ||
overridden when logging a model. | ||
""" | ||
|
||
def __init__(self, *, development_config: Optional[str] = None): | ||
# TODO: Update global path after we pass in paths using model_config | ||
_mlflow_rag_config_path = getattr( | ||
mlflow.langchain._rag_utils, "__databricks_rag_config_path__", None | ||
) | ||
self.config_path = _mlflow_rag_config_path or development_config | ||
config = globals().get("__mlflow_model_config__", None) | ||
# backwards compatibility | ||
rag_config = globals().get("__databricks_rag_config_path__", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is this global set? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will get set when loading the model - I think if the versions are out of sync there could be some issues. Let's leave this in? I also added setting both globals for now. |
||
self.config_path = config or rag_config or development_config | ||
|
||
if not self.config_path: | ||
raise FileNotFoundError("Config file is None. Please provide a valid path.") | ||
|
@@ -50,3 +50,7 @@ def get(self, key): | |
return config_data[key] | ||
else: | ||
raise KeyError(f"Key '{key}' not found in configuration: {config_data}.") | ||
|
||
|
||
def _set_model_config(model_config): | ||
globals()["__mlflow_model_config__"] = model_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_config
param doc is lost