Skip to content

Commit

Permalink
Gallery preview fix and optionally skip download of urls in postprcess (
Browse files Browse the repository at this point in the history
#6288)

* Add code

* add changeset

* Use urls from our s3 bucket

* Add notebook code

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot committed Nov 3, 2023
1 parent e8216be commit 9227872
Show file tree
Hide file tree
Showing 12 changed files with 98 additions and 24 deletions.
6 changes: 6 additions & 0 deletions .changeset/cuddly-snakes-dress.md
@@ -0,0 +1,6 @@
---
"@gradio/gallery": patch
"gradio": patch
---

fix:Gallery preview fix and optionally skip download of urls in postprcess
1 change: 1 addition & 0 deletions demo/gallery_component_events/run.ipynb
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: gallery_component_events"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr \n", "\n", "with gr.Blocks() as demo:\n", " cheetahs = [\n", " \"https://gradio-builds.s3.amazonaws.com/assets/cheetah-003.jpg\",\n", " \"https://gradio-builds.s3.amazonaws.com/assets/lite-logo.png\",\n", " \"https://gradio-builds.s3.amazonaws.com/assets/TheCheethcat.jpg\",\n", " ]\n", " with gr.Row():\n", " with gr.Column():\n", " btn = gr.Button()\n", " with gr.Column():\n", " gallery = gr.Gallery()\n", " with gr.Column():\n", " select_output = gr.Textbox(label=\"Select Data\")\n", " btn.click(lambda: cheetahs, None, [gallery])\n", "\n", " def select(select_data: gr.SelectData):\n", " return select_data.value['image']['url']\n", "\n", " gallery.select(select, None, select_output)\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
25 changes: 25 additions & 0 deletions demo/gallery_component_events/run.py
@@ -0,0 +1,25 @@
import gradio as gr

with gr.Blocks() as demo:
cheetahs = [
"https://gradio-builds.s3.amazonaws.com/assets/cheetah-003.jpg",
"https://gradio-builds.s3.amazonaws.com/assets/lite-logo.png",
"https://gradio-builds.s3.amazonaws.com/assets/TheCheethcat.jpg",
]
with gr.Row():
with gr.Column():
btn = gr.Button()
with gr.Column():
gallery = gr.Gallery()
with gr.Column():
select_output = gr.Textbox(label="Select Data")
btn.click(lambda: cheetahs, None, [gallery])

def select(select_data: gr.SelectData):
return select_data.value['image']['url']

gallery.select(select, None, select_output)


if __name__ == "__main__":
demo.launch()
2 changes: 1 addition & 1 deletion gradio/blocks.py
Expand Up @@ -1380,7 +1380,7 @@ def postprocess_data(
f"{block.__class__} Component with id {output_id} not a valid output component."
)
prediction_value = block.postprocess(prediction_value)
outputs_cached = processing_utils.move_files_to_cache(prediction_value, block) # type: ignore
outputs_cached = processing_utils.move_files_to_cache(prediction_value, block, postprocess=True) # type: ignore
output.append(outputs_cached)

return output
Expand Down
9 changes: 6 additions & 3 deletions gradio/components/base.py
Expand Up @@ -167,8 +167,11 @@ def __init__(
if not hasattr(self, "data_model"):
self.data_model: type[GradioDataModel] | None = None
self.temp_files: set[str] = set()
self.GRADIO_CACHE = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
self.GRADIO_CACHE = str(
Path(
os.environ.get("GRADIO_TEMP_DIR")
or str(Path(tempfile.gettempdir()) / "gradio")
).resolve()
)

Block.__init__(
Expand Down Expand Up @@ -208,7 +211,7 @@ def __init__(
if self._skip_init_processing
else self.postprocess(initial_value)
)
self.value = move_files_to_cache(initial_value, self) # type: ignore
self.value = move_files_to_cache(initial_value, self, postprocess=True) # type: ignore

if callable(load_fn):
self.attach_load_event(load_fn, every)
Expand Down
12 changes: 9 additions & 3 deletions gradio/components/gallery.py
Expand Up @@ -7,6 +7,7 @@

import numpy as np
from gradio_client.documentation import document, set_documentation_group
from gradio_client.utils import is_http_url_like
from PIL import Image as _Image # using _ to minimize namespace pollution

from gradio import processing_utils, utils
Expand Down Expand Up @@ -138,6 +139,7 @@ def postprocess(
return GalleryData(root=[])
output = []
for img in value:
url = None
caption = None
if isinstance(img, (tuple, list)):
img, caption = img
Expand All @@ -151,12 +153,16 @@ def postprocess(
img, cache_dir=self.GRADIO_CACHE
)
file_path = str(utils.abspath(file))
elif isinstance(img, (str, Path)):
elif isinstance(img, str):
file_path = img
url = img if is_http_url_like(img) else None
elif isinstance(img, Path):
file_path = str(img)
else:
raise ValueError(f"Cannot process type as image: {type(img)}")

entry = GalleryImage(image=FileData(path=file_path), caption=caption)
entry = GalleryImage(
image=FileData(path=file_path, url=url), caption=caption
)
output.append(entry)
return GalleryData(root=output)

Expand Down
2 changes: 1 addition & 1 deletion gradio/helpers.py
Expand Up @@ -215,7 +215,7 @@ def __init__(
if isinstance(prediction_value, (GradioRootModel, GradioModel)):
prediction_value = prediction_value.model_dump()
prediction_value = processing_utils.move_files_to_cache(
prediction_value, component
prediction_value, component, postprocess=True
)
sub.append(prediction_value)
self.processed_examples.append(sub)
Expand Down
26 changes: 12 additions & 14 deletions gradio/processing_utils.py
Expand Up @@ -8,7 +8,6 @@
import shutil
import subprocess
import tempfile
import urllib.request
import warnings
from io import BytesIO
from pathlib import Path
Expand Down Expand Up @@ -121,17 +120,9 @@ def hash_file(file_path: str | Path, chunk_num_blocks: int = 128) -> str:
return sha1.hexdigest()


def hash_url(url: str, chunk_num_blocks: int = 128) -> str:
def hash_url(url: str) -> str:
sha1 = hashlib.sha1()
remote = urllib.request.urlopen(url)
max_file_size = 100 * 1024 * 1024 # 100MB
total_read = 0
while True:
data = remote.read(chunk_num_blocks * sha1.block_size)
total_read += chunk_num_blocks * sha1.block_size
if not data or total_read > max_file_size:
break
sha1.update(data)
sha1.update(url.encode("utf-8"))
return sha1.hexdigest()


Expand Down Expand Up @@ -207,7 +198,6 @@ def save_url_to_cache(url: str, cache_dir: str) -> str:
temp_dir = hash_url(url)
temp_dir = Path(cache_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)

name = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(abspath(temp_dir / name))

Expand Down Expand Up @@ -273,19 +263,27 @@ def move_resource_to_block_cache(url_or_file_path: str | Path, block: Component)
return temp_file_path


def move_files_to_cache(data: Any, block: Component):
def move_files_to_cache(data: Any, block: Component, postprocess: bool = False):
"""Move files to cache and replace the file path with the cache path.
Runs after postprocess and before preprocess.
Args:
data: The input or output data for a component. Can be a dictionary or a dataclass
block: The component
postprocess: Whether its running from postprocessing
"""

def _move_to_cache(d: dict):
payload = FileData(**d)
temp_file_path = move_resource_to_block_cache(payload.path, block)
# If the gradio app developer is returning a URL from
# postprocess, it means the component can display a URL
# without it being served from the gradio server
# This makes it so that the URL is not downloaded and speeds up event processing
if payload.url and postprocess:
temp_file_path = payload.url
else:
temp_file_path = move_resource_to_block_cache(payload.path, block)
payload.path = temp_file_path
return payload.model_dump()

Expand Down
24 changes: 24 additions & 0 deletions js/app/test/gallery_component_events.spec.ts
@@ -0,0 +1,24 @@
import { test, expect } from "@gradio/tootils";

test("Gallery preview mode displays all images correctly.", async ({
page
}) => {
await page.getByRole("button", { name: "Run" }).click();
await page.getByLabel("Thumbnail 2 of 3").click();

expect(await page.getByTestId("detailed-image").getAttribute("src")).toEqual(
"https://gradio-builds.s3.amazonaws.com/assets/lite-logo.png"
);

expect(await page.getByTestId("thumbnail 1").getAttribute("src")).toEqual(
"https://gradio-builds.s3.amazonaws.com/assets/cheetah-003.jpg"
);
});

test("Gallery select event returns the right value", async ({ page }) => {
await page.getByRole("button", { name: "Run" }).click();
await page.getByLabel("Thumbnail 2 of 3").click();
expect(await page.getByLabel("Select Data")).toHaveValue(
"https://gradio-builds.s3.amazonaws.com/assets/lite-logo.png"
);
});
5 changes: 3 additions & 2 deletions js/gallery/shared/Gallery.svelte
Expand Up @@ -214,7 +214,7 @@
>
<img
data-testid="detailed-image"
src={_value[selected_index].image.path}
src={_value[selected_index].image.url}
alt={_value[selected_index].caption || ""}
title={_value[selected_index].caption || null}
class:with-caption={!!_value[selected_index].caption}
Expand All @@ -240,8 +240,9 @@
aria-label={"Thumbnail " + (i + 1) + " of " + _value.length}
>
<img
src={image.image.path}
src={image.image.url}
title={image.caption || null}
data-testid={"thumbnail " + (i + 1)}
alt=""
loading="lazy"
/>
Expand Down
5 changes: 5 additions & 0 deletions test/test_components.py
Expand Up @@ -2087,6 +2087,11 @@ def test_static(self):


class TestGallery:
def test_postprocess(self):
url = "https://huggingface.co/Norod78/SDXL-VintageMagStyle-Lora/resolve/main/Examples/00015-20230906102032-7778-Wonderwoman VintageMagStyle _lora_SDXL-VintageMagStyle-Lora_1_, Very detailed, clean, high quality, sharp image.jpg"
gallery = gr.Gallery([url])
assert gallery.get_config()["value"][0]["image"]["path"] == url

@patch("uuid.uuid4", return_value="my-uuid")
def test_gallery(self, mock_uuid):
gallery = gr.Gallery()
Expand Down
5 changes: 5 additions & 0 deletions test/test_processing_utils.py
Expand Up @@ -99,6 +99,11 @@ def test_save_url_to_cache(self, gradio_temp_dir):
f = processing_utils.save_url_to_cache(url2, cache_dir=gradio_temp_dir)
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2

def test_save_url_to_cache_with_spaces(self, gradio_temp_dir):
url = "https://huggingface.co/datasets/freddyaboulton/gradio-reviews/resolve/main00015-20230906102032-7778-Wonderwoman VintageMagStyle _lora_SDXL-VintageMagStyle-Lora_1_, Very detailed, clean, high quality, sharp image.jpg"
processing_utils.save_url_to_cache(url, cache_dir=gradio_temp_dir)
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1


class TestImagePreprocessing:
def test_decode_base64_to_image(self):
Expand Down

0 comments on commit 9227872

Please sign in to comment.