Skip to content

Commit

Permalink
remove function decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl committed Dec 9, 2023
1 parent 68a6732 commit 0aeb41c
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 129 deletions.
3 changes: 1 addition & 2 deletions instructor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from .distil import FinetuneFormat, Instructions
from .dsl import CitationMixin, Maybe, MultiTask, llm_validator, openai_moderation
from .function_calls import OpenAISchema, openai_function, openai_schema, Mode
from .function_calls import OpenAISchema, openai_schema, Mode
from .patch import apatch, patch

__all__ = [
"OpenAISchema",
"openai_function",
"CitationMixin",
"MultiTask",
"Maybe",
Expand Down
83 changes: 0 additions & 83 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,89 +16,6 @@ class Mode(enum.Enum):
MD_JSON: str = "markdown_json_mode"


class openai_function:
"""
Decorator to convert a function into an OpenAI function.
This decorator will convert a function into an OpenAI function. The
function will be validated using pydantic and the schema will be
generated from the function signature.
Example:
```python
@openai_function
def sum(a: int, b: int) -> int:
return a + b
completion = openai.ChatCompletion.create(
...
messages=[{
"content": "What is 1 + 1?",
"role": "user"
}]
)
sum.from_response(completion)
# 2
```
"""

def __init__(self, func: Callable) -> None:
self.func = func
self.validate_func = validate_arguments(func)
self.docstring = parse(self.func.__doc__ or "")

parameters = self.validate_func.model.model_json_schema()
parameters["properties"] = {
k: v
for k, v in parameters["properties"].items()
if k not in ("v__duplicate_kwargs", "args", "kwargs")
}
for param in self.docstring.params:
if (name := param.arg_name) in parameters["properties"] and (
description := param.description
):
parameters["properties"][name]["description"] = description
parameters["required"] = sorted(
k for k, v in parameters["properties"].items() if "default" not in v
)
self.openai_schema = {
"name": self.func.__name__,
"description": self.docstring.short_description,
"parameters": parameters,
}
self.model = self.validate_func.model

def __call__(self, *args: Any, **kwargs: Any) -> Any:
@wraps(self.func)
def wrapper(*args, **kwargs):
return self.validate_func(*args, **kwargs)

return wrapper(*args, **kwargs)

def from_response(self, completion, throw_error=True, strict: bool = None):
"""
Parse the response from OpenAI's API and return the function call
Parameters:
completion (openai.ChatCompletion): The response from OpenAI's API
throw_error (bool): Whether to throw an error if the response does not contain a function call
Returns:
result (any): result of the function call
"""
message = completion["choices"][0]["message"]

if throw_error:
assert "function_call" in message, "No function call detected"
assert (
message["function_call"]["name"] == self.openai_schema["name"]
), "Function name does not match"

function_call = message["function_call"]
arguments = json.loads(function_call["arguments"], strict=strict)
return self.validate_func(**arguments)


class OpenAISchema(BaseModel):
"""
Augments a Pydantic model with OpenAI's schema for function calling
Expand Down
45 changes: 1 addition & 44 deletions tests/test_function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from pydantic import BaseModel

from instructor import openai_schema, OpenAISchema, openai_function
from instructor import openai_schema, OpenAISchema


def test_openai_schema():
Expand Down Expand Up @@ -42,46 +42,3 @@ class Dummy(OpenAISchema):
Dummy.openai_schema["description"]
== "Correctly extracted `Dummy` with all the required parameters with correct types"
)


def test_openai_function():
@openai_function
def get_current_weather(
location: str, format: Literal["celsius", "fahrenheit"] = "celsius"
):
"""
Gets the current weather in a given location, use this function for any questions related to the weather
Parameters
----------
location
The city to get the weather, e.g. San Francisco. Guess the location from user messages
format
A string with the full content of what the given role said
"""

@openai_function
def get_current_weather_no_format_docstring(
location: str, format: Literal["celsius", "fahrenheit"] = "celsius"
):
"""
Gets the current weather in a given location, use this function for any questions related to the weather
Parameters
----------
location
The city to get the weather, e.g. San Francisco. Guess the location from user messages
"""

scheme_missing_param = get_current_weather_no_format_docstring.openai_schema
assert (
scheme_missing_param["parameters"]["properties"]["location"]["description"]
== "The city to get the weather, e.g. San Francisco. Guess the location from user messages"
)
assert scheme_missing_param["parameters"]["properties"]["format"]["enum"] == [
"celsius",
"fahrenheit",
]
with pytest.raises(KeyError, match="description"):
scheme_missing_param["parameters"]["properties"]["format"]["description"]

0 comments on commit 0aeb41c

Please sign in to comment.