Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions upath/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def _fsspec_protocol_equals(p0: str, p1: str) -> bool:
try:
o0 = _fsspec_registry_map[p0]
except KeyError:
raise ValueError(f"Protocol not known: {p0}")
raise ValueError(f"Protocol not known: {p0!r}")
try:
o1 = _fsspec_registry_map[p1]
except KeyError:
raise ValueError(f"Protocol not known: {p1}")
raise ValueError(f"Protocol not known: {p1!r}")

return o0 == o1

Expand Down Expand Up @@ -81,14 +81,22 @@ def get_upath_protocol(
pth_protocol = _match_protocol(str(pth))
# if storage_options and not protocol and not pth_protocol:
# protocol = "file"
if (
if protocol is None:
return pth_protocol or ""
elif (
protocol
and pth_protocol
and not _fsspec_protocol_equals(pth_protocol, protocol)
):
raise ValueError(
f"requested protocol {protocol!r} incompatible with {pth_protocol!r}"
)
elif protocol == "" and pth_protocol:
# explicitly requested empty protocol, but path has non-empty protocol
raise ValueError(
f"explicitly requested empty protocol {protocol!r}"
f" incompatible with {pth_protocol!r}"
)
return protocol or pth_protocol or ""


Expand Down
5 changes: 4 additions & 1 deletion upath/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,7 +1219,10 @@ def __get_pydantic_core_schema__(
),
"protocol": core_schema.typed_dict_field(
core_schema.with_default_schema(
core_schema.str_schema(), default=""
core_schema.nullable_schema(
core_schema.str_schema(),
),
default=None,
),
required=False,
),
Expand Down
35 changes: 35 additions & 0 deletions upath/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,38 @@ def test_open_a_local_upath(tmp_path, protocol):
u = UPath(p, protocol=protocol)
with open(u, "rb") as f:
assert f.read() == b"hello world"


@pytest.mark.parametrize(
"uri,protocol",
[
("s3://bucket/folder", "s3"),
("gs://bucket/folder", "gs"),
("bucket/folder", "s3"),
("memory://folder", "memory"),
("file:/tmp/folder", "file"),
("/tmp/folder", "file"),
("/tmp/folder", ""),
("a/b/c", ""),
],
)
def test_constructor_compatible_protocol_uri(uri, protocol):
p = UPath(uri, protocol=protocol)
assert p.protocol == protocol


@pytest.mark.parametrize(
"uri,protocol",
[
("s3://bucket/folder", "gs"),
("gs://bucket/folder", "s3"),
("memory://folder", "s3"),
("file:/tmp/folder", "s3"),
("s3://bucket/folder", ""),
("memory://folder", ""),
("file:/tmp/folder", ""),
],
)
def test_constructor_incompatible_protocol_uri(uri, protocol):
with pytest.raises(ValueError, match=r".*incompatible with"):
UPath(uri, protocol=protocol)