# ai.messages

> Define message types

In [None]:
#| default_exp ai.messages

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from dataclasses import dataclass
import json
from typing import List, Optional, Union
from typing_extensions import Literal, TypeAlias
from pydantic import BaseModel
#
# The type definitions are adapted by reference on OPENAI specifications.
# but decide to go with dataclass instead of TypedDict
#


@dataclass
class InputAudio:
    data: str
    """Base64 encoded audio data."""

    format: Literal["wav", "mp3"]
    """The format of the encoded audio data."""


@dataclass
class PartInputAudio():
    input_audio: InputAudio

    type: Optional[Literal["input_audio"]] = "input_audio"
    """The type of the content part. Always `input_audio`."""


@dataclass
class ImageURL():
    url: str
    """Either a URL of the image or the base64 encoded image data."""

    detail: Optional[Literal["auto", "low", "high"]] = None
    """Specifies the detail level of the image."""


@dataclass
class PartImage():
    image_url: ImageURL

    type: Optional[Literal["image_url"]] = "image_url"
    """The type of the content part."""


@dataclass
class PartText():
    text: str
    """The text content."""

    type: Optional[Literal["text"]] = "text"
    """The type of the content part."""


ContentPart: TypeAlias = Union[
    PartText,
    PartImage,
    PartInputAudio,
]


class UserMessage(BaseModel):
    content: Union[str, List[ContentPart]]
    """The contents of the user message."""

    role: Literal["user"] = "user"
    """The role of the messages author, in this case `user`."""

    def addPart(self, part: ContentPart):
        """
        Add a ContentPart to the message content. If content is a string, it will be converted to a list.
        """
        if isinstance(self.content, str):
            self.content = [PartText(text=self.content)]
        if not isinstance(self.content, list):
            self.content = []
        self.content.append(part)    


class ToolResponseMessage(BaseModel):
    content: Union[str, List[PartText]]
    """The contents of the tool message."""

    role: Literal["tool"] = "tool"
    """The role of the messages author, in this case `tool`."""

    tool_call_id: Optional[str] = None
    """Tool call that this message is responding to."""


class AnnotationURLCitation(BaseModel):
    end_index: int
    """The index of the last character of the URL citation in the message."""

    start_index: int
    """The index of the first character of the URL citation in the message."""

    title: str
    """The title of the web resource."""

    url: str
    """The URL of the web resource."""


class Annotation(BaseModel):
    type: Literal["url_citation"]
    """The type of the URL citation. Always `url_citation`."""

    url_citation: AnnotationURLCitation
    """A URL citation when using web search."""


class AssistantAudio(BaseModel):
    id: str
    """Unique identifier for this audio response."""

    data: str
    """
    Base64 encoded audio bytes generated by the model, in the format specified in
    the request.
    """

    expires_at: int
    """
    The Unix timestamp (in seconds) for when this audio response will no longer be
    accessible on the server for use in multi-turn conversations.
    """

    transcript: str
    """Transcript of the audio generated by the model."""


class Function(BaseModel):
    arguments: str
    """
    The arguments to call the function with, as generated by the model in JSON
    format. Note that the model does not always generate valid JSON, and may
    hallucinate parameters not defined by your function schema. Validate the
    arguments in your code before calling your function.
    """

    name: str
    """The name of the function to call."""


class ToolCall(BaseModel):
    id: str
    """The ID of the tool call."""

    function: Function
    """The function that the model called."""

    type: Literal["function"]
    """The type of the tool. Currently, only `function` is supported."""


class AssistantMessage(BaseModel):
    content: Optional[str] = None
    """The contents of the message."""

    refusal: Optional[str] = None
    """The refusal message generated by the model."""

    role: Literal["assistant"] = "assistant"
    """The role of the author of this message."""

    annotations: Optional[List[Annotation]] = None
    """
    Annotations for the message, when applicable, as when using the
    """

    audio: Optional[AssistantAudio] = None
    """
    If the audio output modality is requested, this object contains data about the
    audio response from the model.
    """

    tool_calls: Optional[List[ToolCall]] = None
    """The tool calls generated by the model, such as function calls."""

    model_config = {
        "extra": "allow"
    }


