diff --git a/.changeset/clever-brooms-lay.md b/.changeset/clever-brooms-lay.md new file mode 100644 index 000000000000..ef99dbbf09bc --- /dev/null +++ b/.changeset/clever-brooms-lay.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Part I: Remove serializes diff --git a/gradio/blocks.py b/gradio/blocks.py index a3b5bc07a998..47380bf16255 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -38,7 +38,7 @@ ) from gradio.blocks_events import BlocksEvents, BlocksMeta from gradio.context import Context -from gradio.data_classes import FileData +from gradio.data_classes import FileData, GradioModel, GradioRootModel from gradio.events import ( EventData, EventListener, @@ -392,6 +392,8 @@ def postprocess_update_dict( ) if postprocess: attr_dict["value"] = block.postprocess(update_dict["value"]) + if isinstance(attr_dict["value"], (GradioModel, GradioRootModel)): + attr_dict["value"] = attr_dict["value"].model_dump() else: attr_dict["value"] = update_dict["value"] return attr_dict diff --git a/gradio/component_meta.py b/gradio/component_meta.py index 538ee1b20038..6b4c15fcb1a8 100644 --- a/gradio/component_meta.py +++ b/gradio/component_meta.py @@ -8,7 +8,6 @@ from jinja2 import Template -from gradio.data_classes import GradioModel, GradioRootModel from gradio.events import EventListener from gradio.exceptions import ComponentDefinitionError from gradio.utils import no_raise_exception @@ -57,17 +56,6 @@ def {{ event }}(self, ''' -def serializes(f): - @wraps(f) - def serialize(*args, **kwds): - output = f(*args, **kwds) - if isinstance(output, (GradioRootModel, GradioModel)): - output = output.model_dump() - return output - - return serialize - - def create_pyi(class_code: str, events: list[EventListener | str]): template = Template(INTERFACE_TEMPLATE) events = [e if isinstance(e, str) else e.event_name for e in events] @@ -193,8 +181,6 @@ def __new__(cls, name, bases, attrs): attrs[event] = trigger.listener if "EVENTS" in attrs: attrs["EVENTS"] = new_events - if "postprocess" in attrs: - attrs["postprocess"] = serializes(attrs["postprocess"]) component_class = super().__new__(cls, name, bases, attrs) create_or_modify_pyi(component_class, name, events) return component_class diff --git a/gradio/components/clear_button.py b/gradio/components/clear_button.py index 4257b706888d..26037570473e 100644 --- a/gradio/components/clear_button.py +++ b/gradio/components/clear_button.py @@ -8,6 +8,7 @@ from gradio_client.documentation import document, set_documentation_group from gradio.components import Button, Component +from gradio.data_classes import GradioModel, GradioRootModel set_documentation_group("component") @@ -72,6 +73,8 @@ def add(self, components: None | Component | list[Component]) -> ClearButton: none_values = [] for component in components: none = component.postprocess(None) + if isinstance(none, (GradioModel, GradioRootModel)): + none = none.model_dump() none_values.append(none) clear_values = json.dumps(none_values) self.click(None, [], components, js=f"() => {clear_values}") diff --git a/gradio/helpers.py b/gradio/helpers.py index 61a74944bd64..92c21c7c564c 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -444,6 +444,8 @@ def merge_generated_values_into_output( if len(components) > 1: chunk = chunk[output_index] processed_chunk = output_component.postprocess(chunk) + if isinstance(processed_chunk, (GradioModel, GradioRootModel)): + processed_chunk = processed_chunk.model_dump() binary_chunks.append( output_component.stream_output(processed_chunk, "", i == 0)[0] ) diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 7b0d5175ae35..c230272b4e72 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -20,7 +20,7 @@ from PIL import Image, ImageOps, PngImagePlugin from gradio import wasm_utils -from gradio.data_classes import FileData +from gradio.data_classes import FileData, GradioModel, GradioRootModel from gradio.utils import abspath, is_in_or_equal with warnings.catch_warnings(): @@ -278,7 +278,7 @@ def move_files_to_cache(data: Any, block: Component): Runs after postprocess and before preprocess. Args: - data: The input or output data for a component. + data: The input or output data for a component. Can be a dictionary or a dataclass block: The component """ @@ -288,6 +288,9 @@ def _move_to_cache(d: dict): payload.path = temp_file_path return payload.model_dump() + if isinstance(data, (GradioRootModel, GradioModel)): + data = data.model_dump() + return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj) diff --git a/test/test_blocks.py b/test/test_blocks.py index 9972a76d912c..6164ad9500ef 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -23,6 +23,7 @@ from PIL import Image import gradio as gr +from gradio.data_classes import GradioModel, GradioRootModel from gradio.events import SelectData from gradio.exceptions import DuplicateBlockError from gradio.networking import Server, get_first_available_port @@ -512,8 +513,15 @@ def test_blocks_do_not_filter_none_values_from_updates(self, io_components): output = demo.postprocess_data( 0, [gr.update(value=None) for _ in io_components], state=None ) + + def process_and_dump(component): + output = component.postprocess(None) + if isinstance(output, (GradioModel, GradioRootModel)): + output = output.model_dump() + return output + assert all( - o["value"] == c.postprocess(None) for o, c in zip(output, io_components) + o["value"] == process_and_dump(c) for o, c in zip(output, io_components) ) def test_blocks_does_not_replace_keyword_literal(self): diff --git a/test/test_components.py b/test/test_components.py index b7c2f66a23ea..1b1731e14511 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -604,7 +604,9 @@ def test_component_functions(self, gradio_temp_dir): # Output functionalities image_output = gr.Image(type="pil") - processed_image = image_output.postprocess(PIL.Image.open(img["path"])) + processed_image = image_output.postprocess( + PIL.Image.open(img["path"]) + ).model_dump() assert processed_image is not None if processed_image is not None: processed = PIL.Image.open(cast(dict, processed_image).get("path", "")) @@ -683,7 +685,7 @@ def test_postprocess_altair(self): color="Origin", ) ) - out = gr.Plot().postprocess(chart) + out = gr.Plot().postprocess(chart).model_dump() assert isinstance(out["plot"], str) assert out["plot"] == chart.to_json() @@ -752,7 +754,9 @@ def test_component_functions(self, gradio_temp_dir): deepcopy(media_data.BASE64_AUDIO)["data"] ) audio_output = gr.Audio(type="filepath") - assert filecmp.cmp(y_audio.name, audio_output.postprocess(y_audio.name)["path"]) + assert filecmp.cmp( + y_audio.name, audio_output.postprocess(y_audio.name).model_dump()["path"] + ) assert audio_output.get_config() == { "autoplay": False, "name": "audio", @@ -780,8 +784,8 @@ def test_component_functions(self, gradio_temp_dir): "_selectable": False, } - output1 = audio_output.postprocess(y_audio.name) - output2 = audio_output.postprocess(Path(y_audio.name)) + output1 = audio_output.postprocess(y_audio.name).model_dump() + output2 = audio_output.postprocess(Path(y_audio.name)).model_dump() assert output1 == output2 def test_default_value_postprocess(self): @@ -837,7 +841,7 @@ def test_prepost_process_to_mp3(self, gradio_temp_dir): assert output.endswith("mp3") output = audio_input.postprocess( (48000, np.random.randint(-256, 256, (5, 3)).astype(np.int16)) - ) + ).model_dump() assert output["path"].endswith("mp3") @@ -1041,15 +1045,15 @@ def test_postprocess(self): postprocess """ dataframe_output = gr.Dataframe() - output = dataframe_output.postprocess([]) + output = dataframe_output.postprocess([]).model_dump() assert output == {"data": [[]], "headers": [], "metadata": None} - output = dataframe_output.postprocess(np.zeros((2, 2))) + output = dataframe_output.postprocess(np.zeros((2, 2))).model_dump() assert output == { "data": [[0, 0], [0, 0]], "headers": ["1", "2"], "metadata": None, } - output = dataframe_output.postprocess([[1, 3, 5]]) + output = dataframe_output.postprocess([[1, 3, 5]]).model_dump() assert output == { "data": [[1, 3, 5]], "headers": ["1", "2", "3"], @@ -1057,7 +1061,7 @@ def test_postprocess(self): } output = dataframe_output.postprocess( pd.DataFrame([[2, True], [3, True], [4, False]], columns=["num", "prime"]) - ) + ).model_dump() assert output == { "headers": ["num", "prime"], "data": [[2, True], [3, True], [4, False]], @@ -1068,14 +1072,16 @@ def test_postprocess(self): # When the headers don't match the data dataframe_output = gr.Dataframe(headers=["one", "two", "three"]) - output = dataframe_output.postprocess([[2, True], [3, True]]) + output = dataframe_output.postprocess([[2, True], [3, True]]).model_dump() assert output == { "headers": ["one", "two"], "data": [[2, True], [3, True]], "metadata": None, } dataframe_output = gr.Dataframe(headers=["one", "two", "three"]) - output = dataframe_output.postprocess([[2, True, "ab", 4], [3, True, "cd", 5]]) + output = dataframe_output.postprocess( + [[2, True, "ab", 4], [3, True, "cd", 5]] + ).model_dump() assert output == { "headers": ["one", "two", "three", "4"], "data": [[2, True, "ab", 4], [3, True, "cd", 5]], @@ -1098,7 +1104,7 @@ def test_dataframe_postprocess_all_types(self): component = gr.Dataframe( datatype=["date", "date", "number", "number", "bool", "markdown"] ) - output = component.postprocess(df) + output = component.postprocess(df).model_dump() assert output == { "headers": list(df.columns), "data": [ @@ -1130,7 +1136,7 @@ def test_dataframe_postprocess_only_dates(self): } ) component = gr.Dataframe(datatype=["date", "date"]) - output = component.postprocess(df) + output = component.postprocess(df).model_dump() assert output == { "headers": list(df.columns), "data": [ @@ -1156,7 +1162,7 @@ def test_dataframe_postprocess_styler(self): } ) s = df.style.format(precision=1, decimal=",") - output = component.postprocess(s) + output = component.postprocess(s).model_dump() assert output == { "data": [ ["Adam", 1.1, 800], @@ -1204,7 +1210,7 @@ def test_dataframe_postprocess_styler(self): ) t = df.style.highlight_max(color="lightgreen", axis=0) - output = component.postprocess(t) + output = component.postprocess(t).model_dump() assert output == { "data": [ [14, 5, 20, 14, 23], @@ -1362,23 +1368,25 @@ def test_component_functions(self): y_vid_path = "test/test_files/video_sample.mp4" subtitles_path = "test/test_files/s1.srt" video_output = gr.Video() - output1 = video_output.postprocess(y_vid_path)["video"]["path"] + output1 = video_output.postprocess(y_vid_path).model_dump()["video"]["path"] assert output1.endswith("mp4") - output2 = video_output.postprocess(y_vid_path)["video"]["path"] + output2 = video_output.postprocess(y_vid_path).model_dump()["video"]["path"] assert output1 == output2 assert ( - video_output.postprocess(y_vid_path)["video"]["orig_name"] + video_output.postprocess(y_vid_path).model_dump()["video"]["orig_name"] == "video_sample.mp4" ) - output_with_subtitles = video_output.postprocess((y_vid_path, subtitles_path)) + output_with_subtitles = video_output.postprocess( + (y_vid_path, subtitles_path) + ).model_dump() assert output_with_subtitles["subtitles"]["path"].endswith(".vtt") p_video = gr.Video() video_with_subtitle = gr.Video() - postprocessed_video = p_video.postprocess(Path(y_vid_path)) + postprocessed_video = p_video.postprocess(Path(y_vid_path)).model_dump() postprocessed_video_with_subtitle = video_with_subtitle.postprocess( (Path(y_vid_path), Path(subtitles_path)) - ) + ).model_dump() processed_video = { "video": { @@ -1443,7 +1451,7 @@ def test_video_postprocess_converts_to_playable_format(self): bad_vid = str(test_file_dir / "bad_video_sample.mp4") assert not processing_utils.video_is_playable(bad_vid) shutil.copy(bad_vid, tmp_not_playable_vid.name) - output = gr.Video().postprocess(tmp_not_playable_vid.name) + output = gr.Video().postprocess(tmp_not_playable_vid.name).model_dump() assert processing_utils.video_is_playable(output["video"]["path"]) # This file has a playable codec but not a playable container @@ -1453,7 +1461,7 @@ def test_video_postprocess_converts_to_playable_format(self): bad_vid = str(test_file_dir / "playable_but_bad_container.mkv") assert not processing_utils.video_is_playable(bad_vid) shutil.copy(bad_vid, tmp_not_playable_vid.name) - output = gr.Video().postprocess(tmp_not_playable_vid.name) + output = gr.Video().postprocess(tmp_not_playable_vid.name).model_dump() assert processing_utils.video_is_playable(output["video"]["path"]) @patch("pathlib.Path.exists", MagicMock(return_value=False)) @@ -1516,11 +1524,11 @@ def test_component_functions(self): """ y = "happy" label_output = gr.Label() - label = label_output.postprocess(y) + label = label_output.postprocess(y).model_dump() assert label == {"label": "happy", "confidences": None} y = {3: 0.7, 1: 0.2, 0: 0.1} - label = label_output.postprocess(y) + label = label_output.postprocess(y).model_dump() assert label == { "label": 3, "confidences": [ @@ -1530,7 +1538,7 @@ def test_component_functions(self): ], } label_output = gr.Label(num_top_classes=2) - label = label_output.postprocess(y) + label = label_output.postprocess(y).model_dump() assert label == { "label": 3, @@ -1540,11 +1548,11 @@ def test_component_functions(self): ], } with pytest.raises(ValueError): - label_output.postprocess([1, 2, 3]) + label_output.postprocess([1, 2, 3]).model_dump() test_file_dir = Path(Path(__file__).parent, "test_files") path = str(Path(test_file_dir, "test_label_json.json")) - label_dict = label_output.postprocess(path) + label_dict = label_output.postprocess(path).model_dump() assert label_dict["label"] == "web site" assert label_output.get_config() == { @@ -1616,7 +1624,7 @@ def test_postprocess(self): {"token": "Berlin", "class_or_confidence": "LOC"}, {"token": "", "class_or_confidence": None}, ] - result_ = component.postprocess(value) + result_ = component.postprocess(value).model_dump() assert result == result_ text = "Wolfgang lives in Berlin" @@ -1624,7 +1632,9 @@ def test_postprocess(self): {"entity": "PER", "start": 0, "end": 8}, {"entity": "LOC", "start": 18, "end": 24}, ] - result_ = component.postprocess({"text": text, "entities": entities}) + result_ = component.postprocess( + {"text": text, "entities": entities} + ).model_dump() assert result == result_ text = "Wolfgang lives in Berlin" @@ -1632,7 +1642,9 @@ def test_postprocess(self): {"entity_group": "PER", "start": 0, "end": 8}, {"entity": "LOC", "start": 18, "end": 24}, ] - result_ = component.postprocess({"text": text, "entities": entities}) + result_ = component.postprocess( + {"text": text, "entities": entities} + ).model_dump() assert result == result_ # Test split entity is merged when combine adjacent is set @@ -1649,12 +1661,16 @@ def test_postprocess(self): {"token": " lives in ", "class_or_confidence": None}, {"token": "Berlin", "class_or_confidence": "LOC"}, ] - result_ = component.postprocess({"text": text, "entities": entities}) + result_ = component.postprocess( + {"text": text, "entities": entities} + ).model_dump() assert result != result_ assert result_after_merge != result_ component = gr.HighlightedText(combine_adjacent=True) - result_ = component.postprocess({"text": text, "entities": entities}) + result_ = component.postprocess( + {"text": text, "entities": entities} + ).model_dump() assert result_after_merge == result_ component = gr.HighlightedText() @@ -1664,19 +1680,25 @@ def test_postprocess(self): {"entity": "LOC", "start": 18, "end": 24}, {"entity": "PER", "start": 0, "end": 8}, ] - result_ = component.postprocess({"text": text, "entities": entities}) + result_ = component.postprocess( + {"text": text, "entities": entities} + ).model_dump() assert result == result_ text = "I live there" entities = [] - result_ = component.postprocess({"text": text, "entities": entities}) + result_ = component.postprocess( + {"text": text, "entities": entities} + ).model_dump() assert [{"token": text, "class_or_confidence": None}] == result_ text = "Wolfgang" entities = [ {"entity": "PER", "start": 0, "end": 8}, ] - result_ = component.postprocess({"text": text, "entities": entities}) + result_ = component.postprocess( + {"text": text, "entities": entities} + ).model_dump() assert [ {"token": "", "class_or_confidence": None}, {"token": text, "class_or_confidence": "PER"}, @@ -1750,7 +1772,7 @@ def test_postprocess(self): mask2[10:20, 10:20] = 1 input = (img, [(mask1, "mask1"), (mask2, "mask2")]) - result = component.postprocess(input) + result = component.postprocess(input).model_dump() base_img_out = PIL.Image.open(result["image"]["path"]) @@ -1806,9 +1828,9 @@ def test_component_functions(self): Postprocess, get_config """ chatbot = gr.Chatbot() - assert chatbot.postprocess([["You are **cool**\nand fun", "so are *you*"]]) == [ - ("You are **cool**\nand fun", "so are *you*") - ] + assert chatbot.postprocess( + [["You are **cool**\nand fun", "so are *you*"]] + ).model_dump() == [("You are **cool**\nand fun", "so are *you*")] multimodal_msg = [ [("test/test_files/video_sample.mp4",), "cool video"], @@ -1818,7 +1840,7 @@ def test_component_functions(self): [(Path("test/test_files/audio_sample.wav"),), "cool audio"], [(Path("test/test_files/bus.png"), "A bus"), "cool pic"], ] - postprocessed_multimodal_msg = chatbot.postprocess(multimodal_msg) + postprocessed_multimodal_msg = chatbot.postprocess(multimodal_msg).model_dump() for msg in postprocessed_multimodal_msg: assert "file" in msg[0] assert msg[1] in {"cool video", "cool audio", "cool pic"} @@ -2059,7 +2081,9 @@ def test_gallery(self, mock_uuid): client_utils.encode_file_to_base64(Path(test_file_dir, "cheetah1.jpg")), ] - postprocessed_gallery = gallery.postprocess([Path("test/test_files/bus.png")]) + postprocessed_gallery = gallery.postprocess( + [Path("test/test_files/bus.png")] + ).model_dump() processed_gallery = [ { "image": { @@ -2217,7 +2241,7 @@ def test_no_color(self): title="Car Data", x_title="Horse", ) - output = plot.postprocess(cars) + output = plot.postprocess(cars).model_dump() assert sorted(output.keys()) == ["chart", "plot", "type"] config = json.loads(output["plot"]) assert config["encoding"]["x"]["field"] == "Horsepower" @@ -2231,7 +2255,7 @@ def test_no_interactive(self): plot = gr.ScatterPlot( x="Horsepower", y="Miles_per_Gallon", tooltip="Name", interactive=False ) - output = plot.postprocess(cars) + output = plot.postprocess(cars).model_dump() assert sorted(output.keys()) == ["chart", "plot", "type"] config = json.loads(output["plot"]) assert "selection" not in config @@ -2240,7 +2264,7 @@ def test_height_width(self): plot = gr.ScatterPlot( x="Horsepower", y="Miles_per_Gallon", height=100, width=200 ) - output = plot.postprocess(cars) + output = plot.postprocess(cars).model_dump() assert sorted(output.keys()) == ["chart", "plot", "type"] config = json.loads(output["plot"]) assert config["height"] == 100 @@ -2250,7 +2274,7 @@ def test_xlim_ylim(self): plot = gr.ScatterPlot( x="Horsepower", y="Miles_per_Gallon", x_lim=[200, 400], y_lim=[300, 500] ) - output = plot.postprocess(cars) + output = plot.postprocess(cars).model_dump() config = json.loads(output["plot"]) assert config["encoding"]["x"]["scale"] == {"domain": [200, 400]} assert config["encoding"]["y"]["scale"] == {"domain": [300, 500]} @@ -2263,7 +2287,7 @@ def test_color_encoding(self): title="Car Data", color="Origin", ) - output = plot.postprocess(cars) + output = plot.postprocess(cars).model_dump() config = json.loads(output["plot"]) assert config["encoding"]["color"]["field"] == "Origin" assert config["encoding"]["color"]["scale"] == { @@ -2281,7 +2305,7 @@ def test_two_encodings(self): color="Acceleration", shape="Origin", ) - output = plot.postprocess(cars) + output = plot.postprocess(cars).model_dump() config = json.loads(output["plot"]) assert config["encoding"]["color"]["field"] == "Acceleration" assert config["encoding"]["color"]["scale"] == { @@ -2309,7 +2333,7 @@ def test_legend_position(self): size_legend_title="Accel", size_legend_position="none", ) - output = plot.postprocess(cars) + output = plot.postprocess(cars).model_dump() config = json.loads(output["plot"]) assert config["encoding"]["color"]["legend"] is None assert config["encoding"]["shape"]["legend"] is None @@ -2375,7 +2399,7 @@ def test_no_color(self): title="Stock Performance", x_title="Trading Day", ) - output = plot.postprocess(stocks) + output = plot.postprocess(stocks).model_dump() assert sorted(output.keys()) == ["chart", "plot", "type"] config = json.loads(output["plot"]) for layer in config["layer"]: @@ -2390,7 +2414,7 @@ def test_no_color(self): def test_height_width(self): plot = gr.LinePlot(x="date", y="price", height=100, width=200) - output = plot.postprocess(stocks) + output = plot.postprocess(stocks).model_dump() assert sorted(output.keys()) == ["chart", "plot", "type"] config = json.loads(output["plot"]) assert config["height"] == 100 @@ -2398,7 +2422,7 @@ def test_height_width(self): def test_xlim_ylim(self): plot = gr.LinePlot(x="date", y="price", x_lim=[200, 400], y_lim=[300, 500]) - output = plot.postprocess(stocks) + output = plot.postprocess(stocks).model_dump() config = json.loads(output["plot"]) for layer in config["layer"]: assert layer["encoding"]["x"]["scale"] == {"domain": [200, 400]} @@ -2408,7 +2432,7 @@ def test_color_encoding(self): plot = gr.LinePlot( x="date", y="price", tooltip="symbol", color="symbol", overlay_point=True ) - output = plot.postprocess(stocks) + output = plot.postprocess(stocks).model_dump() config = json.loads(output["plot"]) for layer in config["layer"]: assert layer["encoding"]["color"]["field"] == "symbol" @@ -2480,7 +2504,7 @@ def test_no_color(self): x_title="Variable A", sort="x", ) - output = plot.postprocess(simple) + output = plot.postprocess(simple).model_dump() assert sorted(output.keys()) == ["chart", "plot", "type"] assert output["chart"] == "bar" config = json.loads(output["plot"]) @@ -2496,7 +2520,7 @@ def test_no_color(self): def test_height_width(self): plot = gr.BarPlot(x="a", y="b", height=100, width=200) - output = plot.postprocess(simple) + output = plot.postprocess(simple).model_dump() assert sorted(output.keys()) == ["chart", "plot", "type"] config = json.loads(output["plot"]) assert config["height"] == 100 @@ -2504,7 +2528,7 @@ def test_height_width(self): def test_ylim(self): plot = gr.BarPlot(x="a", y="b", y_lim=[15, 100]) - output = plot.postprocess(simple) + output = plot.postprocess(simple).model_dump() config = json.loads(output["plot"]) assert config["encoding"]["y"]["scale"] == {"domain": [15, 100]}