Skip to content

Commit

Permalink
more format preservation re: gradio-app#7486
Browse files Browse the repository at this point in the history
  • Loading branch information
dfl committed Mar 12, 2024
1 parent 90315f6 commit bd1948f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
7 changes: 5 additions & 2 deletions gradio/components/annotated_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
]
| None = None,
*,
format: str | None = None,
show_legend: bool = True,
height: int | str | None = None,
width: int | str | None = None,
Expand All @@ -67,6 +68,7 @@ def __init__(
"""
Parameters:
value: Tuple of base image and list of (annotation, label) pairs.
format: Format to be returned by component, such as 'jpg' or 'png'. If set to None, image will keep uploaded format.
show_legend: If True, will show a legend of the annotations.
height: The height of the image, specified in pixels if a number is passed, or in CSS units if a string is passed.
width: The width of the image, specified in pixels if a number is passed, or in CSS units if a string is passed.
Expand All @@ -82,6 +84,7 @@ def __init__(
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
render: If False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.
"""
self.format = format
self.show_legend = show_legend
self.height = height
self.width = width
Expand Down Expand Up @@ -141,12 +144,12 @@ def postprocess(
base_img = np.array(PIL.Image.open(base_img))
elif isinstance(base_img, np.ndarray):
base_file = processing_utils.save_img_array_to_cache(
base_img, cache_dir=self.GRADIO_CACHE
base_img, cache_dir=self.GRADIO_CACHE, self.format
)
base_img_path = str(utils.abspath(base_file))
elif isinstance(base_img, PIL.Image.Image):
base_file = processing_utils.save_pil_to_cache(
base_img, cache_dir=self.GRADIO_CACHE
base_img, cache_dir=self.GRADIO_CACHE, self.format
)
base_img_path = str(utils.abspath(base_file))
base_img = np.array(base_img)
Expand Down
6 changes: 3 additions & 3 deletions gradio/components/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
"""
Parameters:
value: A PIL Image, numpy array, path or URL for the default value that Image component is going to take. If callable, the function will be called whenever the app loads to set the initial value of the component.
format: Format of to be returned by component, such as 'jpg' or 'png'. If set to None, image will keep uploaded format.
format: Format to be returned by component, such as 'jpg' or 'png'. If set to None, image will keep uploaded format.
height: The height of the displayed image, specified in pixels if a number is passed, or in CSS units if a string is passed.
width: The width of the displayed image, specified in pixels if a number is passed, or in CSS units if a string is passed.
image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning.
Expand Down Expand Up @@ -186,7 +186,7 @@ def preprocess(
cast(Literal["numpy", "pil", "filepath"], self.type),
self.GRADIO_CACHE,
name=name,
format=suffix,
format=self.format or suffix,
)

def postprocess(
Expand All @@ -202,7 +202,7 @@ def postprocess(
return None
if isinstance(value, str) and value.lower().endswith(".svg"):
return FileData(path=value, orig_name=Path(value).name)
saved = image_utils.save_image(value, self.GRADIO_CACHE)
saved = image_utils.save_image(value, self.GRADIO_CACHE, self.format)
orig_name = Path(saved).name if Path(saved).exists() else None
return FileData(path=saved, orig_name=orig_name)

Expand Down
8 changes: 4 additions & 4 deletions gradio/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def format_image(
)


def save_image(y: np.ndarray | PIL.Image.Image | str | Path, cache_dir: str):
def save_image(y: np.ndarray | PIL.Image.Image | str | Path, cache_dir: str, format: str=None):
# numpy gets saved to png as default format
# PIL gets saved to its original format if possible
if isinstance(y, np.ndarray):
path = processing_utils.save_img_array_to_cache(y, cache_dir=cache_dir)
path = processing_utils.save_img_array_to_cache(y, cache_dir=cache_dir, format=format)
elif isinstance(y, PIL.Image.Image):
fmt = y.format
fmt = format or y.format
try:
path = processing_utils.save_pil_to_cache(
y,
Expand All @@ -67,7 +67,7 @@ def save_image(y: np.ndarray | PIL.Image.Image | str | Path, cache_dir: str):
# Catch error if format is not supported by PIL
except (KeyError, ValueError):
path = processing_utils.save_pil_to_cache(
y, cache_dir=cache_dir, format="png"
y, cache_dir=cache_dir, format=format or "png"
)
elif isinstance(y, Path):
path = str(y)
Expand Down

0 comments on commit bd1948f

Please sign in to comment.