Skip to content

Commit

Permalink
Add an extra test for very complex defs and docstrings and clean ever…
Browse files Browse the repository at this point in the history
…ything up for it
  • Loading branch information
Rocketknight1 committed May 24, 2024
1 parent 437af03 commit 0fc549d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
8 changes: 4 additions & 4 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
args_split_re = re.compile(r"(?:^|\n)\s*(\w+)\s*(?:\(\w+\))?:\s*(.*?)\s*(?=\n\s*\w|\Z)", re.DOTALL)
args_split_re = re.compile(r"(?:^|\n)\s*(\w+)\s*(?:\([\w\s\[\],.*]+\))?:\s*(.*?)\s*(?=\n\s*\w|\Z)", re.DOTALL)
returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL)


Expand Down Expand Up @@ -90,10 +90,10 @@ def _parse_type_hint(hint):
if not args:
return {"type": "array"}
if len(args) == 1:
breakpoint()
raise ValueError(
f"The type hint {hint.replace('typing.', '')} is a Tuple with a single element, which we do not "
"support as it is rarely necessary. If this input can contain more than one element, we recommend "
f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
"we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
"more than one element, we recommend "
"using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just "
"pass the element directly."
)
Expand Down
34 changes: 30 additions & 4 deletions tests/utils/test_chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def fn(x: int, y: int):

def test_everything_all_at_once(self):
def fn(
x: str, y: Optional[List[Union[int, str]]], z: Tuple[Union[int, str]] = (42, "hello")
x: str, y: Optional[List[Union[int, str]]], z: Tuple[Union[int, str], str] = (42, "hello")
) -> Tuple[int, str]:
"""
Test function with multiple args, and docstring args that we have to strip out.
Expand All @@ -428,9 +428,9 @@ def fn(
description and also contains
(choices: ["a", "b", "c"])
y (List[int, str], *optional*): The second input. It's a big list with a single-line description.
y (List[Union[int, str], *optional*): The second input. It's a big list with a single-line description.
z (Tuple[int, str]): The third input. It's some kind of tuple with a default arg.
z (Tuple[Union[int, str], str]): The third input. It's some kind of tuple with a default arg.
Returns:
The output. The return description is also a big multiline
Expand All @@ -439,5 +439,31 @@ def fn(
pass

schema = get_json_schema(fn)
breakpoint()
expected_schema = {
"name": "fn",
"description": "Test function with multiple args, and docstring args that we have to strip out.",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "string", "description": "The first input. It's got a big multiline"},
"y": {
"type": "array",
"items": {"type": ["integer", "string"]},
"nullable": True,
"description": "The second input. It's a big list with a single-line description.",
},
"z": {
"type": "array",
"prefixItems": [{"type": ["integer", "string"]}, {"type": "string"}],
"description": "The third input. It's some kind of tuple with a default arg.",
},
},
"required": ["x", "y"],
},
"return": {
"type": "array",
"prefixItems": [{"type": "integer"}, {"type": "string"}],
"description": "The output. The return description is also a big multiline\n description that spans multiple lines.",
},
}
self.assertEqual(schema, expected_schema)

0 comments on commit 0fc549d

Please sign in to comment.