diff --git a/openai_function_calling/function.py b/openai_function_calling/function.py index cf0de34..c921497 100644 --- a/openai_function_calling/function.py +++ b/openai_function_calling/function.py @@ -26,6 +26,7 @@ class FunctionDict(TypedDict): name: str description: str parameters: ParametersDict + strict: NotRequired[bool] class Function: @@ -37,6 +38,7 @@ def __init__( description: str, parameters: list[Parameter] | None = None, required_parameters: list[str] | None = None, + strict: bool | None = None, ) -> None: """Create a new function instance. @@ -46,12 +48,14 @@ def __init__( parameters: A list of parameters. required_parameters: A list of parameter names that are required to run the\ function. + strict: If the function should enforce strict parameters. """ self.name: str = name self.description: str = description self.parameters: list[Parameter] = parameters or [] self.required_parameters: list[str] = required_parameters or [] + self.strict: bool | None = strict self.validate() @@ -110,6 +114,9 @@ def to_json_schema(self) -> FunctionDict: }, } + if self.strict is not None: + output_dict["strict"] = self.strict + if self.required_parameters is None or len(self.required_parameters) == 0: return output_dict diff --git a/tests/test_function.py b/tests/test_function.py index 5d14ced..4314783 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -343,3 +343,35 @@ def test_merge_with_no_parameters_does_not_add_any() -> None: get_current_weather_function.merge(get_tomorrows_weather_function) assert len(get_current_weather_function.parameters) == 0 + + +def test_function_with_strict_true_includes_strict_in_output() -> None: + func = Function( + name="example_function", + description="An example function", + strict=True, + ) + func_dict: FunctionDict = func.to_json_schema() + + assert func_dict.get("strict") is True + + +def test_function_with_strict_false_includes_strict_in_output() -> None: + func = Function( + name="example_function", + description="An example function", + strict=False, + ) + func_dict: FunctionDict = func.to_json_schema() + + assert func_dict.get("strict") is False + + +def test_function_without_strict_excludes_strict_in_output() -> None: + func = Function( + name="example_function", + description="An example function", + ) + func_dict: FunctionDict = func.to_json_schema() + + assert "strict" not in func_dict