Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl committed Nov 20, 2023
1 parent b28fa47 commit 56cc9ed
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 124 deletions.
10 changes: 5 additions & 5 deletions instructor/dsl/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def tasks_from_chunks(cls, json_chunks):
@staticmethod
def extract_json(completion):
for chunk in completion:
if chunk["choices"]:
delta = chunk["choices"][0]["delta"]
if "function_call" in delta:
if "arguments" in delta["function_call"]:
yield delta["function_call"]["arguments"]
try:
if json_chunk := chunk.choices[0].delta.function_call.arguments:
yield json_chunk
except AttributeError:
pass

@staticmethod
def get_object(str, stack):
Expand Down
47 changes: 47 additions & 0 deletions tests/openai/test_multitask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import time

from typing import Iterable
from openai import OpenAI
from pydantic import BaseModel

import instructor


client = instructor.patch(OpenAI())

class User(BaseModel):
name: str
age: int

def test_multi_user():
def stream_extract(input: str, cls) -> Iterable[User]:
MultiUser = instructor.MultiTask(cls)
completion = client.chat.completions.create(
model="gpt-3.5-turbo",
stream=True,
functions=[MultiUser.openai_schema],
function_call={"name": MultiUser.openai_schema["name"]},
messages=[
{
"role": "system",
"content": "You are a perfect entity extraction system",
},
{
"role": "user",
"content": (
f"Consider the data below:\n{input}"
"Correctly segment it into entitites"
"Make sure the JSON is correct"
),
},
],
max_tokens=1000,
)
return MultiUser.from_streaming_response(completion)

resp = [user for user in stream_extract(input="Jason is 20, Sarah is 30", cls=User)]
assert len(resp) == 2
assert resp[0].name == "Jason"
assert resp[0].age == 20
assert resp[1].name == "Sarah"
assert resp[1].age == 30
108 changes: 108 additions & 0 deletions tests/openai/test_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import pytest
import instructor

from instructor import llm_validator
from typing_extensions import Annotated
from pydantic import field_validator, BaseModel, BeforeValidator, ValidationError
from openai import OpenAI, AsyncOpenAI

client = instructor.patch(OpenAI())
aclient = instructor.patch(AsyncOpenAI())


class UserExtract(BaseModel):
name: str
age: int

@field_validator("name")
@classmethod
def validate_name(cls, v):
if v.upper() != v:
raise ValueError("Name should be uppercase")
return v

def test_runmodel_validator():

model = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserExtract,
max_retries=2,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert model.name == "JASON"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"


@pytest.mark.asyncio
async def test_runmodel_async_validator():
model = await aclient.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserExtract,
max_retries=2,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert model.name == "JASON"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"


class UserExtractSimple(BaseModel):
name: str
age: int


@pytest.mark.asyncio
async def test_async_runmodel():
model = await aclient.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserExtractSimple,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtractSimple), "Should be instance of UserExtractSimple"
assert model.name.lower() == "jason"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"


def test_runmodel():

model = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserExtractSimple,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtractSimple), "Should be instance of UserExtractSimple"
assert model.name.lower() == "jason"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"


def test_runmodel_validator_error():

class QuestionAnswerNoEvil(BaseModel):
question: str
answer: Annotated[
str,
BeforeValidator(llm_validator("don't say objectionable things", openai_client=client))
]

with pytest.raises(ValidationError):
QuestionAnswerNoEvil(
question="What is the meaning of life?",
answer="The meaning of life is to be evil and steal",
)
119 changes: 0 additions & 119 deletions tests/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,77 +10,6 @@

from instructor.patch import is_async, wrap_chatcompletion, OVERRIDE_DOCS

client = instructor.patch(OpenAI())
aclient = instructor.patch(AsyncOpenAI())


@pytest.mark.asyncio
async def test_async_runmodel():
class UserExtract(BaseModel):
name: str
age: int

model = await aclient.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserExtract,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert model.name.lower() == "jason"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"


def test_runmodel():
class UserExtract(BaseModel):
name: str
age: int

model = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserExtract,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert model.name.lower() == "jason"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"


def test_runmodel_validator():
from pydantic import field_validator

class UserExtract(BaseModel):
name: str
age: int

@field_validator("name")
@classmethod
def validate_name(cls, v):
if v.upper() != v:
raise ValueError("Name should be uppercase")
return v

model = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserExtract,
max_retries=2,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert model.name == "JASON"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"


def test_patch_completes_successfully():
instructor.patch(OpenAI())
Expand Down Expand Up @@ -135,53 +64,5 @@ def wrapped_function():

assert is_async(wrapped_function) is True


@pytest.mark.asyncio
async def test_async_runmodel_validator():
aclient = instructor.apatch(AsyncOpenAI())
from pydantic import field_validator

class UserExtract(BaseModel):
name: str
age: int

@field_validator("name")
@classmethod
def validate_name(cls, v):
if v.upper() != v:
raise ValueError("Name should be uppercase")
return v

model = await aclient.chat.completions.create(
model="gpt-3.5-turbo",
response_model=UserExtract,
max_retries=2,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert model.name == "JASON"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"


def test_runmodel_validator_error():


class QuestionAnswerNoEvil(BaseModel):
question: str
answer: Annotated[
str,
BeforeValidator(llm_validator("don't say objectionable things", openai_client=client))
]

with pytest.raises(ValidationError):
QuestionAnswerNoEvil(
question="What is the meaning of life?",
answer="The meaning of life is to be evil and steal",
)

def test_override_docs():
assert "response_model" in OVERRIDE_DOCS, "response_model should be in OVERRIDE_DOCS"

0 comments on commit 56cc9ed

Please sign in to comment.