AnyMessage: TypeAlias = Union[
    UserMessage,
    AssistantMessage,
    ToolResponseMessage,
]

class HistoryMessage():
    """
    A message in the conversation history. This class is used to store messages
    in the conversation history, regardless of their type.
    """

    messages: List[AnyMessage]
    """The list of messages in the conversation history."""

    def __init__(self, messages: List[AnyMessage]=None):
        """
        Initialize the conversation history with an empty list of messages.
        """
        if messages is None:
            messages = []
        self.messages = messages

    def addMessage(self, message: AnyMessage):
        """
        Add a message to the conversation history.
        """
        self.messages.append(message)
    
    def to_json(self):
        """
        Convert the conversation history to JSON format.
        """
        return [message.model_dump_json() for message in self.messages]
    
    def to_dict(self):
        """
        Convert the conversation history to JSON format.
        """
        return [message.model_dump() for message in self.messages]

    def from_json(self, json_data: List[str]):
        """
        Load the conversation history from JSON format.
        """
        self.messages = []
        for message_json in json_data:
            # Parse the JSON string to get a dictionary
            message_dict = json.loads(message_json)

            # Determine the message type based on the role
            role = message_dict.get("role")
            if role == "user":
                message = UserMessage.model_validate_json(message_json)
            elif role == "assistant":
                message = AssistantMessage.model_validate_json(message_json)
            elif role == "tool":
                message = ToolResponseMessage.model_validate_json(message_json)
            else:
                raise ValueError(f"Unknown message role: {role}")

            self.messages.append(message)
    

In [None]:
userMessage = UserMessage(content="Write a one-sentence bedtime story about a unicorn.")
assert userMessage.role == 'user'
assert userMessage.content == "Write a one-sentence bedtime story about a unicorn."

In [None]:
user = UserMessage(content=[
    PartText("What's in this image?"),
    PartImage(
        image_url=ImageURL(
            url="https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
        )
    ),
])
assert user.content[0].text == "What's in this image?"
assert user.content[1].type == "image_url"

In [None]:
user.addPart(PartText("This is a test"))
assert user.content[2].text == "This is a test"
assert user.content[2].type == "text"

In [None]:
# Basic AssistantMessage with just content
assistant_msg = AssistantMessage(content="Once upon a time, a magical unicorn with a rainbow mane granted wishes to all the children in the village.")
assert assistant_msg.role == 'assistant'
assert assistant_msg.content == "Once upon a time, a magical unicorn with a rainbow mane granted wishes to all the children in the village."
assert assistant_msg.refusal is None
assert assistant_msg.annotations is None
assert assistant_msg.audio is None
assert assistant_msg.tool_calls is None

# AssistantMessage with a refusal
refusal_msg = AssistantMessage(refusal="I cannot generate content that violates ethical guidelines.")
assert refusal_msg.role == 'assistant'
assert refusal_msg.content is None
assert refusal_msg.refusal == "I cannot generate content that violates ethical guidelines."

# AssistantMessage with annotations (URL citations)
annotation = Annotation(
    type="url_citation",
    url_citation=AnnotationURLCitation(
        start_index=10,
        end_index=20,
        title="Example Resource",
        url="https://example.com/resource"
    )
)
annotated_msg = AssistantMessage(
    content="According to research from Example Resource, unicorns symbolize purity and grace.",
    annotations=[annotation]
)
assert annotated_msg.role == 'assistant'
assert annotated_msg.annotations[0].type == "url_citation"
assert annotated_msg.annotations[0].url_citation.title == "Example Resource"
assert annotated_msg.annotations[0].url_citation.url == "https://example.com/resource"

