From b25e95e164e80d66203ef71ce6bdb67ceb6b24df Mon Sep 17 00:00:00 2001 From: "Yuichiro Tachibana (Tsuchiya)" Date: Tue, 6 Feb 2024 19:46:20 +0000 Subject: [PATCH] Fix `processing_utils.save_url_to_cache()` to follow redirects when accessing the URL (#7322) * Fix `processing_utils.save_url_to_cache()` to follow redirects when accessing the URL * add changeset * follow more redirects * format * add changeset * add test * validate urls --------- Co-authored-by: gradio-pr-bot Co-authored-by: Abubakar Abid --- .changeset/smooth-ends-bet.md | 6 ++++++ client/python/gradio_client/utils.py | 8 ++++++-- gradio/processing_utils.py | 4 +++- gradio/utils.py | 4 +++- test/test_processing_utils.py | 5 +++++ test/test_utils.py | 3 +++ 6 files changed, 26 insertions(+), 4 deletions(-) create mode 100644 .changeset/smooth-ends-bet.md diff --git a/.changeset/smooth-ends-bet.md b/.changeset/smooth-ends-bet.md new file mode 100644 index 000000000000..cc81671c2f67 --- /dev/null +++ b/.changeset/smooth-ends-bet.md @@ -0,0 +1,6 @@ +--- +"gradio": patch +"gradio_client": patch +--- + +fix:Fix `processing_utils.save_url_to_cache()` to follow redirects when accessing the URL diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index d520f556e7b2..60d98ded9f52 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -633,7 +633,9 @@ def download_file( temp_dir = Path(tempfile.gettempdir()) / secrets.token_hex(20) temp_dir.mkdir(exist_ok=True, parents=True) - with httpx.stream("GET", url_path, headers=headers) as response: + with httpx.stream( + "GET", url_path, headers=headers, follow_redirects=True + ) as response: response.raise_for_status() with open(temp_dir / Path(url_path).name, "wb") as f: for chunk in response.iter_bytes(chunk_size=128 * sha1.block_size): @@ -666,7 +668,9 @@ def download_tmp_copy_of_file( directory.mkdir(exist_ok=True, parents=True) file_path = directory / Path(url_path).name - with httpx.stream("GET", url_path, headers=headers) as response: + with httpx.stream( + "GET", url_path, headers=headers, follow_redirects=True + ) as response: response.raise_for_status() with open(file_path, "wb") as f: for chunk in response.iter_raw(): diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index d6d6f2c3c7c2..4c07df2d4ab1 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -190,7 +190,9 @@ def save_url_to_cache(url: str, cache_dir: str) -> str: full_temp_file_path = str(abspath(temp_dir / name)) if not Path(full_temp_file_path).exists(): - with httpx.stream("GET", url) as r, open(full_temp_file_path, "wb") as f: + with httpx.stream("GET", url, follow_redirects=True) as r, open( + full_temp_file_path, "wb" + ) as f: for chunk in r.iter_raw(): f.write(chunk) diff --git a/gradio/utils.py b/gradio/utils.py index 0377ed4bf4b4..44a6e20013eb 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -592,7 +592,9 @@ def validate_url(possible_url: str) -> bool: head_request = httpx.head(possible_url, headers=headers, follow_redirects=True) # some URLs, such as AWS S3 presigned URLs, return a 405 or a 403 for HEAD requests if head_request.status_code in (403, 405): - return httpx.get(possible_url, headers=headers).is_success + return httpx.get( + possible_url, headers=headers, follow_redirects=True + ).is_success return head_request.is_success except Exception: return False diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index fbe1b1bbf65d..d5743da42c8f 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -101,6 +101,11 @@ def test_save_url_to_cache_with_spaces(self, gradio_temp_dir): processing_utils.save_url_to_cache(url, cache_dir=gradio_temp_dir) assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1 + def test_save_url_to_cache_with_redirect(self, gradio_temp_dir): + url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png" + processing_utils.save_url_to_cache(url, cache_dir=gradio_temp_dir) + assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1 + class TestImagePreprocessing: def test_encode_plot_to_base64(self): diff --git a/test/test_utils.py b/test/test_utils.py index 537a2fe9a11c..383a6f8e3cd1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -198,6 +198,9 @@ def test_valid_urls(self): assert validate_url( "https://upload.wikimedia.org/wikipedia/commons/b/b0/Bengal_tiger_%28Panthera_tigris_tigris%29_female_3_crop.jpg" ) + assert validate_url( + "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png" + ) def test_invalid_urls(self): assert not (validate_url("C:/Users/"))