-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement ChatModel (pyfunc subclass) #10820
Changes from all commits
cfdea18
a2c32d7
e91b96e
206d4ea
20bf01d
a74e30b
b1c03b6
f18a57e
fc274bb
c87f3dc
0d5ab57
4075860
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -227,6 +227,7 @@ | |
import yaml | ||
|
||
import mlflow | ||
import mlflow.pyfunc.loaders | ||
import mlflow.pyfunc.model | ||
from mlflow.environment_variables import ( | ||
_MLFLOW_TESTING, | ||
|
@@ -254,6 +255,7 @@ | |
RESOURCE_DOES_NOT_EXIST, | ||
) | ||
from mlflow.pyfunc.model import ( | ||
ChatModel, | ||
PythonModel, | ||
PythonModelContext, # noqa: F401 | ||
_log_warning_if_params_not_in_predict_signature, | ||
|
@@ -263,6 +265,14 @@ | |
) | ||
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS | ||
from mlflow.tracking.artifact_utils import _download_artifact_from_uri | ||
from mlflow.types.llm import ( | ||
CHAT_MODEL_INPUT_EXAMPLE, | ||
CHAT_MODEL_INPUT_SCHEMA, | ||
CHAT_MODEL_OUTPUT_SCHEMA, | ||
ChatMessage, | ||
ChatParams, | ||
ChatResponse, | ||
) | ||
from mlflow.utils import ( | ||
PYTHON_VERSION, | ||
_is_in_ipython_notebook, | ||
|
@@ -1991,6 +2001,13 @@ def predict(model_input: List[str]) -> List[str]: | |
|
||
hints = None | ||
if signature is not None: | ||
if isinstance(python_model, ChatModel): | ||
raise MlflowException( | ||
"ChatModel subclasses have a standard signature that is set " | ||
"automatically. Please remove the `signature` parameter from " | ||
"the call to log_model() or save_model().", | ||
error_code=INVALID_PARAMETER_VALUE, | ||
) | ||
mlflow_model.signature = signature | ||
elif python_model is not None: | ||
if callable(python_model): | ||
|
@@ -1999,6 +2016,25 @@ def predict(model_input: List[str]) -> List[str]: | |
python_model, input_arg_index, input_example=input_example | ||
): | ||
mlflow_model.signature = signature | ||
elif isinstance(python_model, ChatModel): | ||
mlflow_model.signature = ModelSignature( | ||
CHAT_MODEL_INPUT_SCHEMA, | ||
CHAT_MODEL_OUTPUT_SCHEMA, | ||
) | ||
input_example = CHAT_MODEL_INPUT_EXAMPLE | ||
|
||
# perform output validation and throw if | ||
# output is not coercable to ChatResponse | ||
messages = [ChatMessage(**m) for m in input_example["messages"]] | ||
params = ChatParams(**{k: v for k, v in input_example.items() if k != "messages"}) | ||
output = python_model.predict(None, messages, params) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it a problem to perform inference during saving? i saw we do it when trying to infer output signature, but since this is kind of an LLM-specific API, inference can be kind of expensive. the input example specifies if it is a concern, maybe we can just skip output validation entirely (as far as i can tell, there wouldn't be another way to ensure the return type of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there are some risks:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 We shouldn't predict while saving the model, the error message would be confusing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from discussion offline, we'll keep the predict since we do it in transformers/other places already for output signature inference. i'll do some more testing here to make sure it's not a confusing experience |
||
if not isinstance(output, ChatResponse): | ||
raise MlflowException( | ||
"Failed to save ChatModel. Please ensure that the model's predict() method " | ||
"returns a ChatResponse object. If your predict() method currently returns " | ||
"a dict, you can instantiate a ChatResponse by unpacking the output, e.g. " | ||
"`ChatResponse(**output)`", | ||
) | ||
elif isinstance(python_model, PythonModel): | ||
input_arg_index = 1 # second argument | ||
if signature := _infer_signature_from_type_hints( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
import mlflow.pyfunc.loaders.chat_model # noqa: F401 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this necessary? or does python load all files in the subdirectory into the module by default? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Python doesn't. If we want to do |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from typing import Any, Dict, Optional | ||
|
||
from mlflow.exceptions import MlflowException | ||
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR | ||
from mlflow.pyfunc.model import ( | ||
_load_context_model_and_signature, | ||
) | ||
from mlflow.types.llm import ChatMessage, ChatParams, ChatResponse | ||
from mlflow.utils.annotations import experimental | ||
|
||
|
||
def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks the same as PythonModel's _load_pyfunc function (except the wrapper it returned), could we reuse the function and extract the final class as a parameter? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. refactored the common part to |
||
context, chat_model, signature = _load_context_model_and_signature(model_path, model_config) | ||
return _ChatModelPyfuncWrapper(chat_model=chat_model, context=context, signature=signature) | ||
|
||
|
||
@experimental | ||
class _ChatModelPyfuncWrapper: | ||
""" | ||
Wrapper class that converts dict inputs to pydantic objects accepted by :class:`~ChatModel`. | ||
""" | ||
|
||
def __init__(self, chat_model, context, signature): | ||
""" | ||
Args: | ||
chat_model: An instance of a subclass of :class:`~ChatModel`. | ||
context: A :class:`~PythonModelContext` instance containing artifacts that | ||
``chat_model`` may use when performing inference. | ||
signature: :class:`~ModelSignature` instance describing model input and output. | ||
""" | ||
self.chat_model = chat_model | ||
self.context = context | ||
self.signature = signature | ||
|
||
def _convert_input(self, model_input): | ||
# model_input should be correct from signature validation, so just convert it to dict here | ||
dict_input = {key: value[0] for key, value in model_input.to_dict(orient="list").items()} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems so: https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_dict.html#pandas.DataFrame.to_dict but i'm kind of new to pandas—is there something else i should use? |
||
|
||
messages = [ChatMessage(**message) for message in dict_input.pop("messages", [])] | ||
params = ChatParams(**dict_input) | ||
|
||
return messages, params | ||
|
||
def predict( | ||
self, model_input: Dict[str, Any], params: Optional[Dict[str, Any]] = None | ||
) -> Dict[str, Any]: | ||
""" | ||
Args: | ||
model_input: Model input data in the form of a chat request. | ||
params: Additional parameters to pass to the model for inference. | ||
Unused in this implementation, as the params are handled | ||
via ``self._convert_input()``. | ||
|
||
Returns: | ||
Model predictions in :py:class:`~ChatResponse` format. | ||
""" | ||
messages, params = self._convert_input(model_input) | ||
response = self.chat_model.predict(self.context, messages, params) | ||
|
||
if not isinstance(response, ChatResponse): | ||
# shouldn't happen since there is validation at save time ensuring that | ||
# the output is a ChatResponse, so raise an exception if it isn't | ||
raise MlflowException( | ||
"Model returned an invalid response. Expected a ChatResponse, but " | ||
f"got {type(response)} instead.", | ||
error_code=INTERNAL_ERROR, | ||
) | ||
|
||
return response.to_dict() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
from mlflow.models.model import MLMODEL_FILE_NAME | ||
from mlflow.models.signature import _extract_type_hints | ||
from mlflow.tracking.artifact_utils import _download_artifact_from_uri | ||
from mlflow.types.llm import ChatMessage, ChatParams, ChatResponse | ||
from mlflow.utils.annotations import experimental | ||
from mlflow.utils.environment import ( | ||
_CONDA_ENV_FILE_NAME, | ||
|
@@ -192,6 +193,40 @@ def model_config(self): | |
return self._model_config | ||
|
||
|
||
@experimental | ||
class ChatModel(PythonModel, metaclass=ABCMeta): | ||
""" | ||
A subclass of :class:`~PythonModel` that makes it more convenient to implement models | ||
that are compatible with popular LLM chat APIs. By subclassing :class:`~ChatModel`, | ||
users can create MLflow models with a ``predict()`` method that is more convenient | ||
for chat tasks than the generic :class:`~PythonModel` API. ChatModels automatically | ||
define input/output signatures and an input example, so manually specifying these values | ||
when calling :func:`mlflow.pyfunc.save_model() <mlflow.pyfunc.save_model>` is not necessary. | ||
|
||
See the documentation of the ``predict()`` method below for details on that parameters and | ||
outputs that are expected by the ``ChatModel`` API. | ||
""" | ||
|
||
@abstractmethod | ||
def predict(self, context, messages: List[ChatMessage], params: ChatParams) -> ChatResponse: | ||
""" | ||
Evaluates a chat input and produces a chat output. | ||
|
||
Args: | ||
messages (List[:py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`]): | ||
A list of :py:class:`ChatMessage <mlflow.types.llm.ChatMessage>` | ||
objects representing chat history. | ||
params (:py:class:`ChatParams <mlflow.types.llm.ChatParams>`): | ||
A :py:class:`ChatParams <mlflow.types.llm.ChatParams>` object | ||
containing various parameters used to modify model behavior during | ||
inference. | ||
|
||
Returns: | ||
A :py:class:`ChatResponse <mlflow.types.llm.ChatResponse>` object containing | ||
the model's response(s), as well as other metadata. | ||
""" | ||
|
||
|
||
def _save_model_with_class_artifacts_params( | ||
path, | ||
python_model, | ||
|
@@ -306,7 +341,7 @@ def _save_model_with_class_artifacts_params( | |
|
||
mlflow.pyfunc.add_to_model( | ||
model=mlflow_model, | ||
loader_module=__name__, | ||
loader_module=_get_pyfunc_loader_module(python_model), | ||
code=saved_code_subpath, | ||
conda_env=_CONDA_ENV_FILE_NAME, | ||
python_env=_PYTHON_ENV_FILE_NAME, | ||
|
@@ -351,7 +386,9 @@ def _save_model_with_class_artifacts_params( | |
_PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME)) | ||
|
||
|
||
def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None): | ||
def _load_context_model_and_signature( | ||
model_path: str, model_config: Optional[Dict[str, Any]] = None | ||
): | ||
pyfunc_config = _get_flavor_configuration( | ||
model_path=model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME | ||
) | ||
|
@@ -391,6 +428,12 @@ def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None) | |
context = PythonModelContext(artifacts=artifacts, model_config=model_config) | ||
python_model.load_context(context=context) | ||
signature = mlflow.models.Model.load(model_path).signature | ||
|
||
return context, python_model, signature | ||
|
||
|
||
def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None): | ||
context, python_model, signature = _load_context_model_and_signature(model_path, model_config) | ||
return _PythonModelPyfuncWrapper( | ||
python_model=python_model, context=context, signature=signature | ||
) | ||
|
@@ -467,3 +510,9 @@ def predict(self, model_input, params: Optional[Dict[str, Any]] = None): | |
) | ||
_log_warning_if_params_not_in_predict_signature(_logger, params) | ||
return self.python_model.predict(self.context, self._convert_input(model_input)) | ||
|
||
|
||
def _get_pyfunc_loader_module(python_model): | ||
if isinstance(python_model, ChatModel): | ||
return mlflow.pyfunc.loaders.chat_model.__name__ | ||
return __name__ | ||
Comment on lines
+515
to
+518
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what do we think of adding new pyfunc loaders to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good to me. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do any validation/warning if customer specifies custom signature with ChatModel? If it doesn't comply our pydantic schema, we may want to reject here rather than at runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes that's true, i'll throw a warning to say that the signature will be overridden and that it must conform to the spec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
woah actually this brought up a bug in my implementation—if the user specifies a signature, the model actually doesn't get saved as a ChatModel due to the
elif
in line 2005 above. i guess it'selif
because this block contains a lot of validation/signature inference logic that we can skip if the user provides the signature themself. however, forChatModel
we always want to do these validations (e.g. output validation)cc @B-Step62 what do you think about raising an exception when trying to save a ChatModel subclass with a signature, e.g:
another way is making a separate block for ChatModels, e.g:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice finding! I agree with throwing. Warning on happy path can be easily overlooked and almost invisible in automated environment.