# AssistantMessage with tool calls
tool_call = ToolCall(
    id="call_abc123",
    type="function",
    function=Function(
        name="get_weather",
        arguments='{"location": "San Francisco", "unit": "celsius"}'
    )
)
tool_call_msg = AssistantMessage(
    content="I'll check the weather for you.",
    tool_calls=[tool_call]
)
assert tool_call_msg.role == 'assistant'
assert tool_call_msg.tool_calls[0].id == "call_abc123"
assert tool_call_msg.tool_calls[0].function.name == "get_weather"
assert "San Francisco" in tool_call_msg.tool_calls[0].function.arguments

# AssistantMessage with audio
audio_msg = AssistantMessage(
    content="Here's your bedtime story about unicorns.",
    audio=AssistantAudio(
        id="audio_xyz789",
        data="base64encodedaudiodata...",
        expires_at=1714583467,
        transcript="Once upon a time, there was a magical unicorn..."
    )
)
assert audio_msg.role == 'assistant'
assert audio_msg.audio.id == "audio_xyz789"
assert audio_msg.audio.transcript.startswith("Once upon a time")
assert audio_msg.audio.expires_at == 1714583467


In [None]:
# Test serialization/deserialization with Pydantic v2
import json

# Serialize and deserialize a message
original_msg = AssistantMessage(content="Hello, how can I help you today?")
json_str = original_msg.model_dump_json()
deserialized_msg = AssistantMessage.model_validate_json(json_str)
assert deserialized_msg.content == original_msg.content
assert deserialized_msg.role == original_msg.role

# Test with complex message containing multiple fields
complex_msg = AssistantMessage(
    content="I found some information for you.",
    tool_calls=[
        ToolCall(
            id="tool_123",
            type="function",
            function=Function(
                name="search_database",
                arguments='{"query": "unicorn myths", "limit": 5}'
            )
        )
    ],
    annotations=[
        Annotation(
            type="url_citation",
            url_citation=AnnotationURLCitation(
                start_index=5,
                end_index=25,
                title="Mythical Creatures Database",
                url="https://example.com/myths/unicorn"
            )
        )
    ]
)

# Serialize and deserialize the complex message
complex_json = complex_msg.model_dump_json()
restored_msg = AssistantMessage.model_validate_json(complex_json)

# Verify all fields were preserved
assert restored_msg.content == complex_msg.content
assert restored_msg.role == complex_msg.role
assert restored_msg.tool_calls[0].id == complex_msg.tool_calls[0].id
assert restored_msg.tool_calls[0].function.name == complex_msg.tool_calls[0].function.name
assert restored_msg.annotations[0].url_citation.title == complex_msg.annotations[0].url_citation.title

# Test with extra fields (should be allowed due to model_config)
extra_fields_json = '{"content": "Hello", "role": "assistant", "custom_field": "value"}'
extra_fields_msg = AssistantMessage.model_validate_json(extra_fields_json)
assert extra_fields_msg.content == "Hello"
assert extra_fields_msg.role == "assistant"

In [None]:
# Test ToolResponseMessage with string content
tool_response_str = ToolResponseMessage(
    content="The current temperature is 72°F",
    tool_call_id="weather_tool_123"
)
assert tool_response_str.role == "tool"
assert tool_response_str.content == "The current temperature is 72°F"
assert tool_response_str.tool_call_id == "weather_tool_123"

# Test ToolResponseMessage with list of PartText
tool_response_parts = ToolResponseMessage(
    content=[
        PartText(text="Search results for 'unicorn':"),
        PartText(text="1. Unicorns in mythology and folklore"),
        PartText(text="2. The history of unicorn symbolism")
    ],
    tool_call_id="search_tool_456"
)
assert tool_response_parts.role == "tool"
assert len(tool_response_parts.content) == 3
assert tool_response_parts.content[0].text == "Search results for 'unicorn':"
assert tool_response_parts.content[1].text == "1. Unicorns in mythology and folklore"
assert tool_response_parts.content[2].text == "2. The history of unicorn symbolism"
assert tool_response_parts.content[0].type == "text"
assert tool_response_parts.tool_call_id == "search_tool_456"

