In [None]:
%load_ext autoreload
%autoreload 2

# ai.messages

> Define message types

In [None]:
#| default_exp ai.responses

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from iagent.ai.messages import AssistantMessage
from typing import List, Optional
from typing_extensions import Literal
from pydantic import BaseModel
#
# The type definitions are adapted by reference on OPENAI specifications 
# to avoid huge dependency
#


class CompletionUsage(BaseModel):
    completion_tokens: int
    """Number of tokens in the generated completion."""

    prompt_tokens: int
    """Number of tokens in the prompt."""

    total_tokens: int
    """Total number of tokens used in the request (prompt + completion)."""

    model_config = {
        "extra": "allow"
    }


class Choice(BaseModel):
    finish_reason: Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
    """The reason the model stopped generating tokens.

    This will be `stop` if the model hit a natural stop point or a provided stop
    sequence, `length` if the maximum number of tokens specified in the request was
    reached, `content_filter` if content was omitted due to a flag from our content
    filters, `tool_calls` if the model called a tool, or `function_call`
    (deprecated) if the model called a function.
    """

    index: int
    """The index of the choice in the list of choices."""

    message: AssistantMessage
    """A chat completion message generated by the model."""

    model_config = {
        "extra": "allow"
    }

class ChatCompletionResponse(BaseModel):
    id: str
    """A unique identifier for the chat completion."""

    choices: List[Choice]
    """A list of chat completion choices.

    Can be more than one if `n` is greater than 1.
    """

    created: int
    """The Unix timestamp (in seconds) of when the chat completion was created."""

    model: str
    """The model used for the chat completion."""

    usage: Optional[CompletionUsage] = None
    """Usage statistics for the completion request."""

    model_config = {
        "extra": "allow"
    }


In [None]:
from iagent.ai.messages import Function, ToolCall, Annotation, AnnotationURLCitation
# Basic ChatCompletionResponse with minimal fields
basic_response = ChatCompletionResponse(
    id="chatcmpl-123",
    choices=[
        Choice(
            finish_reason="stop",
            index=0,
            message=AssistantMessage(content="Hello, how can I help you today?")
        )
    ],
    created=1677858242,
    model="gpt-4"
)
assert basic_response.id == "chatcmpl-123"
assert len(basic_response.choices) == 1
assert basic_response.choices[0].finish_reason == "stop"
assert basic_response.choices[0].index == 0
assert basic_response.choices[0].message.content == "Hello, how can I help you today?"
assert basic_response.created == 1677858242
assert basic_response.model == "gpt-4"
assert basic_response.usage is None

# ChatCompletionResponse with usage information
response_with_usage = ChatCompletionResponse(
    id="chatcmpl-456",
    choices=[
        Choice(
            finish_reason="stop",
            index=0,
            message=AssistantMessage(content="The weather in San Francisco is sunny.")
        )
    ],
    created=1677858300,
    model="gpt-3.5-turbo",
    usage=CompletionUsage(
        prompt_tokens=10,
        completion_tokens=8,
        total_tokens=18
    )
)
assert response_with_usage.id == "chatcmpl-456"
assert response_with_usage.usage.prompt_tokens == 10
assert response_with_usage.usage.completion_tokens == 8
assert response_with_usage.usage.total_tokens == 18

# ChatCompletionResponse with multiple choices
multi_choice_response = ChatCompletionResponse(
    id="chatcmpl-789",
    choices=[
        Choice(
            finish_reason="stop",
            index=0,
            message=AssistantMessage(content="Option 1: Go to the park.")
        ),
        Choice(
            finish_reason="stop",
            index=1,
            message=AssistantMessage(content="Option 2: Stay home and read a book.")
        )
    ],
    created=1677858400,
    model="gpt-4"
)
assert multi_choice_response.id == "chatcmpl-789"
assert len(multi_choice_response.choices) == 2
assert multi_choice_response.choices[0].index == 0
assert multi_choice_response.choices[1].index == 1
assert "Option 1" in multi_choice_response.choices[0].message.content
assert "Option 2" in multi_choice_response.choices[1].message.content

