Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug(python): fix path handling in windows #724

Merged
merged 8 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,19 @@ jobs:
run: pytest -m "not slow" -x -v --durations=30 tests
- name: doctest
run: pytest --doctest-modules lancedb
mac:
platform:
name: "Platform: ${{ matrix.config.name }}"
timeout-minutes: 30
strategy:
matrix:
mac-runner: [ "macos-13", "macos-13-xlarge" ]
runs-on: "${{ matrix.mac-runner }}"
config:
- name: x86 Mac
runner: macos-13
- name: Arm Mac
runner: macos-13-xlarge
- name: x86 Windows
runner: windows-latest
runs-on: "${{ matrix.config.runner }}"
defaults:
run:
shell: bash
Expand Down
11 changes: 5 additions & 6 deletions python/lancedb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pyarrow import fs

from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri

if TYPE_CHECKING:
from .common import DATA, URI
Expand Down Expand Up @@ -288,14 +288,13 @@ def table_names(
A list of table names.
"""
try:
filesystem, path = fs_from_uri(self.uri)
filesystem = fs_from_uri(self.uri)[0]
except pa.ArrowInvalid:
raise NotImplementedError("Unsupported scheme: " + self.uri)

try:
paths = filesystem.get_file_info(
fs.FileSelector(get_uri_location(self.uri))
)
loc = get_uri_location(self.uri)
paths = filesystem.get_file_info(fs.FileSelector(loc))
except FileNotFoundError:
# It is ok if the file does not exist since it will be created
paths = []
Expand Down Expand Up @@ -373,7 +372,7 @@ def drop_table(self, name: str, ignore_missing: bool = False):
"""
try:
filesystem, path = fs_from_uri(self.uri)
table_path = os.path.join(path, name + ".lance")
table_path = join_uri(path, name + ".lance")
filesystem.delete_dir(table_path)
except FileNotFoundError:
if not ignore_missing:
Expand Down
6 changes: 3 additions & 3 deletions python/lancedb/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .pydantic import LanceModel, model_to_dict
from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas, value_to_sql
from .util import fs_from_uri, safe_import_pandas, value_to_sql, join_uri
from .utils.events import register_event

if TYPE_CHECKING:
Expand Down Expand Up @@ -551,7 +551,7 @@ def to_arrow(self) -> pa.Table:

@property
def _dataset_uri(self) -> str:
return os.path.join(self._conn.uri, f"{self.name}.lance")
return join_uri(self._conn.uri, f"{self.name}.lance")

def create_index(
self,
Expand Down Expand Up @@ -597,7 +597,7 @@ def create_fts_index(self, field_names: Union[str, List[str]]):
register_event("create_fts_index")

def _get_fts_index_path(self):
return os.path.join(self._dataset_uri, "_indices", "tantivy")
return join_uri(self._dataset_uri, "_indices", "tantivy")

@cached_property
def _dataset(self) -> LanceDataset:
Expand Down
32 changes: 31 additions & 1 deletion python/lancedb/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import os
from datetime import date, datetime
from functools import singledispatch
from typing import Tuple
import pathlib
from typing import Tuple, Union
from urllib.parse import urlparse

import numpy as np
Expand Down Expand Up @@ -62,6 +63,12 @@ def get_uri_location(uri: str) -> str:
str: Location part of the URL, without scheme
"""
parsed = urlparse(uri)
if len(parsed.scheme) == 1:
# Windows drive names are parsed as the scheme
# e.g. "c:\path" -> ParseResult(scheme="c", netloc="", path="/path", ...)
# So we add special handling here for schemes that are a single character
return uri

if not parsed.netloc:
return parsed.path
else:
Expand All @@ -84,6 +91,29 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]:
return pa_fs.FileSystem.from_uri(uri)


def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
"""
Join a URI with multiple parts, handles both local and remote paths

Parameters
----------
base : str
The base URI
parts : str
The parts to join to the base URI, each separated by the
appropriate path separator for the URI scheme and OS
"""
if isinstance(base, pathlib.Path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add tests?

return base.joinpath(*parts)
base = str(base)
if get_uri_scheme(base) == "file":
# using pathlib for local paths make this windows compatible
# `get_uri_scheme` returns `file` for windows drive names (e.g. `c:\path`)
return str(pathlib.Path(base, *parts))
# for remote paths, just use os.path.join
return "/".join([p.rstrip("/") for p in [base, *parts]])


def safe_import_pandas():
try:
import pandas as pd
Expand Down
59 changes: 58 additions & 1 deletion python/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from lancedb.util import get_uri_scheme
import os
import pathlib

import pytest

from lancedb.util import get_uri_scheme, join_uri


def test_normalize_uri():
Expand All @@ -28,3 +33,55 @@ def test_normalize_uri():
for uri, expected_scheme in zip(uris, schemes):
parsed_scheme = get_uri_scheme(uri)
assert parsed_scheme == expected_scheme


def test_join_uri_remote():
schemes = ["s3", "az", "gs"]
for scheme in schemes:
expected = f"{scheme}://bucket/path/to/table.lance"
base_uri = f"{scheme}://bucket/path/to/"
parts = ["table.lance"]
assert join_uri(base_uri, *parts) == expected

base_uri = f"{scheme}://bucket"
parts = ["path", "to", "table.lance"]
assert join_uri(base_uri, *parts) == expected


# skip this test if on windows
@pytest.mark.skipif(os.name == "nt", reason="Windows paths are not POSIX")
def test_join_uri_posix():
for base in [
# relative path
"relative/path",
"relative/path/",
# an absolute path
"/absolute/path",
"/absolute/path/",
# a file URI
"file:///absolute/path",
"file:///absolute/path/",
]:
joined = join_uri(base, "table.lance")
assert joined == str(pathlib.Path(base) / "table.lance")
joined = join_uri(pathlib.Path(base), "table.lance")
assert joined == pathlib.Path(base) / "table.lance"


# skip this test if not on windows
@pytest.mark.skipif(os.name != "nt", reason="Windows paths are not POSIX")
def test_local_join_uri_windows():
# https://learn.microsoft.com/en-us/dotnet/standard/io/file-path-formats
for base in [
# windows relative path
"relative\\path",
"relative\\path\\",
# windows absolute path from current drive
"c:\\absolute\\path",
# relative path from root of current drive
"\\relative\\path",
]:
joined = join_uri(base, "table.lance")
assert joined == str(pathlib.Path(base) / "table.lance")
joined = join_uri(pathlib.Path(base), "table.lance")
assert joined == pathlib.Path(base) / "table.lance"
Loading