Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions src/orcapod/databases/connector_arrow_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,25 @@ def __init__(
self,
connector: DBConnectorProtocol,
max_hierarchy_depth: int = 10,
_path_prefix: tuple[str, ...] = (),
_shared_pending_batches: "dict[str, pa.Table] | None" = None,
_shared_pending_record_ids: "dict[str, set[str]] | None" = None,
_shared_pending_skip_existing: "dict[str, bool] | None" = None,
) -> None:
self._connector = connector
self.max_hierarchy_depth = max_hierarchy_depth
self._pending_batches: dict[str, pa.Table] = {}
self._pending_record_ids: dict[str, set[str]] = defaultdict(set)
self._path_prefix = _path_prefix
self._pending_batches: dict[str, pa.Table] = _shared_pending_batches if _shared_pending_batches is not None else {}
self._pending_record_ids: dict[str, set[str]] = _shared_pending_record_ids if _shared_pending_record_ids is not None else defaultdict(set)
# Per-batch flag: True when the batch was added with skip_duplicates=True,
# so flush() can pass skip_existing=True to the connector and let it use
# native INSERT-OR-IGNORE semantics rather than Python-side prefiltering.
self._pending_skip_existing: dict[str, bool] = {}
self._pending_skip_existing: dict[str, bool] = _shared_pending_skip_existing if _shared_pending_skip_existing is not None else {}

# ── Path helpers ──────────────────────────────────────────────────────────

def _get_record_key(self, record_path: tuple[str, ...]) -> str:
return "/".join(record_path)
return "/".join(self._path_prefix + record_path)

def _path_to_table_name(self, record_path: tuple[str, ...]) -> str:
"""Map a record_path to a safe SQL table name.
Expand All @@ -93,10 +98,11 @@ def _path_to_table_name(self, record_path: tuple[str, ...]) -> str:
def _validate_record_path(self, record_path: tuple[str, ...]) -> None:
if not record_path:
raise ValueError("record_path cannot be empty")
if len(record_path) > self.max_hierarchy_depth:
if len(self._path_prefix) + len(record_path) > self.max_hierarchy_depth:
raise ValueError(
f"record_path depth {len(record_path)} exceeds maximum "
f"{self.max_hierarchy_depth}"
f"{self.max_hierarchy_depth - len(self._path_prefix)} "
f"(base_path uses {len(self._path_prefix)} components)"
)
for i, component in enumerate(record_path):
if not component or not isinstance(component, str):
Expand Down Expand Up @@ -165,7 +171,7 @@ def _get_committed_table(
self, record_path: tuple[str, ...]
) -> pa.Table | None:
"""Fetch all committed records for a path from the connector."""
table_name = self._path_to_table_name(record_path)
table_name = self._path_to_table_name(self._path_prefix + record_path)
if table_name not in self._connector.get_table_names():
return None
batches = list(
Expand Down Expand Up @@ -268,6 +274,42 @@ def add_records(
if flush:
self.flush()

# ── base_path / at ────────────────────────────────────────────────────────

@property
def base_path(self) -> tuple[str, ...]:
"""The current relative root of this database view (always () for root instances)."""
return self._path_prefix

def at(self, *path_components: str) -> "ConnectorArrowDatabase":
"""Return a new ConnectorArrowDatabase scoped to the given sub-path.

The returned instance shares the connector and all three pending dicts
(_pending_batches, _pending_record_ids, _pending_skip_existing) by reference.
Calling flush() on any view drains the entire shared pending queue.

