## Build a tool-calling model with `mlflow.pyfunc.ChatModel`

Welcome to the notebook tutorial on building a simple tool calling model using the [`mlflow.pyfunc.ChatModel`](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ChatModel) wrapper. ChatModel is a subclass of MLflow's highly customizable [PythonModel](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.PythonModel), which was specifically designed to make GenAI workflows easier.

Briefly, here are some of the benefits of using ChatModel:
1. No need to define a complex signature! Chat models often accept complex inputs with many levels of nesting, and this can be cumbersome to define yourself.
2. Support for JSON / dict inputs (no need to wrap inputs or convert to Pandas DataFrame)
2. Dataclasses for expected inputs / outputs for a better development experience

For a more in-depth exploration of ChatModel, please check out the [detailed guide](https://mlflow.org/docs/latest/llms/chat-model-guide/index.html).

In this tutorial, we'll be building a simple OpenAI wrapper that makes use of the tool calling support (released in MLflow 2.17.0).

### Environment setup

First, let's set up the environment. We'll need the OpenAI Python SDK, as well as MLflow >= 2.17.0. We'll also need to set our OpenAI API key in order to use the SDK.

In [None]:
%pip install mlflow>=2.17.0 openai

In [2]:
import os
from getpass import getpass

os.environ["OPENAI_API_KEY"] = getpass("Enter your OpenAI API key: ")

### Defining our model

The first step in this tutorial is to define our model by subclassing `mlflow.pyfunc.ChatModel`. This involves defining a `predict()` function that accepts the following arguments:

1. `context`: [PythonModelContext](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.PythonModelContext) (not used in this tutorial)
2. `messages`: List\[[ChatMessage](https://mlflow.org/docs/latest/python_api/mlflow.types.html#mlflow.types.llm.ChatMessage)\]. This is the chat input that the model uses for generation.
3. `params`: [ChatParams](https://mlflow.org/docs/latest/python_api/mlflow.types.html#mlflow.types.llm.ChatParams). These are commonly used params used to configure the chat model, e.g. `temperature`, `max_tokens`, etc. This is where the tool specifications can be found.

As mentioned earlier, our use-case is fairly simple—we'll just be forwarding the inputs to OpenAI. However, you can implement any arbitrary logic here for more complex use-cases (e.g. input pre-processing, post-processing for the OpenAI response).

In [3]:
from typing import List

from openai import OpenAI

import mlflow
from mlflow.types.llm import (
    ChatMessage,
    ChatParams,
    ChatResponse,
)

# replace this with your own MLflow tracking URI
mlflow.set_tracking_uri("http://localhost:5000")

class WeatherModel(mlflow.pyfunc.ChatModel):
    def predict(self, context, messages: List[ChatMessage], params: ChatParams):
        # instantiate the OpenAI client
        client = OpenAI()

        # check if the request contains any tool definitions.
        tools = None 
        if params.tools:
            # if so, call the `to_dict` method to convert them to dictionaries
            # that can be parsed by the OpenAI SDK. the format of MLflow tools
            # are compatible with OpenAI, so no further processing needs to be done.
            tools = [tool.to_dict() for tool in params.tools]

        # call the API
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[m.to_dict() for m in messages],
            # pass the tools in the request
            tools=tools,
        )

        # return a ChatResponse, as this is the expected output of the predict method
        return ChatResponse.from_dict(response.to_dict())

ImportError: cannot import name 'ModelInputExample' from 'mlflow.models' (/Users/daniel.lok/miniconda3/envs/dev/lib/python3.9/site-packages/mlflow/models/__init__.py)

### Logging the model

Next, we need to log the model. This saves the model as an artifact in MLflow Tracking, and allows us to load and serve it later on.

(Note: this is a fundamental pattern in MLflow. To learn more, check out the [Quickstart guide](https://mlflow.org/docs/latest/getting-started/intro-quickstart/index.html)!)

In order to do this, we need to do a few things:

1. Define an input example to inform users about the input we expect
2. Instantiate the model
3. Call `mlflow.pyfunc.log_model()` with the above as arguments

Of course, since this is a tool calling tutorial, we'll have to include an example of how to define a tool call. This is handled by instantiating a [FunctionToolDefinition](https://mlflow.org/docs/latest/python_api/mlflow.types.html#mlflow.types.llm.FunctionToolDefinition):

In [None]:
from mlflow.types.llm import (
    FunctionToolDefinition,
    ParamProperty,
    ToolParamsSchema,
)

# messages to use as input examples
messages = [
    {"role": "system", "content": "Please use the provided tools to answer user queries."},
    {"role": "user", "content": "What's the weather in Singapore?"},
]

# a sample tool definition. we use the `FunctionToolDefinition`
# class to describe the name and expected params for the tool.
weather_tool = FunctionToolDefinition(
    name="get_weather",
    description="Get weather information",
    parameters=ToolParamsSchema({
        "cities": ParamProperty(
            type="string",
            description="City name to get weather information for",
        ),
    }),
# make sure to call `to_tool_definition()` to convert the `FunctionToolDefinition`
# to a `ToolDefinition` object. this step is necessary to normalize the data format,
# as multiple types of tools (besides just functions) might be available in the future.
).to_tool_definition()

# the full input to the model includes both the messages and tools. see the
# docs for `mlflow.types.llm.ChatRequest` for full details on the input format.
input_example = {
    "messages": messages,
    "tools": [weather_tool],
}

# instantiate the model
model = WeatherModel()

# log the model
with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        artifact_path="weather-model",
        python_model=model,
        input_example=input_example,
    )

### Using the model for generations

