diff --git a/.changeset/afraid-bags-lose.md b/.changeset/afraid-bags-lose.md new file mode 100644 index 000000000000..17d1806a0008 --- /dev/null +++ b/.changeset/afraid-bags-lose.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +feat:Extract video filenames correctly from URLs diff --git a/gradio/components/video.py b/gradio/components/video.py index 901287383394..34e7b3cf10d2 100644 --- a/gradio/components/video.py +++ b/gradio/components/video.py @@ -275,9 +275,8 @@ def _format_video(self, video: str | Path | None) -> FileData | None: "Video does not have browser-compatible container or codec. Converting to mp4" ) video = processing_utils.convert_video_to_playable_mp4(video) - # Recalculate the format in case convert_video_to_playable_mp4 already made it the - # selected format - returned_format = video.split(".")[-1].lower() + # Recalculate the format in case convert_video_to_playable_mp4 already made it the selected format + returned_format = utils.get_extension_from_file_path_or_url(video).lower() if self.format is not None and returned_format != self.format: if wasm_utils.IS_WASM: raise wasm_utils.WasmUnsupportedError( diff --git a/gradio/utils.py b/gradio/utils.py index 8416d630f966..e724480e8247 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -15,6 +15,7 @@ import threading import traceback import typing +import urllib.parse import warnings from abc import ABC, abstractmethod from contextlib import contextmanager @@ -969,3 +970,13 @@ def default_input_labels(): while True: yield f"input {n}" n += 1 + + +def get_extension_from_file_path_or_url(file_path_or_url: str) -> str: + """ + Returns the file extension (without the dot) from a file path or URL. If the file path or URL does not have a file extension, returns an empty string. + For example, "https://example.com/avatar/xxxx.mp4?se=2023-11-16T06:51:23Z&sp=r" would return "mp4". + """ + parsed_url = urllib.parse.urlparse(file_path_or_url) + file_extension = os.path.splitext(os.path.basename(parsed_url.path))[1] + return file_extension[1:] if file_extension else "" diff --git a/test/test_utils.py b/test/test_utils.py index 71b8c2571da5..9b49872ff7ff 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -22,6 +22,7 @@ delete_none, format_ner_list, get_continuous_fn, + get_extension_from_file_path_or_url, get_type_hints, ipython_check, is_in_or_equal, @@ -402,3 +403,16 @@ def test_is_in_or_equal(): assert is_in_or_equal("/home/usr/notes.txt", "/home/usr/") assert not is_in_or_equal("/home/usr/subdirectory", "/home/usr/notes.txt") assert not is_in_or_equal("/home/usr/../../etc/notes.txt", "/home/usr/") + + +@pytest.mark.parametrize( + "path_or_url, extension", + [ + ("https://example.com/avatar/xxxx.mp4?se=2023-11-16T06:51:23Z&sp=r", "mp4"), + ("/home/user/documents/example.pdf", "pdf"), + ("C:\\Users\\user\\documents\\example.png", "png"), + ("C:/Users/user/documents/example", ""), + ], +) +def test_get_extension_from_file_path_or_url(path_or_url, extension): + assert get_extension_from_file_path_or_url(path_or_url) == extension