From 12d8e90a1646374b46eb8258be7356c868d1cca3 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 2 Nov 2023 17:38:13 -0700 Subject: [PATCH] Fixes input `Image` component with `streaming=True` (#6272) * fix streaming * cleanpu * add changeset * fix * docstrings and tests * notebooks --------- Co-authored-by: gradio-pr-bot --- .changeset/kind-toys-stop.md | 5 +++++ gradio/components/image.py | 26 ++++++++++++++------------ test/test_components.py | 2 +- 3 files changed, 20 insertions(+), 13 deletions(-) create mode 100644 .changeset/kind-toys-stop.md diff --git a/.changeset/kind-toys-stop.md b/.changeset/kind-toys-stop.md new file mode 100644 index 000000000000..e231707079c6 --- /dev/null +++ b/.changeset/kind-toys-stop.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Fixes input `Image` component with `streaming=True` diff --git a/gradio/components/image.py b/gradio/components/image.py index 4e0da7bdf06a..2fc77a9bae13 100644 --- a/gradio/components/image.py +++ b/gradio/components/image.py @@ -4,7 +4,7 @@ import warnings from pathlib import Path -from typing import Any, Iterable, Literal, cast +from typing import Any, Literal, cast import numpy as np from gradio_client.documentation import document, set_documentation_group @@ -50,11 +50,7 @@ def __init__( image_mode: Literal[ "1", "L", "P", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F" ] = "RGB", - sources: Iterable[Literal["upload", "webcam", "clipboard"]] = ( - "upload", - "webcam", - "clipboard", - ), + sources: list[Literal["upload", "webcam", "clipboard"]] | None = None, type: Literal["numpy", "pil", "filepath"] = "numpy", label: str | None = None, every: float | None = None, @@ -78,7 +74,7 @@ def __init__( height: Height of the displayed image in pixels. width: Width of the displayed image in pixels. image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning. - sources: List of sources for the image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "clipboard" allows users to paste an image from the clipboard. + sources: List of sources for the image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "clipboard" allows users to paste an image from the clipboard. If None, defaults to ["upload", "webcam", "clipboard"] if streaming is False, otherwise defaults to ["webcam"]. type: The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to. every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. @@ -107,9 +103,15 @@ def __init__( self.width = width self.image_mode = image_mode valid_sources = ["upload", "webcam", "clipboard"] - if isinstance(sources, str): - sources = [sources] # type: ignore - for source in sources: + if sources is None: + self.sources = ( + ["webcam"] if streaming else ["upload", "webcam", "clipboard"] + ) + elif isinstance(sources, str): + self.sources = [sources] # type: ignore + else: + self.sources = sources + for source in self.sources: # type: ignore if source not in valid_sources: raise ValueError( f"`sources` must a list consisting of elements in {valid_sources}" @@ -118,7 +120,7 @@ def __init__( self.streaming = streaming self.show_download_button = show_download_button - if streaming and sources != ("webcam"): + if streaming and self.sources != ["webcam"]: raise ValueError( "Image streaming only available if sources is ['webcam']. Streaming not supported with multiple sources." ) @@ -163,7 +165,7 @@ def postprocess( return FileData(path=image_utils.save_image(value, self.GRADIO_CACHE)) def check_streamable(self): - if self.streaming and self.sources != ("webcam"): + if self.streaming and self.sources != ["webcam"]: raise ValueError( "Image streaming only available if sources is ['webcam']. Streaming not supported with multiple sources." ) diff --git a/test/test_components.py b/test/test_components.py index 0776247a44df..f8338ce37dc8 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -572,7 +572,7 @@ def test_component_functions(self, gradio_temp_dir): image_input = gr.Image(type="pil", label="Upload Your Image") assert image_input.get_config() == { "image_mode": "RGB", - "sources": ("upload", "webcam", "clipboard"), + "sources": None, "name": "image", "show_share_button": False, "show_download_button": True,