Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl committed Mar 28, 2024
1 parent 16c36ca commit 6d23442
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 8 deletions.
32 changes: 32 additions & 0 deletions examples/classification/test_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from examples.planning.run import extract_person, extract_people, Person


@pytest.mark.asyncio
async def test_extract_person():
# Test the extract_person function with a known input
text = "John is 45 years old"
expected_person = Person(name="John", age=45)
person = await extract_person(text)
assert (
person == expected_person
), "The extracted person does not match the expected person"


@pytest.mark.asyncio
@pytest.mark.parametrize(
"names_and_ages, expected_people",
[
(
["Alice is 30 years old", "Bob is 24 years old"],
[Person(name="Alice", age=30), Person(name="Bob", age=24)],
)
],
)
async def test_extract_people(names_and_ages, expected_people):
# Test the extract_people function with a list of known inputs
people = await extract_people(names_and_ages)
assert (
people == expected_people
), "The extracted people do not match the expected people"
1 change: 1 addition & 0 deletions examples/match_language/run_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
class GeneratedSummary(BaseModel):
summary: str


async def summarize_text(text: str):
response = await client.chat.completions.create(
model="gpt-3.5-turbo",
Expand Down
3 changes: 2 additions & 1 deletion examples/match_language/run_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class GeneratedSummary(BaseModel):
)
summary: str


async def summarize_text(text: str):
response = await client.chat.completions.create(
model="gpt-3.5-turbo",
Expand Down Expand Up @@ -94,4 +95,4 @@ async def main():
Source: de, Summary: de, Match: True, Detected: de
Source: hi, Summary: hi, Match: True, Detected: hi
Source: ja, Summary: ja, Match: True, Detected: ja
"""
"""
4 changes: 2 additions & 2 deletions instructor/anthropic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def _add_params(
field_type = details.get(
"type", "unknown"
) # Might be better to fail here if there is no type since pydantic models require types

if "array" in field_type and "items" not in details:
raise ValueError("Invalid array item.")

# Check for nested List
if "array" in field_type and "$ref" in details["items"]:
type_element.text = f"List[{details['title']}]"
list_found = True
nested_list_found = True
nested_list_found = True
# Check for non-nested List
elif "array" in field_type and "type" in details["items"]:
type_element.text = f"List[{details['items']['type']}]"
Expand Down
1 change: 0 additions & 1 deletion instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def anthropic_schema(cls) -> str:
for line in parseString(json_to_xml(cls)).toprettyxml().splitlines()[1:]
)


@classmethod
def from_response(
cls,
Expand Down
8 changes: 4 additions & 4 deletions tests/anthropic/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ class User(BaseModel):
name: str
age: int
family: List[str]

resp = create(
model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307
model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307
max_tokens=1024,
max_retries=0,
messages=[
Expand All @@ -81,13 +81,13 @@ class User(BaseModel):
],
response_model=User,
)

assert isinstance(resp, User)
assert isinstance(resp.family, List)
for member in resp.family:
assert isinstance(member, str)


def test_nested_list():
class Properties(BaseModel):
key: str
Expand Down

0 comments on commit 6d23442

Please sign in to comment.