Skip to content

Commit

Permalink
Implement UPath.joinuri (#189)
Browse files Browse the repository at this point in the history
* tests: add tests for query passthrough and joinuri
* upath._flavour: add upath_urijoin
* upath: add UPath.joinuri method
* upath: UPath().name returns last non-empty part
  • Loading branch information
ap-- committed Feb 18, 2024
1 parent 3bab4c0 commit 1a117b3
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 0 deletions.
62 changes: 62 additions & 0 deletions upath/_flavour.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

__all__ = [
"FSSpecFlavour",
"upath_urijoin",
]


Expand Down Expand Up @@ -299,3 +300,64 @@ def splitroot(p):
return splitroot
else:
raise NotImplementedError(f"unsupported module: {mod!r}")


def upath_urijoin(base: str, uri: str) -> str:
"""Join a base URI and a possibly relative URI to form an absolute
interpretation of the latter."""
# see:
# https://github.com/python/cpython/blob/ae6c01d9d2/Lib/urllib/parse.py#L539-L605
# modifications:
# - removed allow_fragments parameter
# - all schemes are considered to allow relative paths
# - all schemes are considered to allow netloc (revisit this)
# - no bytes support (removes encoding and decoding)
if not base:
return uri
if not uri:
return base

bs = urlsplit(base, scheme="")
us = urlsplit(uri, scheme=bs.scheme)

if us.scheme != bs.scheme: # or us.scheme not in uses_relative:
return uri
# if us.scheme in uses_netloc:
if us.netloc:
return us.geturl()
else:
us = us._replace(netloc=bs.netloc)
# end if
if not us.path and not us.fragment:
us = us._replace(path=bs.path, fragment=bs.fragment)
if not us.query:
us = us._replace(query=bs.query)
return us.geturl()

base_parts = bs.path.split("/")
if base_parts[-1] != "":
del base_parts[-1]

if us.path[:1] == "/":
segments = us.path.split("/")
else:
segments = base_parts + us.path.split("/")
segments[1:-1] = filter(None, segments[1:-1])

resolved_path = []

for seg in segments:
if seg == "..":
try:
resolved_path.pop()
except IndexError:
pass
elif seg == ".":
continue
else:
resolved_path.append(seg)

if segments[-1] in (".", ".."):
resolved_path.append("")

return us._replace(path="/".join(resolved_path) or "/").geturl()
24 changes: 24 additions & 0 deletions upath/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from upath._compat import str_remove_prefix
from upath._compat import str_remove_suffix
from upath._flavour import FSSpecFlavour
from upath._flavour import upath_urijoin
from upath._protocol import get_upath_protocol
from upath._stat import UPathStatResult
from upath.registry import get_upath_class
Expand Down Expand Up @@ -253,6 +254,18 @@ def fs(self) -> AbstractFileSystem:
def path(self) -> str:
return super().__str__()

def joinuri(self, uri: str | os.PathLike[str]) -> UPath:
"""Join with urljoin behavior for UPath instances"""
# short circuit if the new uri uses a different protocol
other_protocol = get_upath_protocol(uri)
if other_protocol and other_protocol != self._protocol:
return UPath(uri)
return UPath(
upath_urijoin(str(self), str(uri)),
protocol=other_protocol or self._protocol,
**self.storage_options,
)

# === upath.UPath CUSTOMIZABLE API ================================

@classmethod
Expand Down Expand Up @@ -590,6 +603,17 @@ def is_relative_to(self, other, /, *_deprecated):
return False
return super().is_relative_to(other, *_deprecated)

@property
def name(self):
tail = self._tail
if not tail:
return ""
name = tail[-1]
if not name and len(tail) >= 2:
return tail[-2]
else:
return name

# === pathlib.Path ================================================

def stat(self, *, follow_symlinks=True) -> UPathStatResult:
Expand Down
42 changes: 42 additions & 0 deletions upath/tests/implementations/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,45 @@ def test_empty_parts(args, parts):
pth = UPath(args)
pth_parts = pth.parts
assert pth_parts == parts


def test_query_parameters_passthrough():
pth = UPath("http://example.com/?a=1&b=2")
assert pth.parts == ("http://example.com/", "?a=1&b=2")


@pytest.mark.parametrize(
"base,rel,expected",
[
(
"http://www.example.com/a/b/index.html",
"image.png?version=1",
"http://www.example.com/a/b/image.png?version=1",
),
(
"http://www.example.com/a/b/index.html",
"../image.png",
"http://www.example.com/a/image.png",
),
(
"http://www.example.com/a/b/index.html",
"/image.png",
"http://www.example.com/image.png",
),
(
"http://www.example.com/a/b/index.html",
"ftp://other.com/image.png",
"ftp://other.com/image.png",
),
(
"http://www.example.com/a/b/index.html",
"//other.com/image.png",
"http://other.com/image.png",
),
],
)
def test_joinuri_behavior(base, rel, expected):
p0 = UPath(base)
pr = p0.joinuri(rel)
pe = UPath(expected)
assert pr == pe

0 comments on commit 1a117b3

Please sign in to comment.