# Test ToolResponseMessage without tool_call_id
tool_response_no_id = ToolResponseMessage(
    content="Database query completed successfully."
)
assert tool_response_no_id.role == "tool"
assert tool_response_no_id.content == "Database query completed successfully."
assert tool_response_no_id.tool_call_id is None


In [None]:
# Test serialization/deserialization
import json

# Serialize and deserialize a message with string content
original_msg = ToolResponseMessage(
    content="File has been saved.",
    tool_call_id="file_save_789"
)
json_str = original_msg.model_dump_json()
deserialized_msg = ToolResponseMessage.model_validate_json(json_str)
assert deserialized_msg.content == original_msg.content
assert deserialized_msg.role == original_msg.role
assert deserialized_msg.tool_call_id == original_msg.tool_call_id

# Serialize and deserialize a message with list content
original_list_msg = ToolResponseMessage(
    content=[PartText(text="Result 1"), PartText(text="Result 2")]
)
json_list_str = original_list_msg.model_dump_json()
deserialized_list_msg = ToolResponseMessage.model_validate_json(json_list_str)
assert len(deserialized_list_msg.content) == 2
assert deserialized_list_msg.content[0].text == "Result 1"
assert deserialized_list_msg.content[1].text == "Result 2"

# Test conversion between JSON formats
json_from_str = '{"content": "API response received", "role": "tool", "tool_call_id": "api_call_123"}'
msg_from_json = ToolResponseMessage.model_validate_json(json_from_str)
assert msg_from_json.content == "API response received"
assert msg_from_json.tool_call_id == "api_call_123"

# Test with complex JSON content
complex_json = '{"content": [{"text": "Weather forecast:", "type": "text"}, {"text": "Sunny, 75°F", "type": "text"}], "role": "tool", "tool_call_id": "weather_forecast_123"}'
complex_msg = ToolResponseMessage.model_validate_json(complex_json)
assert len(complex_msg.content) == 2
assert complex_msg.content[0].text == "Weather forecast:"
assert complex_msg.content[1].text == "Sunny, 75°F"
assert complex_msg.tool_call_id == "weather_forecast_123"

# Test AnyMessage type alias with ToolResponseMessage
from typing import cast
message: AnyMessage = ToolResponseMessage(content="This is a tool response")
assert message.role == "tool"
tool_message = cast(ToolResponseMessage, message)
assert tool_message.content == "This is a tool response"

# Test with empty content list (edge case)
empty_list_msg = ToolResponseMessage(content=[])
assert empty_list_msg.role == "tool"
assert len(empty_list_msg.content) == 0


In [None]:
def test_init_empty():
    """Test initializing with no messages."""
    history = HistoryMessage()
    assert history.messages == []

test_init_empty()


In [None]:

def test_init_with_messages():
    """Test initializing with a list of messages."""
    user_msg = UserMessage(content="Hello")
    assistant_msg = AssistantMessage(content="Hi there!")
    messages = [user_msg, assistant_msg]
    
    history = HistoryMessage(messages=messages)
    assert len(history.messages) == 2
    assert history.messages[0] == user_msg
    assert history.messages[1] == assistant_msg

test_init_with_messages()


In [None]:

def test_add_message():
    """Test adding a message to the history."""
    history = HistoryMessage()
    user_msg = UserMessage(content="Hello")
    
    history.addMessage(user_msg)
    assert len(history.messages) == 1
    assert history.messages[0] == user_msg
    
    # Add another message
    assistant_msg = AssistantMessage(content="Hi there!")
    history.addMessage(assistant_msg)
    assert len(history.messages) == 2
    assert history.messages[1] == assistant_msg

