Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ChatModel (pyfunc subclass) #10820

Merged
merged 12 commits into from Feb 2, 2024

Conversation

daniellok-db
Copy link
Collaborator

@daniellok-db daniellok-db commented Jan 15, 2024

馃洜 DevTools 馃洜

Open in GitHub Codespaces

Install mlflow from this PR

pip install git+https://github.com/mlflow/mlflow.git@refs/pull/10820/merge

Checkout with GitHub CLI

gh pr checkout 10820

Related Issues/PRs

What changes are proposed in this pull request?

This PR adds the ChatModel subclass to make it more seamless for users to implement and serve chat models. The ChatModel class requires users to fill out a predict method of the following type (corresponding to the OpenAI chat request format):

class MyChatModel(mlflow.pyfunc.ChatModel):
    def predict(self, context, messages: List[ChatMessage], params: ChatParams) -> ChatResponse:
        # user-defined behavior

This makes it so that the user doesn't have to implement any parsing logic, and can directly work with the pydantic objects that are passed in. Additionally, input/output signatures and an input example are automatically provided.

To support this, we implement a new custom loader for these types of models, defined in mlflow.pyfunc.loaders.chat_model. This loader wraps the ChatModel in a _ChatModelPyfuncWrapper class that accepts the standard chat request format, and breaks it up into messages and params for the user.

How is this PR tested?

  • Existing unit/integration tests
  • New unit/integration tests
  • Manual tests

Ran the following to create a chat model:

class TestChatModel(mlflow.pyfunc.ChatModel):
    def predict(self, context, messages: List[ChatMessage], params: ChatParams) -> ChatResponse:
        mock_response = {
            "id": "123",
            "object": "chat.completion",
            "created": 1677652288,
            "model": "MyChatModel",
            "choices": [
                {
                    "index": 0,
                    "message": {
                        "role": "assistant",
                        "content": json.dumps([m.model_dump(exclude_none=True) for m in messages]),
                    },
                    "finish_reason": "stop",
                },
                {
                    "index": 1,
                    "message": {
                        "role": "user",
                        "content": params.model_dump_json(exclude_none=True),
                    },
                    "finish_reason": "stop",
                },
            ],
            "usage": {
                "prompt_tokens": 10,
                "completion_tokens": 10,
                "total_tokens": 20,
            },
        }
        return ChatResponse(**mock_response)

mlflow.pyfunc.save_model(
    path="chat-model",
    python_model=TestChatModel(),
)

Then on the command line:

$ mlflow models serve -m chat-model

$ curl http://127.0.0.1:5000/invocations -H 'Content-Type: application/json' -d '{ "messages": [ { "role": "system", "content": "You are a helpful assistant" }, { "role": "user", "content": "Hello!" } ] }' | jq

{
  "id": "123",
  "object": "chat.completion",
  "created": 1677652288,
  "model": "MyChatModel",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "[{\"role\": \"system\", \"content\": \"You are a helpful assistant\"}, {\"role\": \"user\", \"content\": \"Hello!\"}]"
      },
      "finish_reason": "stop"
    },
    {
      "index": 1,
      "message": {
        "role": "user",
        "content": "{\"temperature\":1.0,\"n\":1,\"stream\":false}"
      },
      "finish_reason": "stop"
    }
  ],
  "usage": {
    "prompt_tokens": 10,
    "completion_tokens": 10,
    "total_tokens": 20
  }
}

Also tried viewing the model in MLflow UI:

Validate that the MLmodel file looks as expected
Screenshot 2024-01-15 at 12 57 49鈥疨M

Validate that the signature looks correct:

Screen.Recording.2024-01-15.at.12.57.57.PM.mov

Does this PR require documentation update?

Requires a tutorial, but we can work on this in a follow-up PR

  • No. You can skip the rest of this section.
  • Yes. I've updated:
    • Examples
    • API references
    • Instructions

Release Notes

Is this a user-facing change?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release notes for MLflow users.

Added the ChatModel pyfunc class, which allows for more convenient definition of chat models conforming to the OpenAI request/response format.

What component(s), interfaces, languages, and integrations does this PR affect?

Components

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/deployments: MLflow Deployments client APIs, server, and third-party Deployments integrations
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/recipes: Recipes, Recipe APIs, Recipe configs, Recipe Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

Interface

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow's components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

Language

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

Integrations

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

How should the PR be classified in the release notes? Choose one:

  • rn/none - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
  • rn/breaking-change - The PR will be mentioned in the "Breaking Changes" section
  • rn/feature - A new user-facing feature worth mentioning in the release notes
  • rn/bug-fix - A user-facing bug fix worth mentioning in the release notes
  • rn/documentation - A user-facing documentation change worth mentioning in the release notes

