diff --git a/pipeline/cloud/schemas/runs.py b/pipeline/cloud/schemas/runs.py index 8b79fd2f..e532fc7f 100644 --- a/pipeline/cloud/schemas/runs.py +++ b/pipeline/cloud/schemas/runs.py @@ -3,6 +3,9 @@ import typing as t from datetime import datetime from enum import Enum +from urllib.parse import quote, unquote + +from pydantic import root_validator, validator from pipeline.cloud.schemas import BaseModel @@ -208,13 +211,39 @@ def result_array(self) -> t.List[t.Any]: class RunInput(BaseModel): type: RunIOType value: t.Any - file_name: t.Optional[str] file_path: t.Optional[str] - # The file URL is only populated when this schema is - # returned by the API, the user should never populate it file_url: t.Optional[str] + @validator("file_url", pre=True, always=True) + def encode_url(cls, v): + if v is not None: + # check whether has already been encoded, to avoid + # multiple encoding + if v != unquote(v): + return v + return quote(v, safe="/:") + return v + + @classmethod + def encode_nested_urls(cls, value): + if isinstance(value, dict): + for key, val in value.items(): + if key == "file_url" and isinstance(val, str): + value[key] = cls.encode_url(val) + elif isinstance(val, (dict, list)): + cls.encode_nested_urls(val) + elif isinstance(value, list): + for item in value: + cls.encode_nested_urls(item) + return value + + @root_validator(pre=True) + def handle_nested_inputs(cls, values): + if "value" in values: + values["value"] = cls.encode_nested_urls(values["value"]) + return values + class ContainerRunErrorType(str, Enum): input_error = "input_error" diff --git a/pipeline/container/manager.py b/pipeline/container/manager.py index ee36bf2c..c16c3c9c 100644 --- a/pipeline/container/manager.py +++ b/pipeline/container/manager.py @@ -4,6 +4,7 @@ import traceback import typing as t import urllib.parse +from http.client import InvalidURL from pathlib import Path from types import NoneType, UnionType from urllib import request @@ -113,14 +114,23 @@ def _resolve_file_variable_to_local( local_host_dir = "/tmp" if hasattr(file, "url") and file.url is not None: + # Encode the URL to handle spaces and other non-URL-safe characters + encoded_url = run_schemas.RunInput.encode_url(file.url.geturl()) cache_name = hashlib.md5(file.url.geturl().encode()).hexdigest() file_name = file.url.geturl().split("/")[-1] local_path = f"{local_host_dir}/{cache_name}/{file_name}" file_path = Path(local_path) file_path.parent.mkdir(parents=True, exist_ok=True) + try: + # Use the encoded URL for retrieving the file + request.urlretrieve(encoded_url, local_path) + # This should not be raise due to encoded_url, but including it in case + except InvalidURL: + raise Exception("The file to download has an invalid URL.") + except Exception: + raise Exception("Error downloading file.") - request.urlretrieve(file.url.geturl(), local_path) elif file.remote_id is not None or file.path is not None: local_path = Path(file.path) local_path.parent.mkdir(parents=True, exist_ok=True) @@ -130,12 +140,6 @@ def _resolve_file_variable_to_local( if isinstance(file, Directory): raise NotImplementedError("Remote ID not implemented yet") - # file.path = Path(f"{local_host_dir}/{cache_name}_dir") - - # with zipfile.ZipFile(local_path, "r") as zip_ref: - # zip_ref.extractall(str(file.path)) - # return - file.path = Path(local_path) def _create_file_variable( diff --git a/pyproject.toml b/pyproject.toml index 5fc34000..e3f11655 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pipeline-ai" -version = "2.1.7" +version = "2.1.8" description = "Pipelines for machine learning workloads." authors = [ "Paul Hetherington ", diff --git a/tests/schemas/test_run_input.py b/tests/schemas/test_run_input.py new file mode 100644 index 00000000..c89b8602 --- /dev/null +++ b/tests/schemas/test_run_input.py @@ -0,0 +1,150 @@ +from pipeline.cloud.schemas.runs import RunInput, RunIOType + + +def test_url_with_spaces(): + url_with_spaces = "http://example.com/some file.png" + expected_url = "http://example.com/some%20file.png" + # First instance of RunInput + run_input1 = RunInput(file_url=url_with_spaces, type=RunIOType.file) + assert ( + run_input1.file_url == expected_url + ), "URL with spaces should be encoded correctly in the first instance" + + # Second instance of RunInput using the file_url from the first instance + run_input2 = RunInput(file_url=run_input1.file_url, type=RunIOType.file) + assert ( + run_input2.file_url == expected_url + ), "URL with spaces should be encoded correctly in the second instance" + + +def test_url_without_spaces(): + url_without_spaces = "http://example.com/file.png" + run_input = RunInput(file_url=url_without_spaces, type=RunIOType.file) + assert ( + run_input.file_url == url_without_spaces + ), "URL without spaces should not be altered" + + +def test_none_url(): + run_input = RunInput(file_url=None, type=RunIOType.file) + assert run_input.file_url is None, "None URL should remain None" + + +def test_nested_run_input_encoding(): + nested_input = { + "type": "dictionary", + "value": { + "file_1": RunInput( + type="file", file_url="http://example.com/some file.png" + ), + "file_2": RunInput( + type="file", file_url="http://example.com/another file.png" + ), + }, + "file_url": None, + } + run_input = RunInput(**nested_input) + assert ( + run_input.value["file_1"].file_url == "http://example.com/some%20file.png" + ), "Nested URL file_1 should be encoded correctly" + assert ( + run_input.value["file_2"].file_url == "http://example.com/another%20file.png" + ), "Nested URL file_2 should be encoded correctly" + + +def test_deeply_nested_run_input_encoding(): + deeply_nested_input = { + "type": "dictionary", + "value": { + "level_1": { + "level_2": { + "file_3": RunInput( + type="file", file_url="http://example.com/yet another file.png" + ) + } + } + }, + "file_url": None, + } + run_input = RunInput(**deeply_nested_input) + assert ( + run_input.value["level_1"]["level_2"]["file_3"].file_url + == "http://example.com/yet%20another%20file.png" + ), "Deeply nested URL should be encoded correctly" + + +def test_mixed_content_encoding(): + mixed_content_input = { + "type": "dictionary", + "value": { + "file_4": RunInput( + type="file", file_url="http://example.com/file with space.png" + ), + "non_file_data": { + "file_5": RunInput( + type="file", + file_url="http://example.com/another file with space.png", + ) + }, + }, + "file_url": None, + } + run_input = RunInput(**mixed_content_input) + assert ( + run_input.value["file_4"].file_url + == "http://example.com/file%20with%20space.png" + ), "Mixed content URL file_4 should be encoded correctly" + assert ( + run_input.value["non_file_data"]["file_5"].file_url + == "http://example.com/another%20file%20with%20space.png" + ), "Mixed content URL file_5 should be encoded correctly" + + +def test_json_run_input_list(): + input_list = [ + { + "type": "file", + "value": None, + "file_url": "https://storage.googleapis.com/catalyst-v4/pipeline_files/6f/d2/image 0.jpeg", # noqa + }, + { + "type": "dictionary", + "value": { + "file_1": { + "type": "file", + "value": None, + "file_url": "https://storage.googleapis.com/catalyst-v4/pipeline_files/d4/99/image 0.jpeg", # noqa + }, + "file_2": { + "type": "file", + "value": None, + "file_url": "https://storage.googleapis.com/catalyst-v4/pipeline_files/c7/81/image 0.jpeg", # noqa + }, + }, + }, + ] + + # Convert dictionaries to RunInput instances + run_inputs = [] + for item in input_list: + if "file_url" in item: + run_inputs.append(RunInput(**item)) + else: + # Create a new dictionary for the value field where each sub-item + # is converted to RunInput + modified_value = {k: RunInput(**v) for k, v in item["value"].items()} + # Create RunInput instance with the modified value + run_inputs.append(RunInput(type=item["type"], value=modified_value)) + + assert ( + run_inputs[0].file_url + == "https://storage.googleapis.com/catalyst-v4/pipeline_files/6f/d2/image%200.jpeg" # noqa + ), "Top-level URL should be encoded correctly" + assert ( + run_inputs[1].value["file_1"].file_url + == "https://storage.googleapis.com/catalyst-v4/pipeline_files/d4/99/image%200.jpeg" # noqa + ), "Nested URL file_1 should be encoded correctly" + assert ( + run_inputs[1].value["file_2"].file_url + == "https://storage.googleapis.com/catalyst-v4/pipeline_files/c7/81/image%200.jpeg" # noqa + ), "Nested URL file_2 should be encoded correctly"