Skip to content

Commit

Permalink
WIP re: gradio-app#7486. Make less PNG-centric.
Browse files Browse the repository at this point in the history
preserve file format during upload
add format to image component
  • Loading branch information
dfl committed Mar 12, 2024
1 parent 6683ab2 commit 4e01ac3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
3 changes: 3 additions & 0 deletions gradio/components/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Image(StreamingInput, Component):

def __init__(
self,
format: str | None = None,
value: str | PIL.Image.Image | np.ndarray | None = None,
*,
height: int | str | None = None,
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
"""
Parameters:
value: A PIL Image, numpy array, path or URL for the default value that Image component is going to take. If callable, the function will be called whenever the app loads to set the initial value of the component.
format: Format of to be returned by component, such as 'jpg' or 'png'. If set to None, image will keep uploaded format.
height: The height of the displayed image, specified in pixels if a number is passed, or in CSS units if a string is passed.
width: The width of the displayed image, specified in pixels if a number is passed, or in CSS units if a string is passed.
image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning.
Expand All @@ -90,6 +92,7 @@ def __init__(
mirror_webcam: If True webcam will be mirrored. Default is True.
show_share_button: If True, will show a share icon in the corner of the component that allows user to share outputs to Hugging Face Spaces Discussions. If False, icon does not appear. If set to None (default behavior), then the icon appears if this Gradio app is launched on Spaces, but not otherwise.
"""
self.format = format
self.mirror_webcam = mirror_webcam
valid_types = ["numpy", "pil", "filepath"]
if type not in valid_types:
Expand Down
33 changes: 23 additions & 10 deletions gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,17 @@ def extract_base64_data(x: str) -> str:
#########################


def encode_plot_to_base64(plt):
def encode_plot_to_base64(plt, format="png"):
with BytesIO() as output_bytes:
plt.savefig(output_bytes, format="png")
plt.savefig(output_bytes, format)
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return "data:image/png;base64," + base64_str
return format_base64_str(base64_str, format)


def get_pil_exif_bytes(pil_image):
if "exif" in pil_image.info:
return pil_image.info["exif"]


def get_pil_metadata(pil_image):
Expand All @@ -78,23 +83,31 @@ def get_pil_metadata(pil_image):

def encode_pil_to_bytes(pil_image, format="png"):
with BytesIO() as output_bytes:
pil_image.save(output_bytes, format, pnginfo=get_pil_metadata(pil_image))
if format == "png":
params = {"pnginfo": get_pil_metadata(pil_image)}
else:
params = {"exif": get_pil_exif_bytes(pil_image)}
pil_image.save(output_bytes, format, **params)
return output_bytes.getvalue()


def encode_pil_to_base64(pil_image):
bytes_data = encode_pil_to_bytes(pil_image)
def encode_pil_to_base64(pil_image, format="png"):
bytes_data = encode_pil_to_bytes(pil_image, format)
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return "data:image/png;base64," + base64_str
return format_base64_str(base64_str, format)


def encode_array_to_base64(image_array):
def encode_array_to_base64(image_array, format="png"):
with BytesIO() as output_bytes:
pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
pil_image.save(output_bytes, "PNG")
pil_image.save(output_bytes, format)
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return "data:image/png;base64," + base64_str
return format_base64_str(base64_str, format)


def format_base64_str(data, format=None) -> str:
return f"data:image/{format or 'png'};base64,{data}"


def hash_file(file_path: str | Path, chunk_num_blocks: int = 128) -> str:
Expand Down

0 comments on commit 4e01ac3

Please sign in to comment.