Skip to content

Commit

Permalink
Add Tuple support
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed May 23, 2024
1 parent 9b62df8 commit 575929c
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 11 deletions.
25 changes: 18 additions & 7 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,24 @@ def _parse_type_hint(hint):
return_dict["nullable"] = True
return return_dict
elif origin is tuple:
raise ValueError(
"This helper does not parse Tuple types, as they are usually used to indicate that "
"each position is associated with a specific type, and this requires JSON schemas "
"that are not supported by most templates. We recommend "
"either using List instead for arguments where this is appropriate, or "
"splitting arguments with Tuple types into multiple arguments that take single inputs."
)
if not get_args(hint):
return {"type": "array"}
if len(get_args(hint)) == 1:
raise ValueError(
"Tuple type hints should only be used when the argument has a fixed length and each "
f"element has a specific type. The hint {hint} indicates a Tuple of length 1. "
"This should be replaced with an unwrapped type hint instead like "
f"{get_args(hint)[0]}. Alternatively, if the "
"function can actually take a tuple with multiple elements, please either indicate "
f"each element type (e.g. Tuple[{get_args(hint)[0]}, {get_args(hint)[0]}]), "
f"or if the input can be variable length, use List[{get_args(hint)[0]}] instead."
)
if ... in get_args(hint):
raise ValueError(
"'...' is not supported in Tuple type hints. Use List[] types for variable-length"
" inputs instead."
)
return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in get_args(hint)]}
elif origin is dict:
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings
# However, we can specify the type of the dict values with "additionalProperties"
Expand Down
75 changes: 71 additions & 4 deletions tests/utils/test_chat_template_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

from transformers.utils import get_json_schema

Expand Down Expand Up @@ -234,9 +234,10 @@ def fn(x: int) -> int:
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}, "return": {"type": "integer"}},
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
"return": {"type": "integer"},
}
self.assertEqual(schema, expected_schema)

Expand All @@ -254,17 +255,83 @@ def fn(x: int) -> int:
"""
return x

schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
"return": {"type": "integer", "description": "The output"},
}
self.assertEqual(schema, expected_schema)

def test_tuple(self):
def fn(x: Tuple[int, str]):
"""
Test function
Args:
x: The input
Returns:
The output
"""
return x

schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "integer", "description": "The input"},
"return": {"type": "integer", "description": "The output"},
"x": {
"type": "array",
"prefixItems": [{"type": "integer"}, {"type": "string"}],
"description": "The input",
}
},
"required": ["x"],
},
}
self.assertEqual(schema, expected_schema)

def test_single_element_tuple_fails(self):
def fn(x: Tuple[int]):
"""
Test function
Args:
x: The input
Returns:
The output
"""
return x

# Single-element tuples should just be the type itself, or List[type] for variable-length inputs
with self.assertRaises(ValueError):
get_json_schema(fn)

def test_ellipsis_type_fails(self):
def fn(x: Tuple[int, ...]):
"""
Test function
Args:
x: The input
Returns:
The output
"""
return x

# Variable length inputs should be specified with List[type], not Tuple[type, ...]
with self.assertRaises(ValueError):
get_json_schema(fn)

0 comments on commit 575929c

Please sign in to comment.