From 46006a84e00661b39d606dc4f6e6d92fe68b82c0 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Wed, 25 Oct 2023 10:57:03 -0400 Subject: [PATCH] Add recover_kwargs --- gradio/blocks.py | 33 ++++++++++++++++++--------------- gradio/components/dataset.py | 8 +++----- gradio/utils.py | 9 --------- test/test_blocks.py | 13 +++++++++++++ 4 files changed, 34 insertions(+), 29 deletions(-) diff --git a/gradio/blocks.py b/gradio/blocks.py index 18296a818061..471036eae122 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -196,6 +196,21 @@ def get_config(self): config.pop("render", None) return {**config, "root_url": self.root_url, "name": self.get_block_name()} + @classmethod + def recover_kwargs( + cls, props: dict[str, Any], additional_keys: list[str] | None = None + ): + """ + Recovers kwargs from a dict of props. + """ + additional_keys = additional_keys or [] + signature = inspect.signature(cls.__init__) + kwargs = {} + for parameter in signature.parameters.values(): + if parameter.name in props and parameter.name not in additional_keys: + kwargs[parameter.name] = props[parameter.name] + return kwargs + class BlockContext(Block): def __init__( @@ -589,7 +604,7 @@ def get_block_instance(id: int) -> Block: raise ValueError(f"Cannot find block with id {id}") cls = component_or_layout_class(block_config["type"]) - block_config["props"] = utils.recover_kwargs(block_config["props"]) + block_config["props"] = cls.recover_kwargs(block_config["props"]) # If a Gradio app B is loaded into a Gradio app A, and B itself loads a # Gradio app C, then the root_urls of the components in A need to be the @@ -599,18 +614,6 @@ def get_block_instance(id: int) -> Block: else: root_urls.add(block_config["props"]["root_url"]) - # We treat dataset components as a special case because they reference other components - # in the config. Instead of using the component string names, we use the component ids. - if ( - block_config["type"] == "dataset" - and "component_ids" in block_config["props"] - ): - block_config["props"].pop("components", None) - block_config["props"]["components"] = [ - original_mapping[c] for c in block_config["props"]["component_ids"] - ] - block_config["props"].pop("component_ids", None) - # Any component has already processed its initial value, so we skip that step here block = cls(**block_config["props"], _skip_init_processing=True) return block @@ -820,9 +823,9 @@ def set_event_trigger( elif every: raise ValueError("Cannot set a value for `every` without a `fn`.") - if _targets[0][1] == "change" and trigger_mode == None: + if _targets[0][1] == "change" and trigger_mode is None: trigger_mode = "always_last" - elif trigger_mode == None: + elif trigger_mode is None: trigger_mode = "once" elif trigger_mode not in ["once", "multiple", "always_last"]: raise ValueError( diff --git a/gradio/components/dataset.py b/gradio/components/dataset.py index d1d2ab8a37c0..f20f2523e348 100644 --- a/gradio/components/dataset.py +++ b/gradio/components/dataset.py @@ -6,7 +6,6 @@ from gradio_client.documentation import document, set_documentation_group -import gradio.utils as utils from gradio.components.base import ( Component, get_component_instance, @@ -45,8 +44,6 @@ def __init__( container: bool = True, scale: int | None = None, min_width: int = 160, - component_props: dict[str, Any] | None = None, - component_ids: dict[str, Any] | None = None, ): """ Parameters: @@ -72,15 +69,16 @@ def __init__( elem_classes=elem_classes, root_url=root_url, _skip_init_processing=_skip_init_processing, + render=render, ) self.container = container self.scale = scale self.min_width = min_width self._components = [get_component_instance(c) for c in components] self.component_props = [ - utils.recover_kwargs( + component.recover_kwargs( component.get_config(), - ["value", "component_props", "component_ids"], + ["value"], ) for component in self._components ] diff --git a/gradio/utils.py b/gradio/utils.py index 37692e06323b..66a025c457e2 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -952,15 +952,6 @@ def find_user_stack_level() -> int: return n -def recover_kwargs(config: dict, additional_keys_to_ignore: list[str] | None = None): - not_kwargs = ["type", "name", "selectable", "server_fns", "streamable"] - return { - k: v - for k, v in config.items() - if k not in not_kwargs and k not in (additional_keys_to_ignore or []) - } - - class NamedString(str): """ Subclass of str that includes a .name attribute equal to the value of the string itself. This class is used when returning diff --git a/test/test_blocks.py b/test/test_blocks.py index aa53fd4b45c2..e72b045164ac 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -1666,3 +1666,16 @@ def test_temp_file_sets_get_extended(): demo2.render() assert demo3.temp_file_sets == demo1.temp_file_sets + demo2.temp_file_sets + + +def test_recover_kwargs(): + audio = gr.Audio(format="wav", autoplay=True) + props = audio.recover_kwargs( + {"format": "wav", "value": "foo.wav", "autoplay": False, "foo": "bar"} + ) + assert props == {"format": "wav", "value": "foo.wav", "autoplay": False} + props = audio.recover_kwargs( + {"format": "wav", "value": "foo.wav", "autoplay": False, "foo": "bar"}, + ["value"], + ) + assert props == {"format": "wav", "autoplay": False}