Copy link

github-actions bot commented Jan 15, 2024

Documentation preview for 4075860 will be available here when this CircleCI job completes successfully.

More info

@github-actions github-actions bot added area/models MLmodel format, model serialization/deserialization, flavors rn/feature Mention under Features in Changelogs. labels Jan 15, 2024
@dbczumar dbczumar self-requested a review January 25, 2024 02:08
mlflow/types/llm.py Outdated Show resolved Hide resolved
mlflow/types/llm.py Outdated Show resolved Hide resolved
usage: TokenUsageStats
object: str = "chat.completion"
created: int = field(default_factory=lambda: int(time.time()))
id: str = field(default_factory=lambda: str(uuid.uuid4()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any constraint for id (e.g. must start with "chatcmpl-") in OpenAI?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but I don't think we should make up random IDs here that don't have meaning. Can we leave this as None for now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed it to None! my initial thought was that if people want it to have meaning, they can specify the ID directly when instantiating ChatRequest, e.g. ChatRequest(id=meaningful_id, ...) still works, but for people who just want it to be a UUID, this saves them a couple of lines of code

# is not supported, so the code here is a little ugly.


class _BaseDataclass:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all this validation logic is mainly to support the output validation done here.

input validation shouldn't really be an issue, because it's handled by signature validation.

Comment on lines 65 to 70
:param role: The role of the entity that sent the message (e.g. ``"user"``, ``"system"``).
:type role: str
:param content: The content of the message.
:type content: str
:param name: The name of the entity that sent the message. **Optional**
:type name: str
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can unindent all of this stuff to be consistent with the rest of the codebase, but i like the way docstrings looked when they're aligned haha

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the google docstring style?

Comment on lines +506 to +518
def _get_pyfunc_loader_module(python_model):
if isinstance(python_model, ChatModel):
return mlflow.pyfunc.loaders.chat_model.__name__
return __name__
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do we think of adding new pyfunc loaders to the mlflow.pyfunc.loaders module? i think it would be a clean way for us to implement future custom loaders (e.g. for RAGModel, CompletionModel).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me.

@@ -0,0 +1 @@
import mlflow.pyfunc.loaders.chat_model # noqa: F401
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary? or does python load all files in the subdirectory into the module by default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python doesn't. If we want to do from mlflow.pyfunc.loaders import chat_model, we need this line, otherwise we don't.

# output is not coercable to ChatResponse
messages = [ChatMessage(**m) for m in input_example["messages"]]
params = ChatParams(**{k: v for k, v in input_example.items() if k != "messages"})
output = python_model.predict(None, messages, params)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it a problem to perform inference during saving? i saw we do it when trying to infer output signature, but since this is kind of an LLM-specific API, inference can be kind of expensive. the input example specifies max_tokens=10, so hopefully it isn't too bad.

if it is a concern, maybe we can just skip output validation entirely (as far as i can tell, there wouldn't be another way to ensure the return type of the predict() method is actually a ChatResponse).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are some risks:

  1. It may take a while (e.g. a few seconds) for the API request to finish.
  2. No guarantee that the LLM service is healthy. If OpenAI is down, this line would throw.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 We shouldn't predict while saving the model, the error message would be confusing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from discussion offline, we'll keep the predict since we do it in transformers/other places already for output signature inference. i'll do some more testing here to make sure it's not a confusing experience

# output is not coercable to ChatResponse
messages = [ChatMessage(**m) for m in input_example["messages"]]
params = ChatParams(**{k: v for k, v in input_example.items() if k != "messages"})
output = python_model.predict(None, messages, params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 We shouldn't predict while saving the model, the error message would be confusing.

from mlflow.utils.model_utils import _get_flavor_configuration


def _load_pyfunc(model_path: str, model_config: Optional[Dict[str, Any]] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks the same as PythonModel's _load_pyfunc function (except the wrapper it returned), could we reuse the function and extract the final class as a parameter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored the common part to _load_context_model_and_signature


def _convert_input(self, model_input):
# model_input should be correct from signature validation, so just convert it to dict here
dict_input = {key: value[0] for key, value in model_input.to_dict(orient="list").items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does to_dict accept orient param?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems so: https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_dict.html#pandas.DataFrame.to_dict

but i'm kind of new to pandas鈥攊s there something else i should use?

Comment on lines 49 to 51
elif all(isinstance(v, cls) for v in values):
pass
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif all(isinstance(v, cls) for v in values):
pass
else:
elif any(not isinstance(v, cls) for v in values):

Comment on lines 170 to 157
if not isinstance(self.message, ChatMessage):
self.message = ChatMessage(**self.message)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might encounter error if self.message is not a dictionary

self._validate_field("model", str, True)
self._convert_dataclass_list("choices", ChatChoice)
if not isinstance(self.usage, TokenUsageStats):
self.usage = TokenUsageStats(**self.usage)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed this to check for dict and throw ValueError after if the field is not an instance of the expected type

total_tokens: int

def __post_init__(self):
self._validate_field("prompt_tokens", int, True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is defining this as a required set of fields going to preclude using this interface in transformers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think it will preclude that, because we can generate these stats automatically for the user using the transformer's tokenizer. however, i can make it not required if it's a concern! it was unclear from the spec which fields are required and not.

Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
@@ -1999,6 +2009,25 @@ def predict(model_input: List[str]) -> List[str]:
python_model, input_arg_index, input_example=input_example
):
mlflow_model.signature = signature
elif isinstance(python_model, ChatModel):
Copy link
Collaborator

@B-Step62 B-Step62 Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do any validation/warning if customer specifies custom signature with ChatModel? If it doesn't comply our pydantic schema, we may want to reject here rather than at runtime.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes that's true, i'll throw a warning to say that the signature will be overridden and that it must conform to the spec

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

woah actually this brought up a bug in my implementation鈥攊f the user specifies a signature, the model actually doesn't get saved as a ChatModel due to the elif in line 2005 above. i guess it's elif because this block contains a lot of validation/signature inference logic that we can skip if the user provides the signature themself. however, for ChatModel we always want to do these validations (e.g. output validation)

cc @B-Step62 what do you think about raising an exception when trying to save a ChatModel subclass with a signature, e.g:

if signature is not None:
  if isinstance(python_model, ChatModel):
    raise MlflowException("ChatModel subclasses specify a signature automatically, please remove the provided signature from the log_model() or save_model() call.")
  mlflow_model.singature = signature
elif python_model is not None:
  # no change from this PR

another way is making a separate block for ChatModels, e.g:

if isinstance(python_model, ChatModel):
  # move ChatModel logic to this block
  ...
elif signature is not None:
  # no change
  ...
elif python_model is not None:
  # no change
  ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice finding! I agree with throwing. Warning on happy path can be easily overlooked and almost invisible in automated environment.

if isinstance(response, ChatResponse):
return response.to_dict()

# shouldn't happen since there is validation at save time ensuring that
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we raise instead? I'm not sure ignoring unexpected behavior is beneficial.


return messages, params

def predict(self, model_input: ChatRequest, params: Optional[Dict[str, Any]] = None):
Copy link
Collaborator

@B-Step62 B-Step62 Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def predict(self, model_input: ChatRequest, params: Optional[Dict[str, Any]] = None):
def predict(self, model_input: Dict[str, Any], params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:

super-nit

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes that's true haha, it won't be a ChatRequest when coming in

Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Copy link
Collaborator

@B-Step62 B-Step62 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left one very tiny comment, but otherwise LGTM! Awesome idea, it's always better to have typed object than handling dict everywhere:)

messages, params = self._convert_input(model_input)
response = self.chat_model.predict(self.context, messages, params)

if isinstance(response, ChatResponse):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


    if not isinstance(response, ChatResponse):
        raise MLflowException(...)
        
    return response.to_dict()

super-minor thing but probably more common way to structure the block

assert isinstance(response.choices[0].message, ChatMessage)


def to_dict_converts_nested_dataclasses():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def to_dict_converts_nested_dataclasses():
def test_to_dict_converts_nested_dataclasses():

assert not isinstance(response["choices"][0]["message"], ChatMessage)


def to_dict_excludes_nones():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def to_dict_excludes_nones():
def test_to_dict_excludes_nones():


def to_dict_converts_nested_dataclasses():
response = ChatResponse(**MOCK_RESPONSE).to_dict()
assert not isinstance(response["choices"][0], ChatChoice)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the expected class? dict?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup it should be dict, i guess i should just assert that haha

Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
@daniellok-db daniellok-db merged commit bf141c7 into mlflow:master Feb 2, 2024
36 checks passed
ernestwong-db pushed a commit to ernestwong-db/mlflow that referenced this pull request Feb 6, 2024
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Signed-off-by: ernestwong-db <ernest.wong@databricks.com>
lu-wang-dl pushed a commit to lu-wang-dl/mlflow that referenced this pull request Feb 6, 2024
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
Signed-off-by: lu-wang-dl <lu.wang@databricks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/models MLmodel format, model serialization/deserialization, flavors rn/feature Mention under Features in Changelogs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants