Skip to content

Commit

Permalink
Fix openai flavor impelementation to support formatting messages with…
Browse files Browse the repository at this point in the history
… prediction input (mlflow#8291)

* Fix openai implementation

Signed-off-by: harupy <hkawamura0130@gmail.com>

* default value

Signed-off-by: harupy <hkawamura0130@gmail.com>

* remove default

Signed-off-by: harupy <hkawamura0130@gmail.com>

* rename

Signed-off-by: harupy <hkawamura0130@gmail.com>

* reuse _has_content_and_role

Signed-off-by: harupy <hkawamura0130@gmail.com>

* rename key

Signed-off-by: harupy <hkawamura0130@gmail.com>

* improve error message

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Address comments

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Address comments

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Fix spark UDF example

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Doc fixes

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Reorganize docs

Signed-off-by: harupy <hkawamura0130@gmail.com>

---------

Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: Larry O’Brien <larry.obrien@databricks.com>
  • Loading branch information
harupy authored and Larry O’Brien committed May 10, 2023
1 parent 171a0f6 commit 4a13941
Show file tree
Hide file tree
Showing 9 changed files with 770 additions and 122 deletions.
2 changes: 2 additions & 0 deletions docs/source/python_api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ exposed in the :py:mod:`mlflow` module, so we recommend starting there.

.. toctree::
:glob:
:maxdepth: 1

*
openai/index.rst


See also the :ref:`index of all functions and classes<genindex>`.
Expand Down
File renamed without changes.
258 changes: 258 additions & 0 deletions docs/source/python_api/openai/messages.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
:orphan:

.. _mlflow.openai.messages:

Supported ``messages`` formats for OpenAI chat completion task
==============================================================

This document covers the following:

- Supported ``messages`` formats for OpenAI chat completion task in the ``openai`` flavor.
- Logged model signature for each format.
- Payload sent to OpenAI chat completion API for each format.
- Expected prediction input types for each format.


``messages`` with variables
---------------------------

The ``messages`` argument accepts a list of dictionaries with ``role`` and ``content`` keys. The
``content`` field in each message can contain variables (= named format fields). When the logged
model is loaded and makes a prediction, the variables are replaced with the values from the
prediction input.

Single variable
~~~~~~~~~~~~~~~

.. code-block:: python
import mlflow
import openai
with mlflow.start_run():
model_info = mlflow.openai.log_model(
artifact_path="model",
model="gpt-3.5-turbo",
task=openai.ChatCompletion,
messages=[
{
"role": "user",
"content": "Tell me a {adjective} joke",
# ^^^^^^^^^^
# variable
},
# Can contain more messages
],
)
model = mlflow.pyfunc.load_model(model_info.model_uri)
print(model.predict([{"adjective": "funny"}]))
Logged model signature:

.. code-block:: python
{
"inputs": [{"type": "string"}],
"outputs": [{"type": "string"}],
}
Expected prediction input types:

.. code-block:: python
# A list of dictionaries with 'adjective' key
[{"adjective": "funny"}, ...]
# A list of strings
["funny", ...]
Payload sent to OpenAI chat completion API:

.. code-block:: python
{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Tell me a funny joke",
}
],
}
Multiple variables
~~~~~~~~~~~~~~~~~~

.. code-block:: python
import mlflow
import openai
with mlflow.start_run():
model_info = mlflow.openai.log_model(
artifact_path="model",
model="gpt-3.5-turbo",
task=openai.ChatCompletion,
messages=[
{
"role": "user",
"content": "Tell me a {adjective} joke about {thing}.",
# ^^^^^^^^^^ ^^^^^^^
# variable another variable
},
# Can contain more messages
],
)
model = mlflow.pyfunc.load_model(model_info.model_uri)
print(model.predict([{"adjective": "funny", "thing": "vim"}]))
Logged model signature:

.. code-block:: python
{
"inputs": [
{"name": "adjective", "type": "string"},
{"name": "thing", "type": "string"},
],
"outputs": [{"type": "string"}],
}
Expected prediction input types:

.. code-block:: python
# A list of dictionaries with 'adjective' and 'thing' keys
[{"adjective": "funny", "thing": "vim"}, ...]
Payload sent to OpenAI chat completion API:

.. code-block:: python
{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Tell me a funny joke about vim",
}
],
}
``messages`` without variables
------------------------------

If no variables are provided, the prediction input will be _appended_ to the logged ``messages``
with ``role = user``.

.. code-block:: python
with mlflow.start_run():
model_info = mlflow.openai.log_model(
artifact_path="model",
model="gpt-3.5-turbo",
task=openai.ChatCompletion,
messages=[
{
"role": "system",
"content": "You're a frontend engineer.",
}
],
)
model = mlflow.pyfunc.load_model(model_info.model_uri)
print(model.predict(["Tell me a funny joke."]))
Logged model signature:

.. code-block:: python
{
"inputs": [{"type": "string"}],
"outputs": [{"type": "string"}],
}
Expected prediction input type:

- A list of dictionaries with a single key
- A list of strings

Payload sent to OpenAI chat completion API:

.. code-block:: python
{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "You're a frontend engineer.",
},
{
"role": "user",
"content": "Tell me a funny joke.",
},
],
}
No ``messages``
---------------

The ``messages`` argument is optional and can be omitted. If omitted, the prediction input will be
sent to the OpenAI chat completion API as-is with ``role = user``.

.. code-block:: python
import mlflow
import openai
with mlflow.start_run():
model_info = mlflow.openai.log_model(
artifact_path="model",
model="gpt-3.5-turbo",
task=openai.ChatCompletion,
)
model = mlflow.pyfunc.load_model(model_info.model_uri)
print(model.predict(["Tell me a funny joke."]))
Logged model signature:

.. code-block:: python
{
"inputs": [{"type": "string"}],
"outputs": [{"type": "string"}],
}
Expected prediction input types:

.. code-block:: python
# A list of dictionaries with a single key
[{"<any key>": "Tell me a funny joke."}, ...]
# A list of strings
["Tell me a funny joke.", ...]
Payload sent to OpenAI chat completion API:

.. code-block:: python
{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Tell me a funny joke.",
}
],
}
Loading

0 comments on commit 4a13941

Please sign in to comment.