Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl committed Nov 28, 2023
1 parent 47fa89e commit e6f8cae
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
1 change: 1 addition & 0 deletions examples/streaming_multitask/streaming_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class User(BaseModel):

Users = Iterable[User]


def stream_extract(input: str) -> Users:
return client.chat.completions.create(
model="gpt-4-0613",
Expand Down
4 changes: 1 addition & 3 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,7 @@ def from_response(
cls (OpenAISchema): An instance of the class
"""
if stream_multitask:
return cls.from_streaming_response(
completion, mode
)
return cls.from_streaming_response(completion, mode)

message = completion.choices[0].message

Expand Down
16 changes: 12 additions & 4 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,13 @@ def handle_response_model(
response_model = MultiTask(iterable_element_class)
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model) # type: ignore

if new_kwargs.get("stream", False) and not issubclass(response_model, MultiTaskBase):
raise NotImplementedError("stream=True is not supported when using response_model parameter for non-iterables")

if new_kwargs.get("stream", False) and not issubclass(
response_model, MultiTaskBase
):
raise NotImplementedError(
"stream=True is not supported when using response_model parameter for non-iterables"
)

if mode == Mode.FUNCTIONS:
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
Expand Down Expand Up @@ -123,7 +127,11 @@ def process_response(
if response_model is not None:
is_model_multitask = issubclass(response_model, MultiTaskBase)
model = response_model.from_response(
response, validation_context=validation_context, strict=strict, mode=mode, stream_multitask=stream and is_model_multitask
response,
validation_context=validation_context,
strict=strict,
mode=mode,
stream_multitask=stream and is_model_multitask,
)
if not stream:
model._raw_response = response
Expand Down
9 changes: 7 additions & 2 deletions tests/openai/test_multitask.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from typing import Iterable
from openai import OpenAI
from pydantic import BaseModel
Expand All @@ -11,12 +10,13 @@ class User(BaseModel):
name: str
age: int

Users = Iterable[User]

Users = Iterable[User]


def test_multi_user_function_mode():
client = instructor.patch(OpenAI(), mode=Mode.FUNCTIONS)

def stream_extract(input: str) -> Iterable[User]:
return client.chat.completions.create(
model="gpt-3.5-turbo",
Expand Down Expand Up @@ -46,8 +46,10 @@ def stream_extract(input: str) -> Iterable[User]:
assert resp[1].name == "Sarah"
assert resp[1].age == 30


def test_multi_user_json_mode():
client = instructor.patch(OpenAI(), mode=Mode.JSON)

def stream_extract(input: str) -> Iterable[User]:
return client.chat.completions.create(
model="gpt-3.5-turbo-1106",
Expand All @@ -74,8 +76,10 @@ def stream_extract(input: str) -> Iterable[User]:
assert resp[1].name == "Sarah"
assert resp[1].age == 30


def test_multi_user_tools_mode():
client = instructor.patch(OpenAI(), mode=Mode.TOOLS)

def stream_extract(input: str) -> Iterable[User]:
return client.chat.completions.create(
model="gpt-3.5-turbo-1106",
Expand All @@ -102,6 +106,7 @@ def stream_extract(input: str) -> Iterable[User]:
assert resp[1].name == "Sarah"
assert resp[1].age == 30


def test_multi_user_legacy():
def stream_extract(input: str, cls) -> Iterable[User]:
client = instructor.patch(OpenAI())
Expand Down

0 comments on commit e6f8cae

Please sign in to comment.