# ChatCompletionResponse with tool calls
tool_call_response = ChatCompletionResponse(
    id="chatcmpl-tool123",
    choices=[
        Choice(
            finish_reason="tool_calls",
            index=0,
            message=AssistantMessage(
                content="I'll check the weather for you.",
                tool_calls=[
                    ToolCall(
                        id="call_abc123",
                        type="function",
                        function=Function(
                            name="get_weather",
                            arguments='{"location": "San Francisco", "unit": "celsius"}'
                        )
                    )
                ]
            )
        )
    ],
    created=1677858500,
    model="gpt-4"
)
assert tool_call_response.id == "chatcmpl-tool123"
assert tool_call_response.choices[0].finish_reason == "tool_calls"
assert tool_call_response.choices[0].message.tool_calls[0].function.name == "get_weather"
assert "San Francisco" in tool_call_response.choices[0].message.tool_calls[0].function.arguments

# ChatCompletionResponse with content filter
content_filter_response = ChatCompletionResponse(
    id="chatcmpl-filter456",
    choices=[
        Choice(
            finish_reason="content_filter",
            index=0,
            message=AssistantMessage(
                content=None,
                refusal="I cannot provide the requested content as it violates content policies."
            )
        )
    ],
    created=1677858600,
    model="gpt-4"
)
assert content_filter_response.id == "chatcmpl-filter456"
assert content_filter_response.choices[0].finish_reason == "content_filter"
assert content_filter_response.choices[0].message.content is None
assert "violates content policies" in content_filter_response.choices[0].message.refusal

# Test serialization/deserialization
import json

# Serialize and deserialize a basic response
json_str = basic_response.model_dump_json()
deserialized_response = ChatCompletionResponse.model_validate_json(json_str)
assert deserialized_response.id == basic_response.id
assert deserialized_response.model == basic_response.model
assert deserialized_response.choices[0].message.content == basic_response.choices[0].message.content

# Test with complex response containing all fields
complex_response = ChatCompletionResponse(
    id="chatcmpl-complex789",
    choices=[
        Choice(
            finish_reason="stop",
            index=0,
            message=AssistantMessage(
                content="Here's the information you requested.",
                annotations=[
                    Annotation(
                        type="url_citation",
                        url_citation=AnnotationURLCitation(
                            start_index=5,
                            end_index=25,
                            title="Research Paper",
                            url="https://example.com/research"
                        )
                    )
                ]
            )
        )
    ],
    created=1677858700,
    model="gpt-4",
    usage=CompletionUsage(
        prompt_tokens=15,
        completion_tokens=10,
        total_tokens=25
    )
)

# Serialize and deserialize the complex response
complex_json = complex_response.model_dump_json()
restored_response = ChatCompletionResponse.model_validate_json(complex_json)

# Verify all fields were preserved
assert restored_response.id == complex_response.id
assert restored_response.model == complex_response.model
assert restored_response.choices[0].message.content == complex_response.choices[0].message.content
assert restored_response.choices[0].message.annotations[0].url_citation.title == "Research Paper"
assert restored_response.usage.total_tokens == 25

# Test with extra fields (should be allowed due to model_config)
extra_fields_json = '''
{
    "id": "chatcmpl-extra123",
    "choices": [
        {
            "finish_reason": "stop",
            "index": 0,
            "message": {
                "role": "assistant",
                "content": "Hello there!"
            },
            "custom_choice_field": "value"
        }
    ],
    "created": 1677858800,
    "model": "gpt-4",
    "custom_response_field": "value"
}
'''
extra_fields_response = ChatCompletionResponse.model_validate_json(extra_fields_json)
assert extra_fields_response.id == "chatcmpl-extra123"
assert extra_fields_response.choices[0].message.content == "Hello there!"

# Test with different finish_reason values
finish_reasons = ["stop", "length", "tool_calls", "content_filter", "function_call"]
for reason in finish_reasons:
    response = ChatCompletionResponse(
        id=f"chatcmpl-{reason}",
        choices=[
            Choice(
                finish_reason=reason,
                index=0,
                message=AssistantMessage(content="Test message")
            )
        ],
        created=1677858900,
        model="gpt-4"
    )
    assert response.choices[0].finish_reason == reason


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()