-
-
Notifications
You must be signed in to change notification settings - Fork 514
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
69 additions
and
263 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,193 +1,98 @@ | ||
from openai import OpenAI | ||
from pydantic import BaseModel, Field | ||
from typing import List | ||
|
||
import pytest | ||
|
||
import instructor | ||
from instructor.function_calls import OpenAISchema, Mode | ||
from openai import OpenAI, AsyncOpenAI | ||
from instructor.function_calls import Mode | ||
|
||
|
||
class UserExtract(OpenAISchema): | ||
class Item(BaseModel): | ||
name: str | ||
age: int | ||
price: float | ||
|
||
|
||
def test_tool_call(): | ||
client = OpenAI() | ||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo-1106", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Extract jason is 25 years old, mary is 30 years old", | ||
}, | ||
], | ||
tools=[ | ||
{ | ||
"type": "function", | ||
"function": UserExtract.openai_schema, | ||
} | ||
], | ||
tool_choice={ | ||
"type": "function", | ||
"function": {"name": UserExtract.openai_schema["name"]}, | ||
}, | ||
) | ||
response_message = response.choices[0].message | ||
tool_calls = response_message.tool_calls | ||
assert len(tool_calls) == 1 | ||
assert tool_calls[0].function.name == "UserExtract" | ||
assert tool_calls[0].function | ||
user = UserExtract.from_response(response, mode=Mode.TOOLS) | ||
assert user.name.lower() == "jason" | ||
assert user.age == 25 | ||
|
||
|
||
def test_json_mode(): | ||
client = OpenAI() | ||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo-1106", | ||
response_format={"type": "json_object"}, | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{UserExtract.model_json_schema()['properties']}", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "Extract jason is 25 years old", | ||
}, | ||
], | ||
) | ||
user = UserExtract.from_response(response, mode=Mode.JSON) | ||
assert user.name.lower() == "jason" | ||
assert user.age == 25 | ||
class Order(BaseModel): | ||
items: List[Item] = Field(..., default_factory=list) | ||
customer: str | ||
|
||
|
||
def test_markdown_json_mode(): | ||
client = OpenAI() | ||
response = client.chat.completions.create( | ||
model="gpt-3.5-turbo-1106", | ||
response_format={"type": "json_object"}, | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{UserExtract.model_json_schema()['properties']}", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "Extract jason is 25 years old", | ||
}, | ||
], | ||
) | ||
user = UserExtract.from_response(response, mode=Mode.MD_JSON) | ||
assert user.name.lower() == "jason" | ||
assert user.age == 25 | ||
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS, Mode.MD_JSON]) | ||
def test_nested(mode): | ||
client = instructor.patch(OpenAI(), mode=mode) | ||
|
||
content = """ | ||
Order Details: | ||
Customer: Jason | ||
Items: | ||
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]) | ||
def test_mode(mode): | ||
client = OpenAI() | ||
client = instructor.patch(OpenAI(), mode=mode) | ||
user = client.chat.completions.create( | ||
Name: Apple, Price: 0.50 | ||
Name: Bread, Price: 2.00 | ||
Name: Milk, Price: 1.50 | ||
""" | ||
|
||
resp = client.chat.completions.create( | ||
model="gpt-3.5-turbo-1106", | ||
response_model=UserExtract, | ||
response_model=Order, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Extract jason is 25 years old", | ||
"content": content, | ||
}, | ||
], | ||
) | ||
assert user.name.lower() == "jason" | ||
assert user.age == 25 | ||
|
||
assert len(resp.items) == 3 | ||
assert {x.name.lower() for x in resp.items} == {"apple", "bread", "milk"} | ||
assert {x.price for x in resp.items} == {0.5, 2.0, 1.5} | ||
assert resp.customer.lower() == "jason" | ||
|
||
@pytest.mark.asyncio | ||
async def test_tool_call_async(): | ||
client = AsyncOpenAI() | ||
response = await client.chat.completions.create( | ||
model="gpt-3.5-turbo-1106", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Extract jason is 25 years old, mary is 30 years old", | ||
}, | ||
], | ||
tools=[ | ||
{ | ||
"type": "function", | ||
"function": UserExtract.openai_schema, | ||
} | ||
], | ||
tool_choice={ | ||
"type": "function", | ||
"function": {"name": UserExtract.openai_schema["name"]}, | ||
}, | ||
) | ||
response_message = response.choices[0].message | ||
tool_calls = response_message.tool_calls | ||
assert len(tool_calls) == 1 | ||
assert tool_calls[0].function.name == "UserExtract" | ||
assert tool_calls[0].function | ||
user = UserExtract.from_response(response, mode=Mode.TOOLS) | ||
assert user.name.lower() == "jason" | ||
assert user.age == 25 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_json_mode_async(): | ||
client = AsyncOpenAI() | ||
response = await client.chat.completions.create( | ||
model="gpt-3.5-turbo-1106", | ||
response_format={"type": "json_object"}, | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{UserExtract.model_json_schema()['properties']}", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "Extract jason is 25 years old", | ||
}, | ||
], | ||
) | ||
user = UserExtract.from_response(response, mode=Mode.JSON) | ||
assert user.name.lower() == "jason" | ||
assert user.age == 25 | ||
|
||
class Book(BaseModel): | ||
title: str | ||
author: str | ||
genre: str | ||
isbn: str | ||
|
||
@pytest.mark.asyncio | ||
async def test_markdown_json_mode_async(): | ||
client = AsyncOpenAI() | ||
response = await client.chat.completions.create( | ||
model="gpt-3.5-turbo-1106", | ||
response_format={"type": "json_object"}, | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{UserExtract.model_json_schema()['properties']}", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "Extract jason is 25 years old", | ||
}, | ||
], | ||
) | ||
user = UserExtract.from_response(response, mode=Mode.MD_JSON) | ||
assert user.name.lower() == "jason" | ||
assert user.age == 25 | ||
|
||
class LibraryRecord(BaseModel): | ||
books: List[Book] = Field(..., default_factory=list) | ||
visitor: str | ||
library_id: str | ||
|
||
|
||
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS]) | ||
async def test_mode_async(mode): | ||
client = instructor.patch(AsyncOpenAI(), mode=mode) | ||
user = client.chat.completions.create( | ||
@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS, Mode.MD_JSON]) | ||
def test_complex_nested_model(mode): | ||
client = instructor.patch(OpenAI(), mode=mode) | ||
|
||
content = """ | ||
Library visit details: | ||
Visitor: Jason | ||
Library ID: LIB123456 | ||
Books checked out: | ||
- Title: The Great Adventure, Author: Jane Doe, Genre: Fantasy, ISBN: 1234567890 | ||
- Title: History of Tomorrow, Author: John Smith, Genre: Non-Fiction, ISBN: 0987654321 | ||
""" | ||
|
||
resp = client.chat.completions.create( | ||
model="gpt-3.5-turbo-1106", | ||
response_model=UserExtract, | ||
response_model=LibraryRecord, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Extract jason is 25 years old", | ||
"content": content, | ||
}, | ||
], | ||
) | ||
assert user.name.lower() == "jason" | ||
assert user.age == 25 | ||
|
||
assert resp.visitor.lower() == "jason" | ||
assert resp.library_id == "LIB123456" | ||
assert len(resp.books) == 2 | ||
assert {book.title for book in resp.books} == { | ||
"The Great Adventure", | ||
"History of Tomorrow", | ||
} | ||
assert {book.author for book in resp.books} == {"Jane Doe", "John Smith"} | ||
assert {book.genre for book in resp.books} == {"Fantasy", "Non-Fiction"} | ||
assert {book.isbn for book in resp.books} == {"1234567890", "0987654321"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
|
||
import pytest | ||
from pydantic import BaseModel | ||
|
||
|