-
-
Notifications
You must be signed in to change notification settings - Fork 564
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
398 additions
and
264 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import time | ||
|
||
from typing import Iterable | ||
from openai import OpenAI | ||
from pydantic import BaseModel | ||
|
||
import instructor | ||
|
||
|
||
client = instructor.patch(OpenAI()) | ||
|
||
class User(BaseModel): | ||
name: str | ||
age: int | ||
|
||
def test_multi_user(): | ||
def stream_extract(input: str, cls) -> Iterable[User]: | ||
MultiUser = instructor.MultiTask(cls) | ||
completion = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
stream=True, | ||
functions=[MultiUser.openai_schema], | ||
function_call={"name": MultiUser.openai_schema["name"]}, | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "You are a perfect entity extraction system", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": ( | ||
f"Consider the data below:\n{input}" | ||
"Correctly segment it into entitites" | ||
"Make sure the JSON is correct" | ||
), | ||
}, | ||
], | ||
max_tokens=1000, | ||
) | ||
return MultiUser.from_streaming_response(completion) | ||
|
||
resp = [user for user in stream_extract(input="Jason is 20, Sarah is 30", cls=User)] | ||
assert len(resp) == 2 | ||
assert resp[0].name == "Jason" | ||
assert resp[0].age == 20 | ||
assert resp[1].name == "Sarah" | ||
assert resp[1].age == 30 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import pytest | ||
import instructor | ||
|
||
from instructor import llm_validator | ||
from typing_extensions import Annotated | ||
from pydantic import field_validator, BaseModel, BeforeValidator, ValidationError | ||
from openai import OpenAI, AsyncOpenAI | ||
|
||
client = instructor.patch(OpenAI()) | ||
aclient = instructor.patch(AsyncOpenAI()) | ||
|
||
|
||
class UserExtract(BaseModel): | ||
name: str | ||
age: int | ||
|
||
@field_validator("name") | ||
@classmethod | ||
def validate_name(cls, v): | ||
if v.upper() != v: | ||
raise ValueError("Name should be uppercase") | ||
return v | ||
|
||
def test_runmodel_validator(): | ||
|
||
model = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=UserExtract, | ||
max_retries=2, | ||
messages=[ | ||
{"role": "user", "content": "Extract jason is 25 years old"}, | ||
], | ||
) | ||
assert isinstance(model, UserExtract), "Should be instance of UserExtract" | ||
assert model.name == "JASON" | ||
assert hasattr( | ||
model, "_raw_response" | ||
), "The raw response should be available from OpenAI" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_runmodel_async_validator(): | ||
model = await aclient.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=UserExtract, | ||
max_retries=2, | ||
messages=[ | ||
{"role": "user", "content": "Extract jason is 25 years old"}, | ||
], | ||
) | ||
assert isinstance(model, UserExtract), "Should be instance of UserExtract" | ||
assert model.name == "JASON" | ||
assert hasattr( | ||
model, "_raw_response" | ||
), "The raw response should be available from OpenAI" | ||
|
||
|
||
class UserExtractSimple(BaseModel): | ||
name: str | ||
age: int | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_async_runmodel(): | ||
model = await aclient.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=UserExtractSimple, | ||
messages=[ | ||
{"role": "user", "content": "Extract jason is 25 years old"}, | ||
], | ||
) | ||
assert isinstance(model, UserExtractSimple), "Should be instance of UserExtractSimple" | ||
assert model.name.lower() == "jason" | ||
assert hasattr( | ||
model, "_raw_response" | ||
), "The raw response should be available from OpenAI" | ||
|
||
|
||
def test_runmodel(): | ||
|
||
model = client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
response_model=UserExtractSimple, | ||
messages=[ | ||
{"role": "user", "content": "Extract jason is 25 years old"}, | ||
], | ||
) | ||
assert isinstance(model, UserExtractSimple), "Should be instance of UserExtractSimple" | ||
assert model.name.lower() == "jason" | ||
assert hasattr( | ||
model, "_raw_response" | ||
), "The raw response should be available from OpenAI" | ||
|
||
|
||
def test_runmodel_validator_error(): | ||
|
||
class QuestionAnswerNoEvil(BaseModel): | ||
question: str | ||
answer: Annotated[ | ||
str, | ||
BeforeValidator(llm_validator("don't say objectionable things", openai_client=client)) | ||
] | ||
|
||
with pytest.raises(ValidationError): | ||
QuestionAnswerNoEvil( | ||
question="What is the meaning of life?", | ||
answer="The meaning of life is to be evil and steal", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters