diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b1fb92687e2..6d9acf58f38b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The `api_name` parameter will take precendence over the `fn_index` parameter. ## Bug Fixes: * Fixed bug where None could not be used for File,Model3D, and Audio examples by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2588](https://github.com/gradio-app/gradio/pull/2588) * Fixed links in Plotly map guide + demo by [@dawoodkhan82](https://github.com/dawoodkhan82) in [PR 2578](https://github.com/gradio-app/gradio/pull/2578) +* `gr.Blocks.load()` now correctly loads example files from Spaces [@abidlabs](https://github.com/abidlabs) in [PR 2594](https://github.com/gradio-app/gradio/pull/2594) ## Documentation Changes: No changes to highlight. diff --git a/gradio/blocks.py b/gradio/blocks.py index a5ac0e59737b..90c10f367d91 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -66,11 +66,20 @@ class Block: - def __init__(self, *, render=True, elem_id=None, visible=True, **kwargs): + def __init__( + self, + *, + render: bool = True, + elem_id: str | None = None, + visible: bool = True, + root_url: str | None = None, # URL that is prepended to all file paths + **kwargs, + ): self._id = Context.id Context.id += 1 self.visible = visible self.elem_id = elem_id + self.root_url = root_url self._style = {} if render: self.render() @@ -246,6 +255,7 @@ def get_config(self): "visible": self.visible, "elem_id": self.elem_id, "style": self._style, + "root_url": self.root_url, } @classmethod @@ -570,8 +580,17 @@ def share(self, value: Optional[bool]): self._share = value @classmethod - def from_config(cls, config: dict, fns: List[Callable]) -> Blocks: - """Factory method that creates a Blocks from a config and list of functions.""" + def from_config( + cls, config: dict, fns: List[Callable], root_url: str | None = None + ) -> Blocks: + """ + Factory method that creates a Blocks from a config and list of functions. + + Parameters: + config: a dictionary containing the configuration of the Blocks. + fns: a list of functions that are used in the Blocks. Must be in the same order as the dependencies in the config. + root_url: an optional root url to use for the components in the Blocks. Allows serving files from an external URL. + """ config = copy.deepcopy(config) components_config = config["components"] original_mapping: Dict[int, Block] = {} @@ -586,6 +605,8 @@ def get_block_instance(id: int) -> Block: block_config["props"].pop("type", None) block_config["props"].pop("name", None) style = block_config["props"].pop("style", None) + if block_config["props"].get("root_url") is None and root_url: + block_config["props"]["root_url"] = root_url + "/" block = cls(**block_config["props"]) if style: block.style(**style) @@ -603,8 +624,13 @@ def iterate_over_children(children_list): iterate_over_children(children) with Blocks(theme=config["theme"], css=config["theme"]) as blocks: + # ID 0 should be the root Blocks component + original_mapping[0] = Context.root_block or blocks + iterate_over_children(config["layout"]["children"]) + first_dependency = None + # add the event triggers for dependency, fn in zip(config["dependencies"], fns): targets = dependency.pop("targets") @@ -618,19 +644,24 @@ def iterate_over_children(children_list): original_mapping[o] for o in dependency["outputs"] ] dependency.pop("status_tracker", None) - dependency["_js"] = dependency.pop("js", None) dependency["preprocess"] = False dependency["postprocess"] = False for target in targets: - event_method = getattr(original_mapping[target], trigger) - event_method(fn=fn, **dependency) + dependency = original_mapping[target].set_event_trigger( + event_name=trigger, fn=fn, **dependency + ) + if first_dependency is None: + first_dependency = dependency # Allows some use of Interface-specific methods with loaded Spaces blocks.predict = [fns[0]] - dependency = blocks.dependencies[0] - blocks.input_components = [blocks.blocks[i] for i in dependency["inputs"]] - blocks.output_components = [blocks.blocks[o] for o in dependency["outputs"]] + blocks.input_components = [ + Context.root_block.blocks[i] for i in first_dependency["inputs"] + ] + blocks.output_components = [ + Context.root_block.blocks[o] for o in first_dependency["outputs"] + ] if config.get("mode", "blocks") == "interface": blocks.__name__ = "Interface" @@ -1073,7 +1104,7 @@ def load( fn: Instance Method - Callable function inputs: Instance Method - input list outputs: Instance Method - output list - every: Run this event 'every' number of seconds. Interpreted in seconds. Queue must be enabled. + every: Instance Method - Run this event 'every' number of seconds. Interpreted in seconds. Queue must be enabled. Example: import gradio as gr import datetime @@ -1088,14 +1119,8 @@ def get_time(): if isinstance(self_or_cls, type): if name is None: raise ValueError( - "Blocks.load() requires passing `name` as a keyword argument" + "Blocks.load() requires passing parameters as keyword arguments" ) - if fn is not None: - kwargs["fn"] = fn - if inputs is not None: - kwargs["inputs"] = inputs - if outputs is not None: - kwargs["outputs"] = outputs return external.load_blocks_from_repo(name, src, api_key, alias, **kwargs) else: return self_or_cls.set_event_trigger( diff --git a/gradio/components.py b/gradio/components.py index da4b7717489e..70879fb4f88f 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -2022,7 +2022,7 @@ def style( **kwargs, ) - def as_example(self, input_data: str) -> str: + def as_example(self, input_data: str | None) -> str: return Path(input_data).name if input_data else "" @@ -2462,8 +2462,10 @@ def style( **kwargs, ) - def as_example(self, input_data): - if isinstance(input_data, pd.DataFrame): + def as_example(self, input_data: pd.DataFrame | np.ndarray | str | None): + if input_data is None: + return "" + elif isinstance(input_data, pd.DataFrame): return input_data.head(n=5).to_dict(orient="split")["data"] elif isinstance(input_data, np.ndarray): return input_data.tolist() @@ -3616,7 +3618,7 @@ def style(self, **kwargs): **kwargs, ) - def as_example(self, input_data: str) -> str: + def as_example(self, input_data: str | None) -> str: return Path(input_data).name if input_data else "" diff --git a/gradio/external.py b/gradio/external.py index 40ad248716ad..bbb2edd03fe4 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -21,7 +21,6 @@ get_ws_fn, postprocess_label, rows_to_cols, - streamline_spaces_blocks, streamline_spaces_interface, use_websocket, ) @@ -312,6 +311,7 @@ def query_huggingface_api(*params): def from_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Blocks: space_url = "https://huggingface.co/spaces/{}".format(space_name) + print("Fetching Space from: {}".format(space_url)) headers = {} @@ -345,14 +345,12 @@ def from_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> B space_name, config, alias, api_key, iframe_url, **kwargs ) else: # Create a Blocks for Gradio 3.x Spaces - return from_spaces_blocks(space_name, config, api_key, iframe_url) + return from_spaces_blocks(config, api_key, iframe_url) -def from_spaces_blocks( - model_name: str, config: Dict, api_key: str | None, iframe_url: str -) -> Blocks: - config = streamline_spaces_blocks(config) +def from_spaces_blocks(config: Dict, api_key: str | None, iframe_url: str) -> Blocks: api_url = "{}/api/predict/".format(iframe_url) + headers = {"Content-Type": "application/json"} if api_key is not None: headers["Authorization"] = f"Bearer {api_key}" @@ -398,7 +396,7 @@ def fn(*data): fns.append(fn) else: fns.append(None) - return gradio.Blocks.from_config(config, fns) + return gradio.Blocks.from_config(config, fns, iframe_url) def from_spaces_interface( diff --git a/gradio/external_utils.py b/gradio/external_utils.py index a3692ea3ac3b..db496bac334b 100644 --- a/gradio/external_utils.py +++ b/gradio/external_utils.py @@ -153,7 +153,7 @@ def use_websocket(config, dependency): ################## -# Helper functions for cleaning up Interfaces/Blocks loaded from HF Spaces +# Helper function for cleaning up an Interface loaded from HF Spaces ################## @@ -178,12 +178,3 @@ def streamline_spaces_interface(config: Dict) -> Dict: } config = {k: config[k] for k in parameters} return config - - -def streamline_spaces_blocks(config: dict) -> dict: - """Streamlines the blocks config dictionary to fix components that don't render correctly.""" - # TODO(abidlabs): Need a better way to fix relative paths in dataset component - for c, component in enumerate(config["components"]): - if component["type"] == "dataset": - config["components"][c]["props"]["visible"] = False - return config diff --git a/gradio/interface.py b/gradio/interface.py index 05e84d2e319a..e08ee07bd9f4 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -105,7 +105,7 @@ def load( demo = gr.Interface.load("models/EleutherAI/gpt-neo-1.3B", description=description, examples=examples) demo.launch() """ - return super().load(name=name, src=src, api_key=api_key, alias=alias, **kwargs) + return super().load(name=name, src=src, api_key=api_key, alias=alias) @classmethod def from_pipeline(cls, pipeline: transformers.Pipeline, **kwargs) -> Interface: diff --git a/test/test_components.py b/test/test_components.py index c46e63815d9c..1cf68d8f28d2 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -102,6 +102,7 @@ def test_component_functions(self): "elem_id": None, "visible": True, "interactive": None, + "root_url": None, }, ) self.assertIsInstance(text_input.generate_sample(), str) @@ -201,6 +202,7 @@ def test_component_functions(self): "elem_id": None, "visible": True, "interactive": None, + "root_url": None, }, ) @@ -248,6 +250,7 @@ def test_component_functions_integer(self): "elem_id": None, "visible": True, "interactive": None, + "root_url": None, }, ) @@ -372,6 +375,7 @@ def test_component_functions(self): "elem_id": None, "visible": True, "interactive": None, + "root_url": None, }, ) @@ -447,6 +451,7 @@ def test_component_functions(self): "elem_id": None, "visible": True, "interactive": None, + "root_url": None, }, ) @@ -492,6 +497,7 @@ def test_component_functions(self): "elem_id": None, "visible": True, "interactive": None, + "root_url": None, }, ) with self.assertRaises(ValueError): @@ -535,6 +541,7 @@ def test_component_functions(self): "elem_id": None, "visible": True, "interactive": None, + "root_url": None, }, ) with self.assertRaises(ValueError): @@ -593,6 +600,7 @@ async def test_component_functions(self): "visible": True, "value": None, "interactive": None, + "root_url": None, "mirror_webcam": True, }, ) @@ -763,6 +771,7 @@ def test_component_functions(self): "visible": True, "value": None, "interactive": None, + "root_url": None, }, ) self.assertIsNone(audio_input.preprocess(None)) @@ -799,6 +808,7 @@ def test_component_functions(self): "visible": True, "value": None, "interactive": None, + "root_url": None, }, ) self.assertTrue( @@ -893,6 +903,7 @@ def test_component_functions(self): "visible": True, "value": None, "interactive": None, + "root_url": None, }, ) self.assertIsNone(file_input.preprocess(None)) @@ -974,6 +985,7 @@ def test_component_functions(self): "elem_id": None, "visible": True, "interactive": None, + "root_url": None, "wrap": False, }, ) @@ -1008,6 +1020,7 @@ def test_component_functions(self): "headers": [1, 2, 3], }, "interactive": None, + "root_url": None, "wrap": False, }, ) @@ -1227,6 +1240,7 @@ def test_component_functions(self): "visible": True, "value": None, "interactive": None, + "root_url": None, "mirror_webcam": True, }, ) @@ -1364,6 +1378,7 @@ def test_component_functions(self): "visible": True, "value": None, "interactive": None, + "root_url": None, }, ) self.assertIsNone(timeseries_input.preprocess(None)) @@ -1388,6 +1403,7 @@ def test_component_functions(self): "visible": True, "value": None, "interactive": None, + "root_url": None, }, ) data = {"Name": ["Tom", "nick", "krish", "jack"], "Age": [20, 21, 19, 18]} @@ -1532,6 +1548,7 @@ def test_component_functions(self): "elem_id": None, "visible": True, "interactive": None, + "root_url": None, }, ) @@ -1628,6 +1645,7 @@ def test_component_functions(self): "visible": True, "value": None, "interactive": None, + "root_url": None, }, ) @@ -1678,6 +1696,7 @@ def test_component_functions(self): "label": None, "name": "json", "interactive": None, + "root_url": None, }, ) @@ -1728,6 +1747,7 @@ def test_component_functions(self): "label": "HTML Input", "name": "html", "interactive": None, + "root_url": None, }, html_component.get_config(), ) @@ -1757,6 +1777,7 @@ def test_component_functions(self): "label": "Model", "show_label": True, "interactive": None, + "root_url": None, "name": "model3d", "visible": True, "elem_id": None, @@ -1801,6 +1822,7 @@ def test_component_functions(self): "elem_id": None, "visible": True, "interactive": None, + "root_url": None, "name": "colorpicker", }, ) @@ -1929,18 +1951,20 @@ def test_as_example_returns_file_basename(component): @patch("gradio.components.IOComponent.as_example") +@patch("gradio.components.Image.as_example") @patch("gradio.components.File.as_example") @patch("gradio.components.Dataframe.as_example") @patch("gradio.components.Model3D.as_example") def test_dataset_calls_as_example(*mocks): gr.Dataset( - components=[gr.Dataframe(), gr.File(), gr.Image(), gr.Model3D()], + components=[gr.Dataframe(), gr.File(), gr.Image(), gr.Model3D(), gr.Textbox()], samples=[ [ pd.DataFrame({"a": np.array([1, 2, 3])}), "foo.png", "bar.jpeg", "duck.obj", + "hello", ] ], ) diff --git a/test/test_external.py b/test/test_external.py index 7ee3a09a1596..92685a6cd812 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -265,6 +265,16 @@ def test_interface_load_cache_examples(self, tmp_path): cache_examples=True, ) + def test_root_url(self): + demo = gr.Interface.load("spaces/gradio/test-loading-examples") + assert all( + [ + c["props"]["root_url"] + == "https://gradio-test-loading-examples.hf.space/" + for c in demo.get_config_file()["components"] + ] + ) + def test_get_tabular_examples_replaces_nan_with_str_nan(): readme = """ diff --git a/ui/packages/app/src/components/Audio/Audio.svelte b/ui/packages/app/src/components/Audio/Audio.svelte index f951454668d2..1fcd4e4186d5 100644 --- a/ui/packages/app/src/components/Audio/Audio.svelte +++ b/ui/packages/app/src/components/Audio/Audio.svelte @@ -30,11 +30,12 @@ export let show_label: boolean; export let pending: boolean; export let streaming: boolean; + export let root_url: null | string; export let loading_status: LoadingStatus; let _value: null | FileData; - $: _value = normalise_file(value, root); + $: _value = normalise_file(value, root_url ?? root); let dragging: boolean; diff --git a/ui/packages/app/src/components/Dataset/Dataset.svelte b/ui/packages/app/src/components/Dataset/Dataset.svelte index 19dfbcb30d71..baef3532e63e 100644 --- a/ui/packages/app/src/components/Dataset/Dataset.svelte +++ b/ui/packages/app/src/components/Dataset/Dataset.svelte @@ -11,11 +11,12 @@ export let visible: boolean = true; export let value: number | null = null; export let root: string; + export let root_url: null | string; export let samples_per_page: number = 10; const dispatch = createEventDispatcher<{ click: number }>(); - let samples_dir: string = root + "file="; + let samples_dir: string = (root_url ?? root) + "file="; let page = 0; let gallery = headers.length === 1; let paginate = samples.length > samples_per_page; diff --git a/ui/packages/app/src/components/File/File.svelte b/ui/packages/app/src/components/File/File.svelte index e406a19aee72..84af4d08f1ce 100644 --- a/ui/packages/app/src/components/File/File.svelte +++ b/ui/packages/app/src/components/File/File.svelte @@ -17,11 +17,12 @@ export let label: string; export let show_label: boolean; export let file_count: string; + export let root_url: null | string; export let loading_status: LoadingStatus; let _value: null | FileData; - $: _value = normalise_file(value, root); + $: _value = normalise_file(value, root_url ?? root); let dragging = false; diff --git a/ui/packages/app/src/components/Gallery/Gallery.svelte b/ui/packages/app/src/components/Gallery/Gallery.svelte index ecd3b272ee20..0ae633cd0257 100644 --- a/ui/packages/app/src/components/Gallery/Gallery.svelte +++ b/ui/packages/app/src/components/Gallery/Gallery.svelte @@ -15,6 +15,7 @@ export let show_label: boolean; export let label: string; export let root: string; + export let root_url: null | string; export let elem_id: string = ""; export let visible: boolean = true; export let value: Array | Array | null = null; @@ -25,8 +26,8 @@ ? null : value.map((img) => Array.isArray(img) - ? [normalise_file(img[0], root), img[1]] - : [normalise_file(img, root), null] + ? [normalise_file(img[0], root_url ?? root), img[1]] + : [normalise_file(img, root_url ?? root), null] ); let prevValue: string[] | FileData[] | null = null; diff --git a/ui/packages/app/src/components/Model3D/Model3D.svelte b/ui/packages/app/src/components/Model3D/Model3D.svelte index e32fed0135a0..2f583217e9f8 100644 --- a/ui/packages/app/src/components/Model3D/Model3D.svelte +++ b/ui/packages/app/src/components/Model3D/Model3D.svelte @@ -13,6 +13,7 @@ export let value: null | FileData = null; export let mode: "static" | "dynamic"; export let root: string; + export let root_url: null | string; export let clearColor: Array; export let loading_status: LoadingStatus; @@ -20,7 +21,7 @@ export let show_label: boolean; let _value: null | FileData; - $: _value = normalise_file(value, root); + $: _value = normalise_file(value, root_url ?? root); let dragging = false; diff --git a/ui/packages/app/src/components/Video/Video.svelte b/ui/packages/app/src/components/Video/Video.svelte index e07d389fd6fe..d88f0bfef2c1 100644 --- a/ui/packages/app/src/components/Video/Video.svelte +++ b/ui/packages/app/src/components/Video/Video.svelte @@ -15,6 +15,7 @@ export let label: string; export let source: string; export let root: string; + export let root_url: null | string; export let show_label: boolean; export let loading_status: LoadingStatus; export let style: Styles = {}; @@ -23,7 +24,7 @@ export let mode: "static" | "dynamic"; let _value: null | FileData; - $: _value = normalise_file(value, root); + $: _value = normalise_file(value, root_url ?? root); let dragging = false;