From 0c571c044035989d6fe33fc01fee63d1780635cb Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Wed, 18 Oct 2023 13:53:46 -0400 Subject: [PATCH] Add json schema unit tests (#5970) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add tests * add changeset * Fix tests * api-info * Add test * Add test * Add email tests * 3.8 fix 🙄 --------- Co-authored-by: gradio-pr-bot --- .changeset/dull-adults-study.md | 6 + client/python/gradio_client/utils.py | 40 ++-- client/python/test/test_client.py | 24 +-- client/python/test/test_utils.py | 2 +- test/requirements.txt | 9 +- test/test_api_info.py | 266 +++++++++++++++++++++++++++ 6 files changed, 312 insertions(+), 35 deletions(-) create mode 100644 .changeset/dull-adults-study.md create mode 100644 test/test_api_info.py diff --git a/.changeset/dull-adults-study.md b/.changeset/dull-adults-study.md new file mode 100644 index 000000000000..82ee8b0e4948 --- /dev/null +++ b/.changeset/dull-adults-study.md @@ -0,0 +1,6 @@ +--- +"gradio": minor +"gradio_client": minor +--- + +feat:Add json schema unit tests diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index 9996015aad90..0e6677a317d7 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -542,8 +542,6 @@ class APIInfoParseError(ValueError): def get_type(schema: dict): - if not isinstance(schema, dict): - breakpoint() if "const" in schema: return "const" if "enum" in schema: @@ -556,6 +554,10 @@ def get_type(schema: dict): return "oneOf" elif schema.get("anyOf"): return "anyOf" + elif schema.get("allOf"): + return "allOf" + elif "type" not in schema: + return {} else: raise APIInfoParseError(f"Cannot parse type for {schema}") @@ -574,7 +576,7 @@ def _json_schema_to_python_type(schema: Any, defs) -> str: return "Any" type_ = get_type(schema) if type_ == {}: - if "json" in schema["description"]: + if "json" in schema.get("description", {}): return "Dict[Any, Any]" else: return "Any" @@ -593,14 +595,19 @@ def _json_schema_to_python_type(schema: Any, defs) -> str: elif type_ == "boolean": return "bool" elif type_ == "number": - return "int | float" + return "float" elif type_ == "array": - items = schema.get("items") + items = schema.get("items", []) if "prefixItems" in items: elements = ", ".join( [_json_schema_to_python_type(i, defs) for i in items["prefixItems"]] ) return f"Tuple[{elements}]" + elif "prefixItems" in schema: + elements = ", ".join( + [_json_schema_to_python_type(i, defs) for i in schema["prefixItems"]] + ) + return f"Tuple[{elements}]" else: elements = _json_schema_to_python_type(items, defs) return f"List[{elements}]" @@ -609,22 +616,27 @@ def _json_schema_to_python_type(schema: Any, defs) -> str: def get_desc(v): return f" ({v.get('description')})" if v.get("description") else "" - if "additionalProperties" in schema: - return f"Dict[str, {_json_schema_to_python_type(schema['additionalProperties'], defs)}]" + props = schema.get("properties", {}) - props = schema.get("properties") + des = [ + f"{n}: {_json_schema_to_python_type(v, defs)}{get_desc(v)}" + for n, v in props.items() + if n != "$defs" + ] - des = ", ".join( - [ - f"{n}: {_json_schema_to_python_type(v, defs)}{get_desc(v)}" - for n, v in props.items() - if n != "$defs" + if "additionalProperties" in schema: + des += [ + f"str, {_json_schema_to_python_type(schema['additionalProperties'], defs)}" ] - ) + des = ", ".join(des) return f"Dict({des})" elif type_ in ["oneOf", "anyOf"]: desc = " | ".join([_json_schema_to_python_type(i, defs) for i in schema[type_]]) return desc + elif type_ == "allOf": + data = ", ".join(_json_schema_to_python_type(i, defs) for i in schema[type_]) + desc = f"All[{data}]" + return desc else: raise APIInfoParseError(f"Cannot parse schema {schema}") diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index f4cac6363965..0e7dcfe8ed84 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -619,7 +619,7 @@ def test_numerical_to_label_space(self): "label": "Age", "type": {"type": "number"}, "python_type": { - "type": "int | float", + "type": "float", "description": "", }, "component": "Slider", @@ -630,7 +630,7 @@ def test_numerical_to_label_space(self): "label": "Fare (british pounds)", "type": {"type": "number"}, "python_type": { - "type": "int | float", + "type": "float", "description": "", }, "component": "Slider", @@ -665,7 +665,7 @@ def test_numerical_to_label_space(self): "label": "Age", "type": {"type": "number"}, "python_type": { - "type": "int | float", + "type": "float", "description": "", }, "component": "Slider", @@ -676,7 +676,7 @@ def test_numerical_to_label_space(self): "label": "Fare (british pounds)", "type": {"type": "number"}, "python_type": { - "type": "int | float", + "type": "float", "description": "", }, "component": "Slider", @@ -711,7 +711,7 @@ def test_numerical_to_label_space(self): "label": "Age", "type": {"type": "number"}, "python_type": { - "type": "int | float", + "type": "float", "description": "", }, "component": "Slider", @@ -722,7 +722,7 @@ def test_numerical_to_label_space(self): "label": "Fare (british pounds)", "type": {"type": "number"}, "python_type": { - "type": "int | float", + "type": "float", "description": "", }, "component": "Slider", @@ -798,7 +798,7 @@ def test_fetch_fixed_version_space(self, calculator_demo): "label": "num1", "type": {"type": "number"}, "python_type": { - "type": "int | float", + "type": "float", "description": "", }, "component": "Number", @@ -822,7 +822,7 @@ def test_fetch_fixed_version_space(self, calculator_demo): "label": "num2", "type": {"type": "number"}, "python_type": { - "type": "int | float", + "type": "float", "description": "", }, "component": "Number", @@ -834,7 +834,7 @@ def test_fetch_fixed_version_space(self, calculator_demo): "label": "output", "type": {"type": "number"}, "python_type": { - "type": "int | float", + "type": "float", "description": "", }, "component": "Number", @@ -950,7 +950,7 @@ def test_layout_and_state_components_in_output( "label": "count", "type": {"type": "number"}, "python_type": { - "type": "int | float", + "type": "float", "description": "", }, "component": "Number", @@ -964,7 +964,7 @@ def test_layout_and_state_components_in_output( "label": "count", "type": {"type": "number"}, "python_type": { - "type": "int | float", + "type": "float", "description": "", }, "component": "Number", @@ -978,7 +978,7 @@ def test_layout_and_state_components_in_output( "label": "count", "type": {"type": "number"}, "python_type": { - "type": "int | float", + "type": "float", "description": "", }, "component": "Number", diff --git a/client/python/test/test_utils.py b/client/python/test/test_utils.py index 3208d42d1747..a99dc29bcff2 100644 --- a/client/python/test/test_utils.py +++ b/client/python/test/test_utils.py @@ -160,7 +160,7 @@ def test_json_schema_to_python_type(schema): elif schema == "BooleanSerializable": answer = "bool" elif schema == "NumberSerializable": - answer = "int | float" + answer = "float" elif schema == "ImgSerializable": answer = "str" elif schema == "FileSerializable": diff --git a/test/requirements.txt b/test/requirements.txt index c882256f499c..bd6b7caa92d0 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -232,13 +232,6 @@ traitlets==5.3.0 # matplotlib-inline transformers==4.20.1 # via -r requirements.in -typing-extensions==4.7.1 - # via - # black - # huggingface-hub - # pydantic - # starlette - # torch urllib3==1.26.10 # via # botocore @@ -247,6 +240,6 @@ vega-datasets==0.9.0 # via -r requirements.in wcwidth==0.2.5 # via prompt-toolkit - +pydantic[email] # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/test/test_api_info.py b/test/test_api_info.py new file mode 100644 index 000000000000..e0275ab5a8b3 --- /dev/null +++ b/test/test_api_info.py @@ -0,0 +1,266 @@ +from collections import namedtuple +from datetime import datetime, timedelta +from enum import Enum +from pathlib import Path +from typing import ClassVar, Dict, List, Literal, Optional, Set, Tuple, Union +from uuid import UUID + +import pytest +from gradio_client.utils import json_schema_to_python_type +from pydantic import Field, confloat, conint, conlist +from pydantic.networks import AnyUrl, EmailStr, IPvAnyAddress + +from gradio.data_classes import GradioModel, GradioRootModel + + +class StringModel(GradioModel): + data: str + answer: ClassVar = "Dict(data: str)" + + +class IntegerRootModel(GradioRootModel): + root: int + + answer: ClassVar = "int" + + +class FloatModel(GradioModel): + data: float + + answer: ClassVar = "Dict(data: float)" + + +class ListModel(GradioModel): + items: List[int] + + answer: ClassVar = "Dict(items: List[int])" + + +class DictModel(GradioModel): + data_dict: Dict[str, int] + + answer: ClassVar = "Dict(data_dict: Dict(str, int))" + + +class DictModel2(GradioModel): + data_dict: Dict[str, List[float]] + + answer: ClassVar = "Dict(data_dict: Dict(str, List[float]))" + + +class OptionalModel(GradioModel): + optional_data: Optional[int] + + answer: ClassVar = "Dict(optional_data: int | None)" + + +class ColorEnum(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + +class EnumRootModel(GradioModel): + color: ColorEnum + + answer: ClassVar = "Dict(color: Literal[red, green, blue])" + + +class EmailModel(GradioModel): + email: EmailStr + + answer: ClassVar = "Dict(email: str)" + + +class RootWithNestedModel(GradioModel): + nested_int: IntegerRootModel + nested_enum: EnumRootModel + nested_dict: DictModel2 + + answer: ClassVar = "Dict(nested_int: int, nested_enum: Dict(color: Literal[red, green, blue]), nested_dict: Dict(data_dict: Dict(str, List[float])))" + + +class LessNestedModel(GradioModel): + nested_int: int + nested_enum: ColorEnum + nested_dict: Dict[str, List[Union[int, float]]] + + answer: ClassVar = "Dict(nested_int: int, nested_enum: Literal[red, green, blue], nested_dict: Dict(str, List[int | float]))" + + +class StatusModel(GradioModel): + status: Literal["active", "inactive"] + + answer: ClassVar = "Dict(status: Literal[active, inactive])" + + +class PointModel(GradioRootModel): + root: Tuple[float, float] + + answer: ClassVar = "Tuple[float, float]" + + +class UuidModel(GradioModel): + uuid: UUID + + answer: ClassVar = "Dict(uuid: str)" + + +class UrlModel(GradioModel): + url: AnyUrl + + answer: ClassVar = "Dict(url: str)" + + +class CustomFieldModel(GradioModel): + name: str = Field(..., title="Name of the item", max_length=50) + price: float = Field(..., title="Price of the item", gt=0) + + answer: ClassVar = "Dict(name: str, price: float)" + + +class DurationModel(GradioModel): + duration: timedelta + + answer: ClassVar = "Dict(duration: str)" + + +class IPv4Model(GradioModel): + ipv4_address: IPvAnyAddress + + answer: ClassVar = "Dict(ipv4_address: str)" + + +class DateTimeModel(GradioModel): + created_at: datetime + updated_at: datetime + + answer: ClassVar = "Dict(created_at: str, updated_at: str)" + + +class SetModel(GradioModel): + unique_numbers: Set[int] + + answer: ClassVar = "Dict(unique_numbers: List[int])" + + +class ItemModel(GradioModel): + name: str + price: float + + +class OrderModel(GradioModel): + items: List[ItemModel] + + answer: ClassVar = "Dict(items: List[Dict(name: str, price: float)])" + + +class TemperatureUnitEnum(Enum): + CELSIUS = "Celsius" + FAHRENHEIT = "Fahrenheit" + KELVIN = "Kelvin" + + +class TemperatureConversionModel(GradioModel): + temperature: confloat(ge=-273.15, le=1.416808) + from_unit: TemperatureUnitEnum + to_unit: TemperatureUnitEnum = Field(..., title="Target temperature unit") + + answer: ClassVar = "Dict(temperature: float, from_unit: Literal[Celsius, Fahrenheit, Kelvin], to_unit: All[Literal[Celsius, Fahrenheit, Kelvin]])" + + +class CartItemModel(GradioModel): + product_name: str = Field(..., title="Name of the product", max_length=50) + quantity: int = Field(..., title="Quantity of the product", ge=1) + price_per_unit: float = Field(..., title="Price per unit", gt=0) + + +class ShoppingCartModel(GradioModel): + items: List[CartItemModel] + + answer: ClassVar = "Dict(items: List[Dict(product_name: str, quantity: int, price_per_unit: float)])" + + +class CoordinateModel(GradioModel): + latitude: float + longitude: float + + +class PathModel(GradioModel): + coordinates: conlist(CoordinateModel, min_length=2, max_length=2) + + answer: ClassVar = ( + "Dict(coordinates: List[Dict(latitude: float, longitude: float)])" + ) + + +class CreditCardModel(GradioModel): + card_number: conint(ge=1, le=9999999999999999) + + answer: ClassVar = "Dict(card_number: int)" + + +class TupleListModel(GradioModel): + data: List[Tuple[int, str]] + + answer: ClassVar = "Dict(data: List[Tuple[int, str]]" + + +class PathListModel(GradioModel): + file_paths: List[Path] + + answer: ClassVar = "Dict(file_paths: List[str])" + + +class PostModel(GradioModel): + author: str + content: str + tags: List[str] + likes: int = 0 + + answer: ClassVar = "Dict(author: str, content: str, tags: List[str], likes: int)" + + +Person = namedtuple("Person", ["name", "age"]) + + +class NamedTupleDictionaryModel(GradioModel): + people: Dict[str, Person] + + answer: ClassVar = "Dict(people: Dict(str, Tuple[Any, Any]))" + + +MODELS = [ + StringModel, + IntegerRootModel, + FloatModel, + ListModel, + DictModel, + DictModel2, + OptionalModel, + EnumRootModel, + EmailModel, + RootWithNestedModel, + LessNestedModel, + StatusModel, + PointModel, + UuidModel, + UrlModel, + CustomFieldModel, + DurationModel, + IPv4Model, + DateTimeModel, + SetModel, + OrderModel, + TemperatureConversionModel, + ShoppingCartModel, + PathModel, + CreditCardModel, + PathListModel, + NamedTupleDictionaryModel, +] + + +@pytest.mark.parametrize("model", MODELS) +def test_api_info_for_model(model): + assert json_schema_to_python_type(model.model_json_schema()) == model.answer