Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Part I: Remove serializes #6177

Merged
merged 9 commits into from Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/clever-brooms-lay.md
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Part I: Remove serializes
4 changes: 3 additions & 1 deletion gradio/blocks.py
Expand Up @@ -38,7 +38,7 @@
)
from gradio.blocks_events import BlocksEvents, BlocksMeta
from gradio.context import Context
from gradio.data_classes import FileData
from gradio.data_classes import FileData, GradioModel, GradioRootModel
from gradio.events import (
EventData,
EventListener,
Expand Down Expand Up @@ -392,6 +392,8 @@ def postprocess_update_dict(
)
if postprocess:
attr_dict["value"] = block.postprocess(update_dict["value"])
if isinstance(attr_dict["value"], (GradioModel, GradioRootModel)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, so pydantic won't serialize it automatically?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea it didn't work without me manually serializing the value when it was in an update dictionary. If it was a regular prediction, it would be fine.

attr_dict["value"] = attr_dict["value"].model_dump()
else:
attr_dict["value"] = update_dict["value"]
return attr_dict
Expand Down
14 changes: 0 additions & 14 deletions gradio/component_meta.py
Expand Up @@ -8,7 +8,6 @@

from jinja2 import Template

from gradio.data_classes import GradioModel, GradioRootModel
from gradio.events import EventListener
from gradio.exceptions import ComponentDefinitionError
from gradio.utils import no_raise_exception
Expand Down Expand Up @@ -56,17 +55,6 @@ def {{ event }}(self,
'''


def serializes(f):
@wraps(f)
def serialize(*args, **kwds):
output = f(*args, **kwds)
if isinstance(output, (GradioRootModel, GradioModel)):
output = output.model_dump()
return output

return serialize


def create_pyi(class_code: str, events: list[EventListener | str]):
template = Template(INTERFACE_TEMPLATE)
events = [e if isinstance(e, str) else e.event_name for e in events]
Expand Down Expand Up @@ -192,8 +180,6 @@ def __new__(cls, name, bases, attrs):
attrs[event] = trigger.listener
if "EVENTS" in attrs:
attrs["EVENTS"] = new_events
if "postprocess" in attrs:
attrs["postprocess"] = serializes(attrs["postprocess"])
component_class = super().__new__(cls, name, bases, attrs)
create_or_modify_pyi(component_class, name, events)
return component_class
3 changes: 3 additions & 0 deletions gradio/components/clear_button.py
Expand Up @@ -8,6 +8,7 @@
from gradio_client.documentation import document, set_documentation_group

from gradio.components import Button, Component
from gradio.data_classes import GradioModel, GradioRootModel

set_documentation_group("component")

Expand Down Expand Up @@ -72,6 +73,8 @@ def add(self, components: None | Component | list[Component]) -> ClearButton:
none_values = []
for component in components:
none = component.postprocess(None)
if isinstance(none, (GradioModel, GradioRootModel)):
none = none.model_dump()
none_values.append(none)
clear_values = json.dumps(none_values)
self.click(None, [], components, _js=f"() => {clear_values}")
Expand Down
2 changes: 2 additions & 0 deletions gradio/helpers.py
Expand Up @@ -444,6 +444,8 @@ def merge_generated_values_into_output(
if len(components) > 1:
chunk = chunk[output_index]
processed_chunk = output_component.postprocess(chunk)
if isinstance(processed_chunk, (GradioModel, GradioRootModel)):
processed_chunk = processed_chunk.model_dump()
binary_chunks.append(
output_component.stream_output(processed_chunk, "", i == 0)[0]
)
Expand Down
7 changes: 5 additions & 2 deletions gradio/processing_utils.py
Expand Up @@ -20,7 +20,7 @@
from PIL import Image, ImageOps, PngImagePlugin

from gradio import wasm_utils
from gradio.data_classes import FileData
from gradio.data_classes import FileData, GradioModel, GradioRootModel
from gradio.utils import abspath, is_in_or_equal

with warnings.catch_warnings():
Expand Down Expand Up @@ -278,7 +278,7 @@ def move_files_to_cache(data: Any, block: Component):
Runs after postprocess and before preprocess.

Args:
data: The input or output data for a component.
data: The input or output data for a component. Can be a dictionary or a dataclass
block: The component
"""

Expand All @@ -288,6 +288,9 @@ def _move_to_cache(d: dict):
payload.path = temp_file_path
return payload.model_dump()

if isinstance(data, (GradioRootModel, GradioModel)):
data = data.model_dump()

return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj)


Expand Down
10 changes: 9 additions & 1 deletion test/test_blocks.py
Expand Up @@ -23,6 +23,7 @@
from PIL import Image

import gradio as gr
from gradio.data_classes import GradioModel, GradioRootModel
from gradio.events import SelectData
from gradio.exceptions import DuplicateBlockError
from gradio.networking import Server, get_first_available_port
Expand Down Expand Up @@ -512,8 +513,15 @@ def test_blocks_do_not_filter_none_values_from_updates(self, io_components):
output = demo.postprocess_data(
0, [gr.update(value=None) for _ in io_components], state=None
)

def process_and_dump(component):
output = component.postprocess(None)
if isinstance(output, (GradioModel, GradioRootModel)):
output = output.model_dump()
return output

assert all(
o["value"] == c.postprocess(None) for o, c in zip(output, io_components)
o["value"] == process_and_dump(c) for o, c in zip(output, io_components)
)

def test_blocks_does_not_replace_keyword_literal(self):
Expand Down