Raises:
ValueError: If any path component is empty, not a str, or contains
``'/'`` or ``'\\0'`` (which would corrupt the ``'/'``-separated
record key scheme and break ``flush()``'s key reconstruction).
"""
for i, component in enumerate(path_components):
if not component or not isinstance(component, str):
raise ValueError(f"at() path component {i} is invalid: {repr(component)}")
if "/" in component or "\0" in component:
raise ValueError(
f"at() path component {repr(component)} contains an invalid character "
"('/' or '\\0')"
)
return ConnectorArrowDatabase(
connector=self._connector,
max_hierarchy_depth=self.max_hierarchy_depth,
_path_prefix=self._path_prefix + path_components,
_shared_pending_batches=self._pending_batches,
_shared_pending_record_ids=self._pending_record_ids,
_shared_pending_skip_existing=self._pending_skip_existing,
)
Comment on lines +279 to +311
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConnectorArrowDatabase.at() stores path_components into _path_prefix without validation. Since record keys are built via '/'.join(...) and flush reconstructs record_path using split('/'), a scoped prefix component containing '/' or '\0' will break round-tripping and could cause data to be written/read under unintended table names. Consider validating path_components (non-empty str; disallow '/' and '\0') before returning the new scoped instance.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in fa61441. Same validation as InMemoryArrowDatabase.at(): rejects empty strings and components containing '/' or '\\0'. Covered by test_at_rejects_slash_in_component, test_at_rejects_null_in_component, test_at_rejects_empty_component in TestAtMethod.


# ── Flush ─────────────────────────────────────────────────────────────────

def flush(self) -> None:
Expand Down Expand Up @@ -420,6 +462,7 @@ def to_config(self) -> dict[str, Any]:
return {
"type": "connector_arrow_database",
"connector": self._connector.to_config(),
"base_path": list(self._path_prefix),
"max_hierarchy_depth": self.max_hierarchy_depth,
}

Expand Down
110 changes: 90 additions & 20 deletions src/orcapod/databases/delta_lake_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,24 @@ def __init__(
batch_size: int = 1000,
max_hierarchy_depth: int = 10,
allow_schema_evolution: bool = True,
_path_prefix: tuple[str, ...] = (),
):
self._base_uri, self._storage_options = parse_base_path(base_path, storage_options)
self._is_cloud: bool = is_cloud_uri(self._base_uri)
self._root_uri, self._storage_options = parse_base_path(base_path, storage_options)
self._is_cloud: bool = is_cloud_uri(self._root_uri)
self._path_prefix = _path_prefix
self.batch_size = batch_size
self.max_hierarchy_depth = max_hierarchy_depth
self.allow_schema_evolution = allow_schema_evolution

if not self._is_cloud:
# Keep self.base_path for local-path operations (list_sources, etc.)
# NOTE: do NOT access self.base_path on cloud instances.
self.base_path = Path(self._base_uri)
# _local_root is the absolute filesystem root (for list_sources, mkdir, etc.)
# NOTE: do NOT access self._local_root on cloud instances.
self._local_root = Path(self._root_uri)
if create_base_path:
self.base_path.mkdir(parents=True, exist_ok=True)
elif not self.base_path.exists():
self._local_root.mkdir(parents=True, exist_ok=True)
elif not self._local_root.exists():
raise ValueError(
f"Base path {self.base_path} does not exist and create_base_path=False"
f"Base path {self._local_root} does not exist and create_base_path=False"
)
# For cloud paths: create_base_path is silently ignored (no directory needed).

Expand Down Expand Up @@ -100,17 +102,18 @@ def _sanitize_path_component(component: str) -> str:
return component

def _get_table_uri(self, record_path: tuple[str, ...], create_dir: bool = False) -> str:
"""Get the URI for a given record path (works for local and cloud).
"""Get the URI for a given record path, incorporating base_path prefix.

Args:
record_path: Tuple of path components.
record_path: Tuple of path components (relative to base_path).
create_dir: If True, create the local directory (no-op for cloud paths).
"""
full_path = self._path_prefix + record_path # prefix applied once, here only
if self._is_cloud:
return self._base_uri.rstrip("/") + "/" + "/".join(record_path)
return self._root_uri.rstrip("/") + "/" + "/".join(full_path)
else:
path = Path(self._base_uri)
for subpath in record_path:
path = self._local_root
for subpath in full_path:
path = path / self._sanitize_path_component(subpath)
if create_dir:
path.mkdir(parents=True, exist_ok=True)
Expand All @@ -130,9 +133,11 @@ def _validate_record_path(self, record_path: tuple[str, ...]) -> None:
if not record_path:
raise ValueError("Source path cannot be empty")

if len(record_path) > self.max_hierarchy_depth:
if len(self._path_prefix) + len(record_path) > self.max_hierarchy_depth:
raise ValueError(
f"Source path depth {len(record_path)} exceeds maximum {self.max_hierarchy_depth}"
f"Source path depth {len(record_path)} exceeds maximum "
f"{self.max_hierarchy_depth - len(self._path_prefix)} "
f"(base_path uses {len(self._path_prefix)} components)"
)

# Validate path components
Expand Down Expand Up @@ -833,7 +838,8 @@ def to_config(self) -> dict[str, Any]:
"""Serialize database configuration to a JSON-compatible dict."""
config: dict[str, Any] = {
"type": "delta_table",
"base_path": self._base_uri,
"root_uri": self._root_uri, # renamed from "base_path"
"base_path": list(self._path_prefix), # new: relative prefix tuple
"batch_size": self.batch_size,
"max_hierarchy_depth": self.max_hierarchy_depth,
"allow_schema_evolution": self.allow_schema_evolution,
Expand All @@ -843,15 +849,30 @@ def to_config(self) -> dict[str, Any]:
return config

@classmethod
def from_config(cls, config: dict[str, Any]) -> DeltaTableDatabase:
"""Reconstruct a DeltaTableDatabase from a config dict."""
def from_config(cls, config: dict[str, Any]) -> "DeltaTableDatabase":
"""Reconstruct a DeltaTableDatabase from a config dict.

Supports both the current format (``"root_uri"`` for the storage root,
``"base_path"`` as a list for the scoping prefix) and the legacy format
produced before ENG-341 (``"base_path"`` as a URI string, no prefix).
"""
if "root_uri" in config:
# Current format (post-ENG-341)
root_uri = config["root_uri"]
base_path_value = config.get("base_path", [])
_path_prefix = tuple(base_path_value) if isinstance(base_path_value, list) else ()
else:
# Legacy format (pre-ENG-341): "base_path" was the root URI string
root_uri = config["base_path"]
_path_prefix = ()
return cls(
base_path=config["base_path"],
base_path=root_uri,
storage_options=config.get("storage_options"),
create_base_path=True,
batch_size=config.get("batch_size", 1000),
max_hierarchy_depth=config.get("max_hierarchy_depth", 10),
allow_schema_evolution=config.get("allow_schema_evolution", True),
_path_prefix=_path_prefix,
)
Comment on lines 851 to 876
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeltaTableDatabase.from_config now requires config["root_uri"], which will raise KeyError when loading configs produced before this change (where the filesystem root was stored under "base_path"). Since pipeline deserialization calls cls.from_config(config) directly, this is a backwards-compatibility break. Consider accepting both shapes: treat config.get("root_uri") as preferred, but fall back to config.get("base_path") when it’s a string/Path-like root, while still supporting the new list-valued base_path prefix tuple.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in fa61441. from_config now checks for "root_uri" first (current format) and falls back to "base_path" as the URI string (legacy pre-ENG-341 format). When falling back, _path_prefix is set to (). Covered by test_from_config_accepts_legacy_base_path in TestDeltaTableDatabaseConfig.


def flush(self) -> None:
Expand Down Expand Up @@ -937,6 +958,50 @@ def flush_batch(self, record_path: tuple[str, ...]) -> None:
self._pending_record_ids[record_key] = pending_ids
raise

@property
def base_path(self) -> tuple[str, ...]:
"""The current relative root of this database view (always () for root instances)."""
return self._path_prefix

def at(self, *path_components: str) -> "DeltaTableDatabase":
"""Return a new DeltaTableDatabase scoped to the given sub-path.

The returned instance uses the same underlying filesystem root but
all reads and writes are relative to the extended prefix. Unlike
InMemoryArrowDatabase and ConnectorArrowDatabase, DeltaTableDatabase
does NOT share pending state — the filesystem is the shared storage.

Raises:
TypeError: If any component is not a str.
ValueError: If any component is empty, is ``'.'`` or ``'..'``, or
contains filesystem-unsafe characters (``/``, ``\\``, ``*``,
``?``, ``"``, ``<``, ``>``, ``|``, ``\\0``).
"""
_unsafe_chars = ["/", "\\", "*", "?", '"', "<", ">", "|", "\0"]
for i, component in enumerate(path_components):
if not isinstance(component, str):
raise TypeError(
f"at() path component {i} must be str, got {type(component)!r}"
)
if not component:
raise ValueError(f"at() path component {i} must not be empty")
if component in (".", ".."):
raise ValueError(
f"at() path component {repr(component)}: '.' and '..' are not allowed"
)
if any(char in component for char in _unsafe_chars):
raise ValueError(
f"at() path component {repr(component)} contains invalid characters"
)
return DeltaTableDatabase(
base_path=self._root_uri,
storage_options=self._storage_options,
batch_size=self.batch_size,
max_hierarchy_depth=self.max_hierarchy_depth,
allow_schema_evolution=self.allow_schema_evolution,
_path_prefix=self._path_prefix + path_components,
Comment on lines +996 to +1002
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeltaTableDatabase.at() accepts arbitrary path_components but does not validate them. Since these components become part of filesystem paths, allowing values like ".." or strings containing separators can cause unexpected directory traversal / incorrect scoping. Suggest validating path_components with the same rules as record_path components (non-empty str; disallow '/', '\', '\0', etc.; and consider explicitly rejecting '.'/'..').

Suggested change
return DeltaTableDatabase(
base_path=self._root_uri,
storage_options=self._storage_options,
batch_size=self.batch_size,
max_hierarchy_depth=self.max_hierarchy_depth,
allow_schema_evolution=self.allow_schema_evolution,
_path_prefix=self._path_prefix + path_components,
validated_components: list[str] = []
for component in path_components:
if not isinstance(component, str):
raise TypeError(
f"DeltaTableDatabase.at() path components must be str, got {type(component)!r}"
)
if component == "" or "/" in component or "\\" in component or "\0" in component:
raise ValueError(
f"Invalid path component {component!r}: must be non-empty and must not contain '/', '\\\\', or NUL"
)
if component in (".", ".."):
raise ValueError(
f"Invalid path component {component!r}: '.' and '..' are not allowed"
)
validated_components.append(self._sanitize_path_component(component))
return DeltaTableDatabase(
base_path=self._root_uri,
storage_options=self._storage_options,
batch_size=self.batch_size,
max_hierarchy_depth=self.max_hierarchy_depth,
allow_schema_evolution=self.allow_schema_evolution,
_path_prefix=self._path_prefix + tuple(validated_components),

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in fa61441. at() now validates each component before constructing the new instance: rejects non-str, empty string, '.'/'..', and the filesystem-unsafe char set (/, \\, *, ?, ", <, >, |, NUL). I validated but did not pre-sanitize what gets stored in _path_prefix_get_table_uri already applies _sanitize_path_component at use time, keeping base_path predictable across platforms. Covered by test_at_rejects_slash_in_component, test_at_rejects_dotdot_component, test_at_rejects_empty_component, test_at_rejects_non_str_component in TestAtMethod.

)

def list_sources(self) -> list[tuple[str, ...]]:
"""
List all record paths that contain a valid Delta table under base_path.
Expand Down Expand Up @@ -972,5 +1037,10 @@ def _scan(current_path: Path, path_components: tuple[str, ...]) -> None:
except deltalake.exceptions.TableNotFoundError:
_scan(item, components)

_scan(self.base_path, ())
# Build the effective scoped root directory
scoped_root = self._local_root
for component in self._path_prefix:
scoped_root = scoped_root / self._sanitize_path_component(component)

_scan(scoped_root, ())
return sources
63 changes: 56 additions & 7 deletions src/orcapod/databases/in_memory_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,36 @@ class InMemoryArrowDatabase:

RECORD_ID_COLUMN = "__record_id"

def __init__(self, max_hierarchy_depth: int = 10):
def __init__(
self,
max_hierarchy_depth: int = 10,
_path_prefix: tuple[str, ...] = (),
_shared_tables: "dict[str, pa.Table] | None" = None,
_shared_pending_batches: "dict[str, pa.Table] | None" = None,
_shared_pending_record_ids: "dict[str, set[str]] | None" = None,
):
self._path_prefix = _path_prefix
self.max_hierarchy_depth = max_hierarchy_depth
self._tables: dict[str, pa.Table] = {}
self._pending_batches: dict[str, pa.Table] = {}
self._pending_record_ids: dict[str, set[str]] = defaultdict(set)
self._tables: dict[str, pa.Table] = _shared_tables if _shared_tables is not None else {}
self._pending_batches: dict[str, pa.Table] = _shared_pending_batches if _shared_pending_batches is not None else {}
self._pending_record_ids: dict[str, set[str]] = _shared_pending_record_ids if _shared_pending_record_ids is not None else defaultdict(set)

# ------------------------------------------------------------------
# Path helpers
# ------------------------------------------------------------------

def _get_record_key(self, record_path: tuple[str, ...]) -> str:
return "/".join(record_path)
return "/".join(self._path_prefix + record_path)

def _validate_record_path(self, record_path: tuple[str, ...]) -> None:
if not record_path:
raise ValueError("record_path cannot be empty")

if len(record_path) > self.max_hierarchy_depth:
if len(self._path_prefix) + len(record_path) > self.max_hierarchy_depth:
raise ValueError(
f"record_path depth {len(record_path)} exceeds maximum {self.max_hierarchy_depth}"
f"record_path depth {len(record_path)} exceeds maximum "
f"{self.max_hierarchy_depth - len(self._path_prefix)} "
f"(base_path uses {len(self._path_prefix)} components)"
)

# Only restrict characters that break the "/".join(record_path) key scheme.
Expand Down Expand Up @@ -248,6 +258,43 @@ def flush(self) -> None:
kept = committed.filter(mask)
self._tables[record_key] = pa.concat_tables([kept, pending])

# ------------------------------------------------------------------
# Path scoping
# ------------------------------------------------------------------

@property
def base_path(self) -> tuple[str, ...]:
"""The current relative root of this database view (always () for root instances)."""
return self._path_prefix

def at(self, *path_components: str) -> "InMemoryArrowDatabase":
"""Return a new InMemoryArrowDatabase scoped to the given sub-path.

The returned instance shares the underlying storage dicts (_tables,
_pending_batches, _pending_record_ids) by reference, so writes
through any view are visible to all views of the same root database.

Raises:
ValueError: If any path component is empty, not a str, or contains
``'/'`` or ``'\\0'`` (which would corrupt the ``'/'``-separated
record key scheme).
"""
for i, component in enumerate(path_components):
if not component or not isinstance(component, str):
raise ValueError(f"at() path component {i} is invalid: {repr(component)}")
if "/" in component or "\0" in component:
raise ValueError(
f"at() path component {repr(component)} contains an invalid character "
"('/' or '\\0')"
)
return InMemoryArrowDatabase(
max_hierarchy_depth=self.max_hierarchy_depth,
_path_prefix=self._path_prefix + path_components,
_shared_tables=self._tables,
_shared_pending_batches=self._pending_batches,
_shared_pending_record_ids=self._pending_record_ids,
)
Comment on lines +265 to +296
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InMemoryArrowDatabase.at() builds _path_prefix directly from path_components without validating them. Because _get_record_key uses '/'.join(...) and flush reconstructs paths by splitting on '/', any path component containing '/' or '\0' will break key round-tripping and can corrupt reads/writes across views. Consider validating path_components (same checks as _validate_record_path applies to record_path components) before constructing the scoped view.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in fa61441. at() now validates each component before building _path_prefix: rejects empty strings, non-str values, and components containing '/' or '\\0'. These are the same characters rejected by _validate_record_path for individual components. Covered by test_at_rejects_slash_in_component, test_at_rejects_null_in_component, test_at_rejects_empty_component in TestAtMethod.


# ------------------------------------------------------------------
# Read helpers
# ------------------------------------------------------------------
Expand Down Expand Up @@ -338,6 +385,7 @@ def to_config(self) -> dict[str, Any]:
"""Serialize database configuration to a JSON-compatible dict."""
return {
"type": "in_memory",
"base_path": list(self._path_prefix),
"max_hierarchy_depth": self.max_hierarchy_depth,
}

Expand All @@ -346,6 +394,7 @@ def from_config(cls, config: dict[str, Any]) -> "InMemoryArrowDatabase":
"""Reconstruct an InMemoryArrowDatabase from a config dict."""
return cls(
max_hierarchy_depth=config.get("max_hierarchy_depth", 10),
_path_prefix=tuple(config.get("base_path", [])),
)

def get_records_with_column_value(
Expand Down
Loading
Loading