Skip to content

Commit

Permalink
construct url-safe url when downloading file onto container (#453)
Browse files Browse the repository at this point in the history
* construct url-safe url when downloading file onto container

* safe encode url when downloading file

* include logic for safe file_url in RunInput schema

* include url encoding on RunInput schema

* accommodate for nested run input structures

* check for file_url key and setattr

* different approach

* new approach

* formatting

* do not do double encode in manager

* remove safe kwarg

* flake8 formatting

* fix typo
  • Loading branch information
plutopulp committed May 2, 2024
1 parent 0e5baa4 commit fcc8480
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 11 deletions.
35 changes: 32 additions & 3 deletions pipeline/cloud/schemas/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import typing as t
from datetime import datetime
from enum import Enum
from urllib.parse import quote, unquote

from pydantic import root_validator, validator

from pipeline.cloud.schemas import BaseModel

Expand Down Expand Up @@ -208,13 +211,39 @@ def result_array(self) -> t.List[t.Any]:
class RunInput(BaseModel):
type: RunIOType
value: t.Any

file_name: t.Optional[str]
file_path: t.Optional[str]
# The file URL is only populated when this schema is
# returned by the API, the user should never populate it
file_url: t.Optional[str]

@validator("file_url", pre=True, always=True)
def encode_url(cls, v):
if v is not None:
# check whether has already been encoded, to avoid
# multiple encoding
if v != unquote(v):
return v
return quote(v, safe="/:")
return v

@classmethod
def encode_nested_urls(cls, value):
if isinstance(value, dict):
for key, val in value.items():
if key == "file_url" and isinstance(val, str):
value[key] = cls.encode_url(val)
elif isinstance(val, (dict, list)):
cls.encode_nested_urls(val)
elif isinstance(value, list):
for item in value:
cls.encode_nested_urls(item)
return value

@root_validator(pre=True)
def handle_nested_inputs(cls, values):
if "value" in values:
values["value"] = cls.encode_nested_urls(values["value"])
return values


class ContainerRunErrorType(str, Enum):
input_error = "input_error"
Expand Down
18 changes: 11 additions & 7 deletions pipeline/container/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import traceback
import typing as t
import urllib.parse
from http.client import InvalidURL
from pathlib import Path
from types import NoneType, UnionType
from urllib import request
Expand Down Expand Up @@ -113,14 +114,23 @@ def _resolve_file_variable_to_local(
local_host_dir = "/tmp"

if hasattr(file, "url") and file.url is not None:
# Encode the URL to handle spaces and other non-URL-safe characters
encoded_url = run_schemas.RunInput.encode_url(file.url.geturl())
cache_name = hashlib.md5(file.url.geturl().encode()).hexdigest()

file_name = file.url.geturl().split("/")[-1]
local_path = f"{local_host_dir}/{cache_name}/{file_name}"
file_path = Path(local_path)
file_path.parent.mkdir(parents=True, exist_ok=True)
try:
# Use the encoded URL for retrieving the file
request.urlretrieve(encoded_url, local_path)
# This should not be raise due to encoded_url, but including it in case
except InvalidURL:
raise Exception("The file to download has an invalid URL.")
except Exception:
raise Exception("Error downloading file.")

request.urlretrieve(file.url.geturl(), local_path)
elif file.remote_id is not None or file.path is not None:
local_path = Path(file.path)
local_path.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -130,12 +140,6 @@ def _resolve_file_variable_to_local(
if isinstance(file, Directory):
raise NotImplementedError("Remote ID not implemented yet")

# file.path = Path(f"{local_host_dir}/{cache_name}_dir")

# with zipfile.ZipFile(local_path, "r") as zip_ref:
# zip_ref.extractall(str(file.path))
# return

file.path = Path(local_path)

def _create_file_variable(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pipeline-ai"
version = "2.1.7"
version = "2.1.8"
description = "Pipelines for machine learning workloads."
authors = [
"Paul Hetherington <ph@mystic.ai>",
Expand Down
150 changes: 150 additions & 0 deletions tests/schemas/test_run_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from pipeline.cloud.schemas.runs import RunInput, RunIOType


def test_url_with_spaces():
url_with_spaces = "http://example.com/some file.png"
expected_url = "http://example.com/some%20file.png"
# First instance of RunInput
run_input1 = RunInput(file_url=url_with_spaces, type=RunIOType.file)
assert (
run_input1.file_url == expected_url
), "URL with spaces should be encoded correctly in the first instance"

# Second instance of RunInput using the file_url from the first instance
run_input2 = RunInput(file_url=run_input1.file_url, type=RunIOType.file)
assert (
run_input2.file_url == expected_url
), "URL with spaces should be encoded correctly in the second instance"


def test_url_without_spaces():
url_without_spaces = "http://example.com/file.png"
run_input = RunInput(file_url=url_without_spaces, type=RunIOType.file)
assert (
run_input.file_url == url_without_spaces
), "URL without spaces should not be altered"


def test_none_url():
run_input = RunInput(file_url=None, type=RunIOType.file)
assert run_input.file_url is None, "None URL should remain None"


def test_nested_run_input_encoding():
nested_input = {
"type": "dictionary",
"value": {
"file_1": RunInput(
type="file", file_url="http://example.com/some file.png"
),
"file_2": RunInput(
type="file", file_url="http://example.com/another file.png"
),
},
"file_url": None,
}
run_input = RunInput(**nested_input)
assert (
run_input.value["file_1"].file_url == "http://example.com/some%20file.png"
), "Nested URL file_1 should be encoded correctly"
assert (
run_input.value["file_2"].file_url == "http://example.com/another%20file.png"
), "Nested URL file_2 should be encoded correctly"


def test_deeply_nested_run_input_encoding():
deeply_nested_input = {
"type": "dictionary",
"value": {
"level_1": {
"level_2": {
"file_3": RunInput(
type="file", file_url="http://example.com/yet another file.png"
)
}
}
},
"file_url": None,
}
run_input = RunInput(**deeply_nested_input)
assert (
run_input.value["level_1"]["level_2"]["file_3"].file_url
== "http://example.com/yet%20another%20file.png"
), "Deeply nested URL should be encoded correctly"


def test_mixed_content_encoding():
mixed_content_input = {
"type": "dictionary",
"value": {
"file_4": RunInput(
type="file", file_url="http://example.com/file with space.png"
),
"non_file_data": {
"file_5": RunInput(
type="file",
file_url="http://example.com/another file with space.png",
)
},
},
"file_url": None,
}
run_input = RunInput(**mixed_content_input)
assert (
run_input.value["file_4"].file_url
== "http://example.com/file%20with%20space.png"
), "Mixed content URL file_4 should be encoded correctly"
assert (
run_input.value["non_file_data"]["file_5"].file_url
== "http://example.com/another%20file%20with%20space.png"
), "Mixed content URL file_5 should be encoded correctly"


def test_json_run_input_list():
input_list = [
{
"type": "file",
"value": None,
"file_url": "https://storage.googleapis.com/catalyst-v4/pipeline_files/6f/d2/image 0.jpeg", # noqa
},
{
"type": "dictionary",
"value": {
"file_1": {
"type": "file",
"value": None,
"file_url": "https://storage.googleapis.com/catalyst-v4/pipeline_files/d4/99/image 0.jpeg", # noqa
},
"file_2": {
"type": "file",
"value": None,
"file_url": "https://storage.googleapis.com/catalyst-v4/pipeline_files/c7/81/image 0.jpeg", # noqa
},
},
},
]

# Convert dictionaries to RunInput instances
run_inputs = []
for item in input_list:
if "file_url" in item:
run_inputs.append(RunInput(**item))
else:
# Create a new dictionary for the value field where each sub-item
# is converted to RunInput
modified_value = {k: RunInput(**v) for k, v in item["value"].items()}
# Create RunInput instance with the modified value
run_inputs.append(RunInput(type=item["type"], value=modified_value))

assert (
run_inputs[0].file_url
== "https://storage.googleapis.com/catalyst-v4/pipeline_files/6f/d2/image%200.jpeg" # noqa
), "Top-level URL should be encoded correctly"
assert (
run_inputs[1].value["file_1"].file_url
== "https://storage.googleapis.com/catalyst-v4/pipeline_files/d4/99/image%200.jpeg" # noqa
), "Nested URL file_1 should be encoded correctly"
assert (
run_inputs[1].value["file_2"].file_url
== "https://storage.googleapis.com/catalyst-v4/pipeline_files/c7/81/image%200.jpeg" # noqa
), "Nested URL file_2 should be encoded correctly"

0 comments on commit fcc8480

Please sign in to comment.