Skip to content

Commit

Permalink
Part I: Remove serializes (#6177)
Browse files Browse the repository at this point in the history
* remove serializse

* lint

* add changeset

* lint

* fix test

* fix tests

* fix final test

* fix tests

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
2 people authored and dawoodkhan82 committed Oct 31, 2023
1 parent 76a06bd commit 4ee9c4d
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 75 deletions.
5 changes: 5 additions & 0 deletions .changeset/clever-brooms-lay.md
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Part I: Remove serializes
4 changes: 3 additions & 1 deletion gradio/blocks.py
Expand Up @@ -38,7 +38,7 @@
)
from gradio.blocks_events import BlocksEvents, BlocksMeta
from gradio.context import Context
from gradio.data_classes import FileData
from gradio.data_classes import FileData, GradioModel, GradioRootModel
from gradio.events import (
EventData,
EventListener,
Expand Down Expand Up @@ -392,6 +392,8 @@ def postprocess_update_dict(
)
if postprocess:
attr_dict["value"] = block.postprocess(update_dict["value"])
if isinstance(attr_dict["value"], (GradioModel, GradioRootModel)):
attr_dict["value"] = attr_dict["value"].model_dump()
else:
attr_dict["value"] = update_dict["value"]
return attr_dict
Expand Down
14 changes: 0 additions & 14 deletions gradio/component_meta.py
Expand Up @@ -8,7 +8,6 @@

from jinja2 import Template

from gradio.data_classes import GradioModel, GradioRootModel
from gradio.events import EventListener
from gradio.exceptions import ComponentDefinitionError
from gradio.utils import no_raise_exception
Expand Down Expand Up @@ -57,17 +56,6 @@ def {{ event }}(self,
'''


def serializes(f):
@wraps(f)
def serialize(*args, **kwds):
output = f(*args, **kwds)
if isinstance(output, (GradioRootModel, GradioModel)):
output = output.model_dump()
return output

return serialize


def create_pyi(class_code: str, events: list[EventListener | str]):
template = Template(INTERFACE_TEMPLATE)
events = [e if isinstance(e, str) else e.event_name for e in events]
Expand Down Expand Up @@ -193,8 +181,6 @@ def __new__(cls, name, bases, attrs):
attrs[event] = trigger.listener
if "EVENTS" in attrs:
attrs["EVENTS"] = new_events
if "postprocess" in attrs:
attrs["postprocess"] = serializes(attrs["postprocess"])
component_class = super().__new__(cls, name, bases, attrs)
create_or_modify_pyi(component_class, name, events)
return component_class
3 changes: 3 additions & 0 deletions gradio/components/clear_button.py
Expand Up @@ -8,6 +8,7 @@
from gradio_client.documentation import document, set_documentation_group

from gradio.components import Button, Component
from gradio.data_classes import GradioModel, GradioRootModel

set_documentation_group("component")

Expand Down Expand Up @@ -72,6 +73,8 @@ def add(self, components: None | Component | list[Component]) -> ClearButton:
none_values = []
for component in components:
none = component.postprocess(None)
if isinstance(none, (GradioModel, GradioRootModel)):
none = none.model_dump()
none_values.append(none)
clear_values = json.dumps(none_values)
self.click(None, [], components, js=f"() => {clear_values}")
Expand Down
2 changes: 2 additions & 0 deletions gradio/helpers.py
Expand Up @@ -444,6 +444,8 @@ def merge_generated_values_into_output(
if len(components) > 1:
chunk = chunk[output_index]
processed_chunk = output_component.postprocess(chunk)
if isinstance(processed_chunk, (GradioModel, GradioRootModel)):
processed_chunk = processed_chunk.model_dump()
binary_chunks.append(
output_component.stream_output(processed_chunk, "", i == 0)[0]
)
Expand Down
7 changes: 5 additions & 2 deletions gradio/processing_utils.py
Expand Up @@ -20,7 +20,7 @@
from PIL import Image, ImageOps, PngImagePlugin

from gradio import wasm_utils
from gradio.data_classes import FileData
from gradio.data_classes import FileData, GradioModel, GradioRootModel
from gradio.utils import abspath, is_in_or_equal

with warnings.catch_warnings():
Expand Down Expand Up @@ -278,7 +278,7 @@ def move_files_to_cache(data: Any, block: Component):
Runs after postprocess and before preprocess.
Args:
data: The input or output data for a component.
data: The input or output data for a component. Can be a dictionary or a dataclass
block: The component
"""

Expand All @@ -288,6 +288,9 @@ def _move_to_cache(d: dict):
payload.path = temp_file_path
return payload.model_dump()

if isinstance(data, (GradioRootModel, GradioModel)):
data = data.model_dump()

return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj)


Expand Down
10 changes: 9 additions & 1 deletion test/test_blocks.py
Expand Up @@ -23,6 +23,7 @@
from PIL import Image

import gradio as gr
from gradio.data_classes import GradioModel, GradioRootModel
from gradio.events import SelectData
from gradio.exceptions import DuplicateBlockError
from gradio.networking import Server, get_first_available_port
Expand Down Expand Up @@ -512,8 +513,15 @@ def test_blocks_do_not_filter_none_values_from_updates(self, io_components):
output = demo.postprocess_data(
0, [gr.update(value=None) for _ in io_components], state=None
)

def process_and_dump(component):
output = component.postprocess(None)
if isinstance(output, (GradioModel, GradioRootModel)):
output = output.model_dump()
return output

assert all(
o["value"] == c.postprocess(None) for o, c in zip(output, io_components)
o["value"] == process_and_dump(c) for o, c in zip(output, io_components)
)

def test_blocks_does_not_replace_keyword_literal(self):
Expand Down

0 comments on commit 4ee9c4d

Please sign in to comment.