diff --git a/instructor/client.py b/instructor/client.py index b21916ddd..fa4aeb978 100644 --- a/instructor/client.py +++ b/instructor/client.py @@ -388,6 +388,7 @@ def from_openai( instructor.Mode.JSON, instructor.Mode.JSON_SCHEMA, instructor.Mode.MD_JSON, + instructor.Mode.MD_YAML, } if provider in {Provider.OPENAI}: @@ -397,6 +398,7 @@ def from_openai( instructor.Mode.FUNCTIONS, instructor.Mode.PARALLEL_TOOLS, instructor.Mode.MD_JSON, + instructor.Mode.MD_YAML, } if isinstance(client, openai.OpenAI): diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 28de8efe1..40175be97 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -15,7 +15,11 @@ from instructor.exceptions import IncompleteOutputException from instructor.mode import Mode -from instructor.utils import classproperty, extract_json_from_codeblock +from instructor.utils import ( + classproperty, + extract_json_from_codeblock, + extract_json_from_yaml_codeblock, +) T = TypeVar("T") @@ -116,6 +120,8 @@ def from_response( if mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}: return cls.parse_json(completion, validation_context, strict) + if mode in {Mode.MD_YAML}: + return cls.parse_yaml(completion, validation_context, strict) raise ValueError(f"Invalid patch mode: {mode}") @@ -226,6 +232,22 @@ def parse_json( strict=strict, ) + @classmethod + def parse_yaml( + cls: type[BaseModel], + completion: ChatCompletion, + validation_context: Optional[dict[str, Any]] = None, + strict: Optional[bool] = None, + ) -> BaseModel: + message = completion.choices[0].message.content or "" + message = extract_json_from_yaml_codeblock(message) + + return cls.model_validate_json( + message, + context=validation_context, + strict=strict, + ) + def openai_schema(cls: type[BaseModel]) -> OpenAISchema: if not issubclass(cls, BaseModel): diff --git a/instructor/mode.py b/instructor/mode.py index f40d02dc2..bdf180725 100644 --- a/instructor/mode.py +++ b/instructor/mode.py @@ -11,6 +11,7 @@ class Mode(enum.Enum): MISTRAL_TOOLS = "mistral_tools" JSON = "json_mode" MD_JSON = "markdown_json_mode" + MD_YAML = "markdown_yaml_mode" JSON_SCHEMA = "json_schema_mode" ANTHROPIC_TOOLS = "anthropic_tools" ANTHROPIC_JSON = "anthropic_json" diff --git a/instructor/process_response.py b/instructor/process_response.py index 16c24ad94..e154a46f8 100644 --- a/instructor/process_response.py +++ b/instructor/process_response.py @@ -13,6 +13,7 @@ from pydantic import BaseModel import json +import yaml import inspect import logging from typing import ( @@ -234,6 +235,38 @@ def handle_response_model( "type": "function", "function": {"name": response_model.openai_schema["name"]}, } + elif mode in {Mode.MD_YAML}: + message = dedent( + f""" + As a genius expert, your task is to understand the content and provide + the parsed objects in yaml that match the following yaml_schema:\n + + {yaml.dump(response_model.model_json_schema(), indent=2)} + + Make sure to return an instance of the YAML, not the schema itself + """ + ) + + new_kwargs["messages"].append( + { + "role": "user", + "content": "Return the correct YAML response within a ```yaml codeblock. not the YAML_SCHEMA", + }, + ) + + # check that the first message is a system message + # if it is not, add a system message to the beginning + if new_kwargs["messages"][0]["role"] != "system": + new_kwargs["messages"].insert( + 0, + { + "role": "system", + "content": message, + }, + ) + # if it is, system append the schema to the end + else: + new_kwargs["messages"][0]["content"] += f"\n\n{message}" elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}: # If its a JSON Mode we need to massage the prompt a bit # in order to get the response we want in a json format diff --git a/instructor/retry.py b/instructor/retry.py index 0524ec93b..9cee559ec 100644 --- a/instructor/retry.py +++ b/instructor/retry.py @@ -115,6 +115,11 @@ def reask_messages(response: ChatCompletion, mode: Mode, exception: Exception): "role": "user", "content": f"Correct your JSON ONLY RESPONSE, based on the following errors:\n{exception}", } + elif mode == Mode.MD_YAML: + yield { + "role": "user", + "content": f"Correct your YAML ONLY RESPONSE, based on the following errors:\n{exception}", + } else: yield { "role": "user", diff --git a/instructor/utils.py b/instructor/utils.py index 9f9448ed4..e128a98cc 100644 --- a/instructor/utils.py +++ b/instructor/utils.py @@ -2,6 +2,7 @@ import inspect import json +import yaml import logging from collections.abc import AsyncGenerator, Generator, Iterable from typing import ( @@ -70,6 +71,16 @@ def extract_json_from_codeblock(content: str) -> str: return content[first_paren : last_paren + 1] +def extract_json_from_yaml_codeblock(content: str) -> str: + yaml_start = content.find("```yaml") + if yaml_start != -1: + yaml_end = content.find("```", yaml_start + 7) + if yaml_end != -1: + yaml_string = yaml.safe_load(content[yaml_start + 7 : yaml_end].strip()) + return json.dumps(yaml_string) + return "" + + def extract_json_from_stream(chunks: Iterable[str]) -> Generator[str, None, None]: capturing = False brace_count = 0 diff --git a/tests/llm/test_openai/test_modes.py b/tests/llm/test_openai/test_modes.py index 08b30a010..c7c8c3c05 100644 --- a/tests/llm/test_openai/test_modes.py +++ b/tests/llm/test_openai/test_modes.py @@ -17,6 +17,34 @@ class Order(BaseModel): customer: str +def test_yaml(client): + client = instructor.patch(client, mode=instructor.Mode.MD_YAML) + content = """ + Order Details: + Customer: Jason + Items: + + Name: Apple, Price: 0.50 + Name: Bread, Price: 2.00 + Name: Milk, Price: 1.50 + """ + + resp = client.chat.completions.create( + model="gpt-3.5-turbo", + response_model=Order, + messages=[ + { + "role": "user", + "content": content, + }, + ], + ) + assert len(resp.items) == 3 + assert {x.name.lower() for x in resp.items} == {"apple", "bread", "milk"} + assert {x.price for x in resp.items} == {0.5, 2.0, 1.5} + assert resp.customer.lower() == "jason" + + @pytest.mark.parametrize("model, mode", product(models, modes)) def test_nested(model, mode, client): client = instructor.patch(client, mode=mode)