Skip to content

Commit

Permalink
Fix processing_utils.save_url_to_cache() to follow redirects when a…
Browse files Browse the repository at this point in the history
…ccessing the URL (#7322)

* Fix `processing_utils.save_url_to_cache()` to follow redirects when accessing the URL

* add changeset

* follow more redirects

* format

* add changeset

* add test

* validate urls

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
3 people committed Feb 6, 2024
1 parent 200e251 commit b25e95e
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 4 deletions.
6 changes: 6 additions & 0 deletions .changeset/smooth-ends-bet.md
@@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---

fix:Fix `processing_utils.save_url_to_cache()` to follow redirects when accessing the URL
8 changes: 6 additions & 2 deletions client/python/gradio_client/utils.py
Expand Up @@ -633,7 +633,9 @@ def download_file(
temp_dir = Path(tempfile.gettempdir()) / secrets.token_hex(20)
temp_dir.mkdir(exist_ok=True, parents=True)

with httpx.stream("GET", url_path, headers=headers) as response:
with httpx.stream(
"GET", url_path, headers=headers, follow_redirects=True
) as response:
response.raise_for_status()
with open(temp_dir / Path(url_path).name, "wb") as f:
for chunk in response.iter_bytes(chunk_size=128 * sha1.block_size):
Expand Down Expand Up @@ -666,7 +668,9 @@ def download_tmp_copy_of_file(
directory.mkdir(exist_ok=True, parents=True)
file_path = directory / Path(url_path).name

with httpx.stream("GET", url_path, headers=headers) as response:
with httpx.stream(
"GET", url_path, headers=headers, follow_redirects=True
) as response:
response.raise_for_status()
with open(file_path, "wb") as f:
for chunk in response.iter_raw():
Expand Down
4 changes: 3 additions & 1 deletion gradio/processing_utils.py
Expand Up @@ -190,7 +190,9 @@ def save_url_to_cache(url: str, cache_dir: str) -> str:
full_temp_file_path = str(abspath(temp_dir / name))

if not Path(full_temp_file_path).exists():
with httpx.stream("GET", url) as r, open(full_temp_file_path, "wb") as f:
with httpx.stream("GET", url, follow_redirects=True) as r, open(
full_temp_file_path, "wb"
) as f:
for chunk in r.iter_raw():
f.write(chunk)

Expand Down
4 changes: 3 additions & 1 deletion gradio/utils.py
Expand Up @@ -592,7 +592,9 @@ def validate_url(possible_url: str) -> bool:
head_request = httpx.head(possible_url, headers=headers, follow_redirects=True)
# some URLs, such as AWS S3 presigned URLs, return a 405 or a 403 for HEAD requests
if head_request.status_code in (403, 405):
return httpx.get(possible_url, headers=headers).is_success
return httpx.get(
possible_url, headers=headers, follow_redirects=True
).is_success
return head_request.is_success
except Exception:
return False
Expand Down
5 changes: 5 additions & 0 deletions test/test_processing_utils.py
Expand Up @@ -101,6 +101,11 @@ def test_save_url_to_cache_with_spaces(self, gradio_temp_dir):
processing_utils.save_url_to_cache(url, cache_dir=gradio_temp_dir)
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1

def test_save_url_to_cache_with_redirect(self, gradio_temp_dir):
url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png"
processing_utils.save_url_to_cache(url, cache_dir=gradio_temp_dir)
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1


class TestImagePreprocessing:
def test_encode_plot_to_base64(self):
Expand Down
3 changes: 3 additions & 0 deletions test/test_utils.py
Expand Up @@ -198,6 +198,9 @@ def test_valid_urls(self):
assert validate_url(
"https://upload.wikimedia.org/wikipedia/commons/b/b0/Bengal_tiger_%28Panthera_tigris_tigris%29_female_3_crop.jpg"
)
assert validate_url(
"https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png"
)

def test_invalid_urls(self):
assert not (validate_url("C:/Users/"))
Expand Down

0 comments on commit b25e95e

Please sign in to comment.