From 55887865faad03d424645509ea4936977ec069bb Mon Sep 17 00:00:00 2001 From: Ashank Tomar Date: Sat, 9 Dec 2023 20:16:53 +0530 Subject: [PATCH] Json mode nested models (#263) --- instructor/patch.py | 4 ++ tests/openai/evals/test_nested_structures.py | 49 ++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/instructor/patch.py b/instructor/patch.py index 89aec3d8e..2a94e4636 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -107,6 +107,10 @@ def handle_response_model( the parsed objects in json that match the following json_schema (do not deviate at all and its okay if you cant be exact):\n {response_model.model_json_schema()['properties']} """ + # Check for nested models + if "$defs" in response_model.model_json_schema(): + message += f"\nHere are some more definitions to adhere too:\n{response_model.model_json_schema()['$defs']}" + new_kwargs["messages"].append( { "role": "assistant", diff --git a/tests/openai/evals/test_nested_structures.py b/tests/openai/evals/test_nested_structures.py index 1e5ddc92f..f38ca2bd2 100644 --- a/tests/openai/evals/test_nested_structures.py +++ b/tests/openai/evals/test_nested_structures.py @@ -47,3 +47,52 @@ def test_nested(mode): assert {x.name.lower() for x in resp.items} == {"apple", "bread", "milk"} assert {x.price for x in resp.items} == {0.5, 2.0, 1.5} assert resp.customer.lower() == "jason" + + +class Book(BaseModel): + title: str + author: str + genre: str + isbn: str + + +class LibraryRecord(BaseModel): + books: List[Book] = Field(..., default_factory=list) + visitor: str + library_id: str + + +@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS, Mode.MD_JSON]) +def test_complex_nested_model(mode): + client = instructor.patch(OpenAI(), mode=mode) + + content = """ + Library visit details: + Visitor: Jason + Library ID: LIB123456 + Books checked out: + - Title: The Great Adventure, Author: Jane Doe, Genre: Fantasy, ISBN: 1234567890 + - Title: History of Tomorrow, Author: John Smith, Genre: Non-Fiction, ISBN: 0987654321 + """ + + resp = client.chat.completions.create( + model="gpt-3.5-turbo-1106", + response_model=LibraryRecord, + messages=[ + { + "role": "user", + "content": content, + }, + ], + ) + + assert resp.visitor.lower() == "jason" + assert resp.library_id == "LIB123456" + assert len(resp.books) == 2 + assert {book.title for book in resp.books} == { + "The Great Adventure", + "History of Tomorrow", + } + assert {book.author for book in resp.books} == {"Jane Doe", "John Smith"} + assert {book.genre for book in resp.books} == {"Fantasy", "Non-Fiction"} + assert {book.isbn for book in resp.books} == {"1234567890", "0987654321"}