Skip to content

Commit

Permalink
Add recover_kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
freddyaboulton committed Oct 25, 2023
1 parent c66b9a2 commit 46006a8
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 29 deletions.
33 changes: 18 additions & 15 deletions gradio/blocks.py
Expand Up @@ -196,6 +196,21 @@ def get_config(self):
config.pop("render", None)
return {**config, "root_url": self.root_url, "name": self.get_block_name()}

@classmethod
def recover_kwargs(
cls, props: dict[str, Any], additional_keys: list[str] | None = None
):
"""
Recovers kwargs from a dict of props.
"""
additional_keys = additional_keys or []
signature = inspect.signature(cls.__init__)
kwargs = {}
for parameter in signature.parameters.values():
if parameter.name in props and parameter.name not in additional_keys:
kwargs[parameter.name] = props[parameter.name]
return kwargs


class BlockContext(Block):
def __init__(
Expand Down Expand Up @@ -589,7 +604,7 @@ def get_block_instance(id: int) -> Block:
raise ValueError(f"Cannot find block with id {id}")
cls = component_or_layout_class(block_config["type"])

block_config["props"] = utils.recover_kwargs(block_config["props"])
block_config["props"] = cls.recover_kwargs(block_config["props"])

# If a Gradio app B is loaded into a Gradio app A, and B itself loads a
# Gradio app C, then the root_urls of the components in A need to be the
Expand All @@ -599,18 +614,6 @@ def get_block_instance(id: int) -> Block:
else:
root_urls.add(block_config["props"]["root_url"])

# We treat dataset components as a special case because they reference other components
# in the config. Instead of using the component string names, we use the component ids.
if (
block_config["type"] == "dataset"
and "component_ids" in block_config["props"]
):
block_config["props"].pop("components", None)
block_config["props"]["components"] = [
original_mapping[c] for c in block_config["props"]["component_ids"]
]
block_config["props"].pop("component_ids", None)

# Any component has already processed its initial value, so we skip that step here
block = cls(**block_config["props"], _skip_init_processing=True)
return block
Expand Down Expand Up @@ -820,9 +823,9 @@ def set_event_trigger(
elif every:
raise ValueError("Cannot set a value for `every` without a `fn`.")

if _targets[0][1] == "change" and trigger_mode == None:
if _targets[0][1] == "change" and trigger_mode is None:
trigger_mode = "always_last"
elif trigger_mode == None:
elif trigger_mode is None:
trigger_mode = "once"
elif trigger_mode not in ["once", "multiple", "always_last"]:
raise ValueError(
Expand Down
8 changes: 3 additions & 5 deletions gradio/components/dataset.py
Expand Up @@ -6,7 +6,6 @@

from gradio_client.documentation import document, set_documentation_group

import gradio.utils as utils
from gradio.components.base import (
Component,
get_component_instance,
Expand Down Expand Up @@ -45,8 +44,6 @@ def __init__(
container: bool = True,
scale: int | None = None,
min_width: int = 160,
component_props: dict[str, Any] | None = None,
component_ids: dict[str, Any] | None = None,
):
"""
Parameters:
Expand All @@ -72,15 +69,16 @@ def __init__(
elem_classes=elem_classes,
root_url=root_url,
_skip_init_processing=_skip_init_processing,
render=render,
)
self.container = container
self.scale = scale
self.min_width = min_width
self._components = [get_component_instance(c) for c in components]
self.component_props = [
utils.recover_kwargs(
component.recover_kwargs(
component.get_config(),
["value", "component_props", "component_ids"],
["value"],
)
for component in self._components
]
Expand Down
9 changes: 0 additions & 9 deletions gradio/utils.py
Expand Up @@ -952,15 +952,6 @@ def find_user_stack_level() -> int:
return n


def recover_kwargs(config: dict, additional_keys_to_ignore: list[str] | None = None):
not_kwargs = ["type", "name", "selectable", "server_fns", "streamable"]
return {
k: v
for k, v in config.items()
if k not in not_kwargs and k not in (additional_keys_to_ignore or [])
}


class NamedString(str):
"""
Subclass of str that includes a .name attribute equal to the value of the string itself. This class is used when returning
Expand Down
13 changes: 13 additions & 0 deletions test/test_blocks.py
Expand Up @@ -1666,3 +1666,16 @@ def test_temp_file_sets_get_extended():
demo2.render()

assert demo3.temp_file_sets == demo1.temp_file_sets + demo2.temp_file_sets


def test_recover_kwargs():
audio = gr.Audio(format="wav", autoplay=True)
props = audio.recover_kwargs(
{"format": "wav", "value": "foo.wav", "autoplay": False, "foo": "bar"}
)
assert props == {"format": "wav", "value": "foo.wav", "autoplay": False}
props = audio.recover_kwargs(
{"format": "wav", "value": "foo.wav", "autoplay": False, "foo": "bar"},
["value"],
)
assert props == {"format": "wav", "autoplay": False}

0 comments on commit 46006a8

Please sign in to comment.