From a725f75a894da0dc9ed564c448a66898fc30a73b Mon Sep 17 00:00:00 2001 From: Vikram Oberoi Date: Wed, 1 May 2024 17:50:33 -0400 Subject: [PATCH] ANTHROPIC_JSON: allow control characters in JSON strings if strict=False Addresses #612. Anthropic's models regularly have control characters in their strings, producing invalid JSON that causes validation to fail. These changes merge Pydantic's non-strict semantics with those of `json.loads` for ANTHROPIC_JSON mode only. In the event that the client passes in strict=False, control characters will also be allowed in JSON strings. --- instructor/function_calls.py | 26 +++++++++++---- tests/test_function_calls.py | 63 ++++++++++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 9 deletions(-) diff --git a/instructor/function_calls.py b/instructor/function_calls.py index abffe7124..28de8efe1 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -1,14 +1,21 @@ +import json import logging from functools import wraps from typing import Annotated, Any, Optional, TypeVar, cast from docstring_parser import parse from openai.types.chat import ChatCompletion -from pydantic import BaseModel, Field, TypeAdapter, ConfigDict, create_model # type: ignore - remove once Pydantic is updated +from pydantic import ( # type: ignore - remove once Pydantic is updated + BaseModel, + ConfigDict, + Field, + TypeAdapter, + create_model, +) + from instructor.exceptions import IncompleteOutputException from instructor.mode import Mode -from instructor.utils import extract_json_from_codeblock, classproperty - +from instructor.utils import classproperty, extract_json_from_codeblock T = TypeVar("T") @@ -141,9 +148,16 @@ def parse_anthropic_json( text = completion.content[0].text extra_text = extract_json_from_codeblock(text) - return cls.model_validate_json( - extra_text, context=validation_context, strict=strict - ) + + if strict: + return cls.model_validate_json( + extra_text, context=validation_context, strict=True + ) + else: + # Allow control characters. + parsed = json.loads(extra_text, strict=False) + # Pydantic non-strict: https://docs.pydantic.dev/latest/concepts/strict_mode/ + return cls.model_validate(parsed, context=validation_context, strict=False) @classmethod def parse_cohere_tools( diff --git a/tests/test_function_calls.py b/tests/test_function_calls.py index 71dfe568d..bcedb9ef2 100644 --- a/tests/test_function_calls.py +++ b/tests/test_function_calls.py @@ -1,13 +1,14 @@ from typing import TypeVar + import pytest -from pydantic import BaseModel +from anthropic.types import Message, Usage from openai.resources.chat.completions import ChatCompletion +from pydantic import BaseModel, ValidationError -from instructor import openai_schema, OpenAISchema import instructor +from instructor import OpenAISchema, openai_schema from instructor.exceptions import IncompleteOutputException - T = TypeVar("T") @@ -51,6 +52,24 @@ def mock_completion(request: T) -> ChatCompletion: return completion +@pytest.fixture # type: ignore[misc] +def mock_anthropic_message(request: T) -> Message: + data_content = '{\n"data": "Claude says hi"\n}' + if hasattr(request, "param"): + data_content = request.param.get("data_content", data_content) + return Message( + id="test_id", + content=[{ "type": "text", "text": data_content }], + model="claude-3-haiku-20240307", + role="assistant", + stop_reason="end_turn", + stop_sequence=None, + type="message", + usage=Usage( + input_tokens=100, + output_tokens=100, + ) + ) def test_openai_schema() -> None: @openai_schema @@ -122,3 +141,41 @@ def test_incomplete_output_exception_raise( ) -> None: with pytest.raises(IncompleteOutputException): test_model.from_response(mock_completion, mode=instructor.Mode.FUNCTIONS) + +def test_anthropic_no_exception( + test_model: type[OpenAISchema], mock_anthropic_message: Message +) -> None: + test_model_instance = test_model.from_response(mock_anthropic_message, mode=instructor.Mode.ANTHROPIC_JSON) + assert test_model_instance.data == "Claude says hi" + +@pytest.mark.parametrize( + "mock_anthropic_message", + [{"data_content": '{\n"data": "Claude likes\ncontrol\ncharacters"\n}'}], + indirect=True, +) # type: ignore[misc] +def test_control_characters_not_allowed_in_anthropic_json_strict_mode( + test_model: type[OpenAISchema], mock_anthropic_message: Message +) -> None: + with pytest.raises(ValidationError) as exc_info: + test_model.from_response( + mock_anthropic_message, mode=instructor.Mode.ANTHROPIC_JSON, strict=True + ) + + # https://docs.pydantic.dev/latest/errors/validation_errors/#json_invalid + exc = exc_info.value + assert len(exc.errors()) == 1 + assert exc.errors()[0]["type"] == "json_invalid" + assert "control character" in exc.errors()[0]["msg"] + +@pytest.mark.parametrize( + "mock_anthropic_message", + [{"data_content": '{\n"data": "Claude likes\ncontrol\ncharacters"\n}'}], + indirect=True, +) # type: ignore[misc] +def test_control_characters_allowed_in_anthropic_json_non_strict_mode( + test_model: type[OpenAISchema], mock_anthropic_message: Message +) -> None: + test_model_instance = test_model.from_response( + mock_anthropic_message, mode=instructor.Mode.ANTHROPIC_JSON, strict=False + ) + assert test_model_instance.data == "Claude likes\ncontrol\ncharacters"