forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix openai flavor impelementation to support formatting messages with…
… prediction input (mlflow#8291) * Fix openai implementation Signed-off-by: harupy <hkawamura0130@gmail.com> * default value Signed-off-by: harupy <hkawamura0130@gmail.com> * remove default Signed-off-by: harupy <hkawamura0130@gmail.com> * rename Signed-off-by: harupy <hkawamura0130@gmail.com> * reuse _has_content_and_role Signed-off-by: harupy <hkawamura0130@gmail.com> * rename key Signed-off-by: harupy <hkawamura0130@gmail.com> * improve error message Signed-off-by: harupy <hkawamura0130@gmail.com> * Address comments Signed-off-by: harupy <hkawamura0130@gmail.com> * Address comments Signed-off-by: harupy <hkawamura0130@gmail.com> * Fix spark UDF example Signed-off-by: harupy <hkawamura0130@gmail.com> * Doc fixes Signed-off-by: harupy <hkawamura0130@gmail.com> * Reorganize docs Signed-off-by: harupy <hkawamura0130@gmail.com> --------- Signed-off-by: harupy <hkawamura0130@gmail.com> Signed-off-by: Larry O’Brien <larry.obrien@databricks.com>
- Loading branch information
Showing
9 changed files
with
770 additions
and
122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
:orphan: | ||
|
||
.. _mlflow.openai.messages: | ||
|
||
Supported ``messages`` formats for OpenAI chat completion task | ||
============================================================== | ||
|
||
This document covers the following: | ||
|
||
- Supported ``messages`` formats for OpenAI chat completion task in the ``openai`` flavor. | ||
- Logged model signature for each format. | ||
- Payload sent to OpenAI chat completion API for each format. | ||
- Expected prediction input types for each format. | ||
|
||
|
||
``messages`` with variables | ||
--------------------------- | ||
|
||
The ``messages`` argument accepts a list of dictionaries with ``role`` and ``content`` keys. The | ||
``content`` field in each message can contain variables (= named format fields). When the logged | ||
model is loaded and makes a prediction, the variables are replaced with the values from the | ||
prediction input. | ||
|
||
Single variable | ||
~~~~~~~~~~~~~~~ | ||
|
||
.. code-block:: python | ||
import mlflow | ||
import openai | ||
with mlflow.start_run(): | ||
model_info = mlflow.openai.log_model( | ||
artifact_path="model", | ||
model="gpt-3.5-turbo", | ||
task=openai.ChatCompletion, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Tell me a {adjective} joke", | ||
# ^^^^^^^^^^ | ||
# variable | ||
}, | ||
# Can contain more messages | ||
], | ||
) | ||
model = mlflow.pyfunc.load_model(model_info.model_uri) | ||
print(model.predict([{"adjective": "funny"}])) | ||
Logged model signature: | ||
|
||
.. code-block:: python | ||
{ | ||
"inputs": [{"type": "string"}], | ||
"outputs": [{"type": "string"}], | ||
} | ||
Expected prediction input types: | ||
|
||
.. code-block:: python | ||
# A list of dictionaries with 'adjective' key | ||
[{"adjective": "funny"}, ...] | ||
# A list of strings | ||
["funny", ...] | ||
Payload sent to OpenAI chat completion API: | ||
|
||
.. code-block:: python | ||
{ | ||
"model": "gpt-3.5-turbo", | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": "Tell me a funny joke", | ||
} | ||
], | ||
} | ||
Multiple variables | ||
~~~~~~~~~~~~~~~~~~ | ||
|
||
.. code-block:: python | ||
import mlflow | ||
import openai | ||
with mlflow.start_run(): | ||
model_info = mlflow.openai.log_model( | ||
artifact_path="model", | ||
model="gpt-3.5-turbo", | ||
task=openai.ChatCompletion, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Tell me a {adjective} joke about {thing}.", | ||
# ^^^^^^^^^^ ^^^^^^^ | ||
# variable another variable | ||
}, | ||
# Can contain more messages | ||
], | ||
) | ||
model = mlflow.pyfunc.load_model(model_info.model_uri) | ||
print(model.predict([{"adjective": "funny", "thing": "vim"}])) | ||
Logged model signature: | ||
|
||
.. code-block:: python | ||
{ | ||
"inputs": [ | ||
{"name": "adjective", "type": "string"}, | ||
{"name": "thing", "type": "string"}, | ||
], | ||
"outputs": [{"type": "string"}], | ||
} | ||
Expected prediction input types: | ||
|
||
.. code-block:: python | ||
# A list of dictionaries with 'adjective' and 'thing' keys | ||
[{"adjective": "funny", "thing": "vim"}, ...] | ||
Payload sent to OpenAI chat completion API: | ||
|
||
.. code-block:: python | ||
{ | ||
"model": "gpt-3.5-turbo", | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": "Tell me a funny joke about vim", | ||
} | ||
], | ||
} | ||
``messages`` without variables | ||
------------------------------ | ||
|
||
If no variables are provided, the prediction input will be _appended_ to the logged ``messages`` | ||
with ``role = user``. | ||
|
||
.. code-block:: python | ||
with mlflow.start_run(): | ||
model_info = mlflow.openai.log_model( | ||
artifact_path="model", | ||
model="gpt-3.5-turbo", | ||
task=openai.ChatCompletion, | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "You're a frontend engineer.", | ||
} | ||
], | ||
) | ||
model = mlflow.pyfunc.load_model(model_info.model_uri) | ||
print(model.predict(["Tell me a funny joke."])) | ||
Logged model signature: | ||
|
||
.. code-block:: python | ||
{ | ||
"inputs": [{"type": "string"}], | ||
"outputs": [{"type": "string"}], | ||
} | ||
Expected prediction input type: | ||
|
||
- A list of dictionaries with a single key | ||
- A list of strings | ||
|
||
Payload sent to OpenAI chat completion API: | ||
|
||
.. code-block:: python | ||
{ | ||
"model": "gpt-3.5-turbo", | ||
"messages": [ | ||
{ | ||
"role": "system", | ||
"content": "You're a frontend engineer.", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "Tell me a funny joke.", | ||
}, | ||
], | ||
} | ||
No ``messages`` | ||
--------------- | ||
|
||
The ``messages`` argument is optional and can be omitted. If omitted, the prediction input will be | ||
sent to the OpenAI chat completion API as-is with ``role = user``. | ||
|
||
.. code-block:: python | ||
import mlflow | ||
import openai | ||
with mlflow.start_run(): | ||
model_info = mlflow.openai.log_model( | ||
artifact_path="model", | ||
model="gpt-3.5-turbo", | ||
task=openai.ChatCompletion, | ||
) | ||
model = mlflow.pyfunc.load_model(model_info.model_uri) | ||
print(model.predict(["Tell me a funny joke."])) | ||
Logged model signature: | ||
|
||
.. code-block:: python | ||
{ | ||
"inputs": [{"type": "string"}], | ||
"outputs": [{"type": "string"}], | ||
} | ||
Expected prediction input types: | ||
|
||
.. code-block:: python | ||
# A list of dictionaries with a single key | ||
[{"<any key>": "Tell me a funny joke."}, ...] | ||
# A list of strings | ||
["Tell me a funny joke.", ...] | ||
Payload sent to OpenAI chat completion API: | ||
|
||
.. code-block:: python | ||
{ | ||
"model": "gpt-3.5-turbo", | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": "Tell me a funny joke.", | ||
} | ||
], | ||
} | ||
Oops, something went wrong.