## Build a tool-calling model with `mlflow.pyfunc.ChatModel`

<a href="https://raw.githubusercontent.com/mlflow/mlflow/master/docs/source/llms/notebooks/chat-model-tool-calling.ipynb" class="notebook-download-btn"><i class="fas fa-download"></i>Download this Notebook</a>

Welcome to the notebook tutorial on building a simple tool calling model using the [`mlflow.pyfunc.ChatModel`](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ChatModel) wrapper. ChatModel is a subclass of MLflow's highly customizable [PythonModel](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.PythonModel), which was specifically designed to make GenAI workflows easier.

Briefly, here are some of the benefits of using ChatModel:
1. No need to define a complex signature! Chat models often accept complex inputs with many levels of nesting, and this can be cumbersome to define yourself.
2. Support for JSON / dict inputs (no need to wrap inputs or convert to Pandas DataFrame)
2. Dataclasses for expected inputs / outputs for a better development experience

For a more in-depth exploration of ChatModel, please check out the [detailed guide](https://mlflow.org/docs/latest/llms/chat-model-guide/index.html).

In this tutorial, we'll be building a simple OpenAI wrapper that makes use of the tool calling support (released in MLflow 2.17.0).

### Environment setup

First, let's set up the environment. We'll need the OpenAI Python SDK, as well as MLflow >= 2.17.0. We'll also need to set our OpenAI API key in order to use the SDK.

In [3]:
%pip install 'mlflow>=2.17.0' openai -qq

Note: you may need to restart the kernel to use updated packages.


In [4]:
import os
from getpass import getpass

os.environ["OPENAI_API_KEY"] = getpass("Enter your OpenAI API key: ")

### Step 1: Creating the tool definition

Since our model will be a tool calling model, we'll be defining a `get_weather` tool that we can pass to the OpenAI model to use. We do this by using [`mlflow.types.llm.FunctionToolDefinition`](https://mlflow.org/docs/latest/python_api/mlflow.types.html#mlflow.types.llm.FunctionToolDefinition) to describe the parameters that our tool accepts. The format of this dataclass is aligned with the OpenAI spec:

In [18]:
from mlflow.types.llm import (
    FunctionToolDefinition,
    ParamProperty,
    ToolParamsSchema,
)

# a sample tool definition. we use the `FunctionToolDefinition`
# class to describe the name and expected params for the tool.
# for this example, we're defining a simple tool that returns
# the weather for a given city.
weather_tool = FunctionToolDefinition(
    name="get_weather",
    description="Get weather information",
    parameters=ToolParamsSchema(
        {
            "city": ParamProperty(
                type="string",
                description="City name to get weather information for",
            ),
        }
    ),
    # make sure to call `to_tool_definition()` to convert the `FunctionToolDefinition`
    # to a `ToolDefinition` object. this step is necessary to normalize the data format,
    # as multiple types of tools (besides just functions) might be available in the future.
).to_tool_definition()

### Step 2: Implementing the tool

Now that we have a definition for the tool, we need to actually implement it. For the purposes of this tutorial, we're just going to mock a response, but the implementation can be arbitrary—you might make an API call to an actual weather service, for example.

In [19]:
def get_weather(city: str) -> str:
    return f"It's sunny in {city}, with a temperature of 20C"

### Step 3: Defining our model

Now we're ready to create our model! We'll be subclassing `mlflow.pyfunc.ChatModel`, and this involves defining a `predict()` function that accepts the following arguments:

1. `context`: [PythonModelContext](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.PythonModelContext) (not used in this tutorial)
2. `messages`: List\[[ChatMessage](https://mlflow.org/docs/latest/python_api/mlflow.types.html#mlflow.types.llm.ChatMessage)\]. This is the chat input that the model uses for generation.
3. `params`: [ChatParams](https://mlflow.org/docs/latest/python_api/mlflow.types.html#mlflow.types.llm.ChatParams). These are commonly used params used to configure the chat model, e.g. `temperature`, `max_tokens`, etc. This is where the tool specifications can be found.

For the implementation, we'll simply forward the user's input to OpenAI, and provide the `get_weather` tool as an option for the LLM to use if it chooses to do so. If we receive a tool call request, we'll call the `get_weather()` function and return the response back to OpenAI. We'll need to use what we've defined in the previous two steps in order to do this.

In [20]:
import json
from typing import List

from openai import OpenAI

import mlflow
from mlflow.types.llm import (
    ChatMessage,
    ChatParams,
    ChatResponse,
    FunctionToolDefinition,
    ParamProperty,
    ToolParamsSchema,
)

# replace this with your own MLflow tracking URI
mlflow.set_tracking_uri("http://localhost:5000")


class WeatherModel(mlflow.pyfunc.ChatModel):
    def __init__(self):
        # from step 1 above, our tool definition
        weather_tool = FunctionToolDefinition(
            name="get_weather",
            description="Get weather information",
            parameters=ToolParamsSchema(
                {
                    "city": ParamProperty(
                        type="string",
                        description="City name to get weather information for",
                    ),
                }
            ),
        ).to_tool_definition()

        # OpenAI expects tools to be provided as a list of dictionaries
        self.tools = [weather_tool.to_dict()]

    # from step 2 above, the implementation of the tool
    def get_weather(self, city: str) -> str:
        return "It's sunny in {}, with a temperature of 20C".format(city)

    # the core method that needs to be implemented. this function
    # will be called every time a user sends messages to our model
    def predict(self, context, messages: List[ChatMessage], params: ChatParams):
        # instantiate the OpenAI client
        client = OpenAI()

        # convert the messages to a format that the OpenAI API expects
        messages = [m.to_dict() for m in messages]

        # call the OpenAI API
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=messages,
            # pass the tools in the request
            tools=self.tools,
        )

        # if OpenAI returns a tool_calling response, then we call
        # our tool. otherwise, we just return the response as is
        tool_calls = response.choices[0].message.tool_calls
        if tool_calls:
            print("Received a tool call, calling the weather tool...")

            # for this example, we only provide the model with one tool,
            # so we can assume the tool call is for the weather tool. if
            # we had more, we'd need to check the name of the tool that
            # was called
            city = json.loads(tool_calls[0].function.arguments)["city"]
            tool_call_id = tool_calls[0].id

            # call the tool and construct a new chat message
            tool_response = ChatMessage(
                role="tool", content=self.get_weather(city), tool_call_id=tool_call_id
            ).to_dict()

            # send another request to the API, making sure to append
            # the assistant's tool call along with the tool response.
            messages.append(response.choices[0].message)
            messages.append(tool_response)
            response = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=messages,
                tools=self.tools,
            )

        # return the result as a ChatResponse, as this
        # is the expected output of the predict method
        return ChatResponse.from_dict(response.to_dict())

### Step 4: Logging the model

Next, we need to log the model. This saves the model as an artifact in MLflow Tracking, and allows us to load and serve it later on.

(Note: this is a fundamental pattern in MLflow. To learn more, check out the [Quickstart guide](https://mlflow.org/docs/latest/getting-started/intro-quickstart/index.html)!)

In order to do this, we need to do a few things:

1. Define an input example to inform users about the input we expect
2. Instantiate the model
3. Call `mlflow.pyfunc.log_model()` with the above as arguments

Of course, since this is a tool calling tutorial, we'll have to include an example of how to define a tool call. This is handled by instantiating a [FunctionToolDefinition](https://mlflow.org/docs/latest/python_api/mlflow.types.html#mlflow.types.llm.FunctionToolDefinition):

In [21]:
# messages to use as input examples
messages = [
    {"role": "system", "content": "Please use the provided tools to answer user queries."},
    {"role": "user", "content": "What's the weather in Singapore?"},
]

input_example = {
    "messages": messages,
}

# instantiate the model
model = WeatherModel()

# log the model
with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        artifact_path="weather-model",
        python_model=model,
        input_example=input_example,
    )

2024/10/25 11:39:22 INFO mlflow.pyfunc: Predicting on input example to validate output


Received a tool call, calling the weather tool...


  example = _Example(input_example)
Downloading artifacts: 100%|██████████| 7/7 [00:00<00:00, 216.81it/s]


Received a tool call, calling the weather tool...


2024/10/25 11:39:29 INFO mlflow.tracking._tracking_service.client: 🏃 View run gaudy-rook-875 at: http://localhost:5000/#/experiments/0/runs/567216afacbe4cd9bf2fcbe019e71b5b.
2024/10/25 11:39:29 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5000/#/experiments/0.


### Using the model for generations

Now that the model is logged, our work is more or less done! In order to use the model for predictions, let's load it back using `mlflow.pyfunc.load_model()`.


In [25]:
import mlflow

# Load the previously logged ChatModel
tool_model = mlflow.pyfunc.load_model(model_info.model_uri)

system_prompt = {
    "role": "system",
    "content": "Please use the provided tools to answer user queries.",
}

messages = [
    system_prompt,
    {"role": "user", "content": "What's the weather in Singapore?"},
]

# Call the model's predict method
response = tool_model.predict({"messages": messages})
print(response["choices"][0]["message"]["content"])

messages = [
    system_prompt,
    {"role": "user", "content": "What's the weather in San Francisco?"},
]

# Generating another response
response = tool_model.predict({"messages": messages})
print(response["choices"][0]["message"]["content"])

Downloading artifacts: 100%|██████████| 7/7 [00:00<00:00, 160.13it/s]


Received a tool call, calling the weather tool...
The weather in Singapore is sunny, with a temperature of 20°C.
Received a tool call, calling the weather tool...
The weather in San Francisco is sunny, with a temperature of 20°C.


### Serving the model

MLflow also allows you to serve models, using the `mlflow models serve` CLI tool. In another terminal shell, run the following (you might have to set the `MLFLOW_TRACKING_URI` environment variable to `http://localhost:5000`, or to wherever your tracking server is located).

```sh
$ export OPENAI_API_KEY=<YOUR OPENAI API KEY>
$ mlflow models serve -m <MODEL_URI> -p 8000
```

This will start serving the model on `http://localhost:8000`, and the model can be queried via POST request to the `/invocations` route.

In [33]:
import requests

messages = [
    system_prompt,
    {"role": "user", "content": "What's the weather in Tokyo?"},
]

requests.post("http://localhost:8000/invocations", json={"messages": messages}).json()

{'choices': [{'index': 0,
   'message': {'role': 'assistant',
    'content': 'The weather in Tokyo is sunny, with a temperature of 20°C.'},
   'finish_reason': 'stop'}],
 'usage': {'prompt_tokens': 100, 'completion_tokens': 16, 'total_tokens': 116},
 'id': 'chatcmpl-AM5lXVHQqsuAzI1F0OoZx8dYZsogq',
 'model': 'gpt-4o-mini-2024-07-18',
 'object': 'chat.completion',
 'created': 1729828743}

### Conclusion

In this tutorial, we covered how to use MLflow's `ChatModel` class to create a convenient OpenAI wrapper that supports tool calling. Though the use-case was simple, the concepts covered here can be easily extended to support more complex functionality.

If you're looking to dive deeper into building quality GenAI apps, you might be also be interested in checking out [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html), an observability tool you can use to trace the execution of arbitrary functions (such as your tool calls, for example).