Skip to content

Commit

Permalink
Stop putting the return type in with the other parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed May 22, 2024
1 parent 7bd7f4c commit 2731c08
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,20 @@ def get_json_schema(func):
main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc)

json_schema = _convert_type_hints_to_json_schema(func)
if (return_dict := json_schema["properties"].pop("return", None)) is not None:
if return_doc is not None: # We allow a missing return docstring since most templates ignore it
return_dict["description"] = return_doc
for arg in json_schema["properties"]:
if arg == "return":
if return_doc is not None: # We allow a missing return docstring since most templates ignore it
json_schema["properties"][arg]["description"] = return_doc
continue
elif arg not in param_descriptions:
if arg not in param_descriptions:
raise ValueError(
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
)
json_schema["properties"][arg]["description"] = param_descriptions[arg]

return {"name": func.__name__, "description": main_doc, "parameters": json_schema}
output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
if return_dict is not None:
output["return"] = return_dict
return output


def add_json_schema(func):
Expand Down Expand Up @@ -247,10 +249,10 @@ def _parse_type_hint(hint):
elif origin is dict:
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings
# However, we can specify the type of the dict values with "additionalProperties"
return {
"type": "object",
"additionalProperties": _parse_type_hint(get_args(hint)[1]),
}
out = {"type": "object"}
if len(get_args(hint)) == 2:
out["additionalProperties"] = _parse_type_hint(get_args(hint)[1])
return out
else:
raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
else:
Expand Down

0 comments on commit 2731c08

Please sign in to comment.