test_add_message()


In [None]:

def test_to_json():
    """Test converting the history to JSON."""
    user_msg = UserMessage(content="Hello")
    assistant_msg = AssistantMessage(content="Hi there!")
    history = HistoryMessage(messages=[user_msg, assistant_msg])
    
    json_data = history.to_json()
    assert isinstance(json_data, list)
    assert len(json_data) == 2
    
    # Parse the JSON strings to verify content
    user_json = json.loads(json_data[0])
    assistant_json = json.loads(json_data[1])
    
    assert user_json["role"] == "user"
    assert user_json["content"] == "Hello"
    assert assistant_json["role"] == "assistant"
    assert assistant_json["content"] == "Hi there!"

test_to_json()


In [None]:
def test_from_json():
    """Test loading history from JSON."""
    # Create JSON data
    user_json = json.dumps({"role": "user", "content": "Hello"})
    assistant_json = json.dumps({"role": "assistant", "content": "Hi there!"})
    json_data = [user_json, assistant_json]
    
    # Load from JSON
    history = HistoryMessage()
    history.from_json(json_data)
    
    assert len(history.messages) == 2
    assert history.messages[0].role == "user"
    assert history.messages[0].content == "Hello"
    assert history.messages[1].role == "assistant"
    assert history.messages[1].content == "Hi there!"

test_from_json()


In [None]:

def test_complex_message_types():
    """Test history with complex message types."""
    # Create a user message with multiple parts
    user_msg = UserMessage(content=[
        PartText(text="What's in this image?"),
        PartImage(
            image_url=ImageURL(
                url="https://example.com/image.jpg",
            )
        ),
    ])
    
    # Create an assistant message with tool calls
    tool_call = ToolCall(
        id="call_123",
        type="function",
        function=Function(
            name="get_weather",
            arguments='{"location": "Seattle", "unit": "celsius"}'
        )
    )
    assistant_msg = AssistantMessage(
        content="I'll check the weather for you.",
        tool_calls=[tool_call]
    )
    
    # Create a tool response message
    tool_response = ToolResponseMessage(
        content="The weather in Seattle is 15°C and partly cloudy.",
        tool_call_id="call_123"
    )
    
    # Add all to history
    history = HistoryMessage()
    history.addMessage(user_msg)
    history.addMessage(assistant_msg)
    history.addMessage(tool_response)
    
    assert len(history.messages) == 3
    
    # Test to_json and from_json with complex messages
    json_data = history.to_json()
    
    new_history = HistoryMessage()
    new_history.from_json(json_data)
    
    assert len(new_history.messages) == 3
    assert new_history.messages[0].role == "user"
    assert len(new_history.messages[0].content) == 2
    assert new_history.messages[0].content[0].text == "What's in this image?"
    assert new_history.messages[0].content[1].type == "image_url"
    
    assert new_history.messages[1].role == "assistant"
    assert new_history.messages[1].content == "I'll check the weather for you."
    assert len(new_history.messages[1].tool_calls) == 1
    assert new_history.messages[1].tool_calls[0].id == "call_123"
    
    assert new_history.messages[2].role == "tool"
    assert new_history.messages[2].content == "The weather in Seattle is 15°C and partly cloudy."
    assert new_history.messages[2].tool_call_id == "call_123"

test_complex_message_types()


In [None]:

def test_empty_content():
    """Test handling of empty or None content."""
    # Assistant message with no content but with refusal
    assistant_msg = AssistantMessage(refusal="I cannot answer that.")
    
    history = HistoryMessage([assistant_msg])
    json_data = history.to_json()
    
    new_history = HistoryMessage()
    new_history.from_json(json_data)
    
    assert new_history.messages[0].role == "assistant"
    assert new_history.messages[0].content is None
    assert new_history.messages[0].refusal == "I cannot answer that."

test_empty_content()


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()