Skip to content

Commit

Permalink
Refactor code and add test cases for dump_message
Browse files Browse the repository at this point in the history
function
  • Loading branch information
Guiforge committed Dec 1, 2023
1 parent 11d23de commit 8307f6e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 17 deletions.
28 changes: 16 additions & 12 deletions instructor/patch.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import inspect
import json
import warnings
from collections.abc import Iterable
from functools import wraps
from instructor.dsl.multitask import MultiTask, MultiTaskBase
from json import JSONDecodeError
from typing import get_origin, get_args, Callable, Optional, Type, Union, TYPE_CHECKING
from collections.abc import Iterable
from typing import Callable, Optional, Type, Union, get_args, get_origin

from openai import AsyncOpenAI, OpenAI
from openai.types.chat import ChatCompletion
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionMessageParam,
)
from pydantic import BaseModel, ValidationError

if TYPE_CHECKING:
from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam

from .function_calls import OpenAISchema, openai_schema, Mode
from instructor.dsl.multitask import MultiTask, MultiTaskBase

import warnings
from .function_calls import Mode, OpenAISchema, openai_schema

OVERRIDE_DOCS = """
Creates a new chat completion for the provided messages and parameters.
Expand All @@ -41,13 +43,15 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
if it isn't used.
"""
ret: ChatCompletionMessageParam = {"role": message.role, "content": message.content}
ret: ChatCompletionMessageParam = {
"role": message.role,
"content": message.content or "",
}
if message.tool_calls is not None:
ret["content"] += message.tool_calls
ret["content"] += json.dumps(message.model_dump()["tool_calls"])
return ret



def handle_response_model(
*,
response_model: Type[BaseModel],
Expand Down
69 changes: 64 additions & 5 deletions tests/test_patch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import functools
import pytest
import instructor

from openai import OpenAI, AsyncOpenAI

import pytest
from openai import AsyncOpenAI, OpenAI
from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)

from instructor.patch import is_async, wrap_chatcompletion, OVERRIDE_DOCS
import instructor
from instructor.patch import OVERRIDE_DOCS, dump_message, is_async, wrap_chatcompletion


def test_patch_completes_successfully():
Expand Down Expand Up @@ -66,3 +70,58 @@ def test_override_docs():
assert (
"response_model" in OVERRIDE_DOCS
), "response_model should be in OVERRIDE_DOCS"


@pytest.mark.parametrize(
"message, expected",
[
(
ChatCompletionMessage(
role="assistant",
content="Hello, world!",
tool_calls=[
ChatCompletionMessageToolCall(
id="test_tool",
function=Function(arguments="", name="test_tool"),
type="function",
)
],
),
{
"role": "assistant",
"content": 'Hello, world![{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]',
},
),
(
ChatCompletionMessage(
role="assistant",
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="test_tool",
function=Function(arguments="", name="test_tool"),
type="function",
)
],
),
{
"role": "assistant",
"content": '[{"id": "test_tool", "function": {"arguments": "", "name": "test_tool"}, "type": "function"}]',
},
),
(
ChatCompletionMessage(
role="assistant",
content=None,
),
{
"role": "assistant",
"content": "",
},
),
],
)
def test_dump_message(
message: ChatCompletionMessage, expected: ChatCompletionMessageParam
):
assert dump_message(message) == expected

0 comments on commit 8307f6e

Please sign in to comment.