From 2731c08b8f3bb4791c322b773a738479db88a016 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 21 May 2024 15:45:54 +0100 Subject: [PATCH] Stop putting the return type in with the other parameters --- src/transformers/utils/chat_template_utils.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 917fd632c35f7..af217eb844e5f 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -88,18 +88,20 @@ def get_json_schema(func): main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc) json_schema = _convert_type_hints_to_json_schema(func) + if (return_dict := json_schema["properties"].pop("return", None)) is not None: + if return_doc is not None: # We allow a missing return docstring since most templates ignore it + return_dict["description"] = return_doc for arg in json_schema["properties"]: - if arg == "return": - if return_doc is not None: # We allow a missing return docstring since most templates ignore it - json_schema["properties"][arg]["description"] = return_doc - continue - elif arg not in param_descriptions: + if arg not in param_descriptions: raise ValueError( f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" ) json_schema["properties"][arg]["description"] = param_descriptions[arg] - return {"name": func.__name__, "description": main_doc, "parameters": json_schema} + output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} + if return_dict is not None: + output["return"] = return_dict + return output def add_json_schema(func): @@ -247,10 +249,10 @@ def _parse_type_hint(hint): elif origin is dict: # The JSON equivalent to a dict is 'object', which mandates that all keys are strings # However, we can specify the type of the dict values with "additionalProperties" - return { - "type": "object", - "additionalProperties": _parse_type_hint(get_args(hint)[1]), - } + out = {"type": "object"} + if len(get_args(hint)) == 2: + out["additionalProperties"] = _parse_type_hint(get_args(hint)[1]) + return out else: raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint) else: