Skip to content

Commit

Permalink
Add markdown yaml model to optimize input tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
ssonal committed May 16, 2024
1 parent 28deb93 commit 7c668e0
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 1 deletion.
2 changes: 2 additions & 0 deletions instructor/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def from_openai(
instructor.Mode.JSON,
instructor.Mode.JSON_SCHEMA,
instructor.Mode.MD_JSON,
instructor.Mode.MD_YAML,
}

if provider in {Provider.OPENAI}:
Expand All @@ -397,6 +398,7 @@ def from_openai(
instructor.Mode.FUNCTIONS,
instructor.Mode.PARALLEL_TOOLS,
instructor.Mode.MD_JSON,
instructor.Mode.MD_YAML,
}

if isinstance(client, openai.OpenAI):
Expand Down
24 changes: 23 additions & 1 deletion instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

from instructor.exceptions import IncompleteOutputException
from instructor.mode import Mode
from instructor.utils import classproperty, extract_json_from_codeblock
from instructor.utils import (
classproperty,
extract_json_from_codeblock,
extract_json_from_yaml_codeblock,
)

T = TypeVar("T")

Expand Down Expand Up @@ -116,6 +120,8 @@ def from_response(

if mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
return cls.parse_json(completion, validation_context, strict)
if mode in {Mode.MD_YAML}:
return cls.parse_yaml(completion, validation_context, strict)

raise ValueError(f"Invalid patch mode: {mode}")

Expand Down Expand Up @@ -226,6 +232,22 @@ def parse_json(
strict=strict,
)

@classmethod
def parse_yaml(
cls: type[BaseModel],
completion: ChatCompletion,
validation_context: Optional[dict[str, Any]] = None,
strict: Optional[bool] = None,
) -> BaseModel:
message = completion.choices[0].message.content or ""
message = extract_json_from_yaml_codeblock(message)

return cls.model_validate_json(
message,
context=validation_context,
strict=strict,
)


def openai_schema(cls: type[BaseModel]) -> OpenAISchema:
if not issubclass(cls, BaseModel):
Expand Down
1 change: 1 addition & 0 deletions instructor/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Mode(enum.Enum):
MISTRAL_TOOLS = "mistral_tools"
JSON = "json_mode"
MD_JSON = "markdown_json_mode"
MD_YAML = "markdown_yaml_mode"
JSON_SCHEMA = "json_schema_mode"
ANTHROPIC_TOOLS = "anthropic_tools"
ANTHROPIC_JSON = "anthropic_json"
Expand Down
33 changes: 33 additions & 0 deletions instructor/process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pydantic import BaseModel

import json
import yaml
import inspect
import logging
from typing import (
Expand Down Expand Up @@ -234,6 +235,38 @@ def handle_response_model(
"type": "function",
"function": {"name": response_model.openai_schema["name"]},
}
elif mode in {Mode.MD_YAML}:
message = dedent(
f"""
As a genius expert, your task is to understand the content and provide
the parsed objects in yaml that match the following yaml_schema:\n
{yaml.dump(response_model.model_json_schema(), indent=2)}
Make sure to return an instance of the YAML, not the schema itself
"""
)

new_kwargs["messages"].append(
{
"role": "user",
"content": "Return the correct YAML response within a ```yaml codeblock. not the YAML_SCHEMA",
},
)

# check that the first message is a system message
# if it is not, add a system message to the beginning
if new_kwargs["messages"][0]["role"] != "system":
new_kwargs["messages"].insert(
0,
{
"role": "system",
"content": message,
},
)
# if it is, system append the schema to the end
else:
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}:
# If its a JSON Mode we need to massage the prompt a bit
# in order to get the response we want in a json format
Expand Down
5 changes: 5 additions & 0 deletions instructor/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ def reask_messages(response: ChatCompletion, mode: Mode, exception: Exception):
"role": "user",
"content": f"Correct your JSON ONLY RESPONSE, based on the following errors:\n{exception}",
}
elif mode == Mode.MD_YAML:
yield {
"role": "user",
"content": f"Correct your YAML ONLY RESPONSE, based on the following errors:\n{exception}",
}
else:
yield {
"role": "user",
Expand Down
11 changes: 11 additions & 0 deletions instructor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect
import json
import yaml
import logging
from collections.abc import AsyncGenerator, Generator, Iterable
from typing import (
Expand Down Expand Up @@ -70,6 +71,16 @@ def extract_json_from_codeblock(content: str) -> str:
return content[first_paren : last_paren + 1]


def extract_json_from_yaml_codeblock(content: str) -> str:
yaml_start = content.find("```yaml")
if yaml_start != -1:
yaml_end = content.find("```", yaml_start + 7)
if yaml_end != -1:
yaml_string = yaml.safe_load(content[yaml_start + 7 : yaml_end].strip())
return json.dumps(yaml_string)
return ""


def extract_json_from_stream(chunks: Iterable[str]) -> Generator[str, None, None]:
capturing = False
brace_count = 0
Expand Down
28 changes: 28 additions & 0 deletions tests/llm/test_openai/test_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,34 @@ class Order(BaseModel):
customer: str


def test_yaml(client):
client = instructor.patch(client, mode=instructor.Mode.MD_YAML)
content = """
Order Details:
Customer: Jason
Items:
Name: Apple, Price: 0.50
Name: Bread, Price: 2.00
Name: Milk, Price: 1.50
"""

resp = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=Order,
messages=[
{
"role": "user",
"content": content,
},
],
)
assert len(resp.items) == 3
assert {x.name.lower() for x in resp.items} == {"apple", "bread", "milk"}
assert {x.price for x in resp.items} == {0.5, 2.0, 1.5}
assert resp.customer.lower() == "jason"


@pytest.mark.parametrize("model, mode", product(models, modes))
def test_nested(model, mode, client):
client = instructor.patch(client, mode=mode)
Expand Down

0 comments on commit 7c668e0

Please sign in to comment.