Skip to content

Commit

Permalink
Initial impl for branch (#90)
Browse files Browse the repository at this point in the history
usage
ds.add_branch("a")
ds.set_current_branch("a")

then all write/read will go to this branch.

---------

Co-authored-by: coufon <zhoufang@google.com>
  • Loading branch information
huan233usc and coufon authored Feb 18, 2024
1 parent 0c9de6c commit f911cc7
Show file tree
Hide file tree
Showing 9 changed files with 341 additions and 106 deletions.
16 changes: 14 additions & 2 deletions python/src/space/core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,25 @@ def record_fields(self) -> List[str]:
return self._storage.record_fields

def add_tag(self, tag: str, snapshot_id: Optional[int] = None):
"""Add yag to a snapshot."""
"""Add tag to a dataset."""
self._storage.add_tag(tag, snapshot_id)

def remove_tag(self, tag: str):
"""Remove tag from a snapshot."""
"""Remove tag from a dataset."""
self._storage.remove_tag(tag)

def add_branch(self, branch: str):
"""Add branch to a dataset."""
self._storage.add_branch(branch)

def remove_branch(self, branch: str):
"""Remove branch for a dataset."""
self._storage.remove_branch(branch)

def set_current_branch(self, branch: str):
"""Set current branch for the dataset."""
self._storage.set_current_branch(branch)

def local(self, file_options: Optional[FileOptions] = None) -> LocalRunner:
"""Get a runner that runs operations locally."""
return LocalRunner(self._storage, file_options)
Expand Down
3 changes: 3 additions & 0 deletions python/src/space/core/proto/metadata_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ global___StorageStatistics = StorageStatistics
class ChangeLog(google.protobuf.message.Message):
"""Change log stores changes made by a snapshot.
NEXT_ID: 3
TODO: to replace RowBitmap list by runtime.FileSet (not backward
compatible).
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor
Expand All @@ -391,6 +393,7 @@ global___ChangeLog = ChangeLog
@typing_extensions.final
class RowBitmap(google.protobuf.message.Message):
"""Mark rows in a file by bitmap.
TODO: to replace it by runtime.DataFile (not backward compatible).
NEXT_ID: 5
"""

Expand Down
133 changes: 106 additions & 27 deletions python/src/space/core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@

# Initial snapshot ID.
_INIT_SNAPSHOT_ID = 0
# Name for the main branch, by default the read write are using this branch.
_MAIN_BRANCH = "main"
# Sets of reference that could not be added as branches or tags by user.
_RESERVED_REFERENCE = [_MAIN_BRANCH]


# pylint: disable=too-many-public-methods
Expand All @@ -59,8 +63,11 @@ class Storage(paths.StoragePathsMixin):
Not thread safe.
"""

def __init__(self, location: str, metadata_file: str,
metadata: meta.StorageMetadata):
def __init__(self,
location: str,
metadata_file: str,
metadata: meta.StorageMetadata,
current_branch: Optional[str] = None):
super().__init__(location)
self._fs = create_fs(location)
self._metadata = metadata
Expand All @@ -77,12 +84,21 @@ def __init__(self, location: str, metadata_file: str,
self._physical_schema)

self._primary_keys = set(self._metadata.schema.primary_keys)
self._current_branch = current_branch or _MAIN_BRANCH
self._max_snapshot_id = max(
[ref.snapshot_id for ref in self._metadata.refs.values()] +
[self._metadata.current_snapshot_id])

@property
def metadata(self) -> meta.StorageMetadata:
"""Return the storage metadata."""
return self._metadata

@property
def current_branch(self) -> str:
"""Return the current branch."""
return self._current_branch

@property
def primary_keys(self) -> List[str]:
"""Return the storage primary keys."""
Expand All @@ -103,6 +119,13 @@ def physical_schema(self) -> pa.Schema:
"""Return the physcal schema that uses reference for record fields."""
return self._physical_schema

def current_snapshot_id(self, branch: str) -> int:
"""Returns the snapshot id for the current branch."""
if branch != _MAIN_BRANCH:
return self.lookup_reference(branch).snapshot_id

return self.metadata.current_snapshot_id

def serializer(self) -> DictSerializer:
"""Return a serializer (deserializer) for the dataset."""
return DictSerializer.create(self.logical_schema)
Expand All @@ -112,7 +135,10 @@ def snapshot(self, snapshot_id: Optional[int] = None) -> meta.Snapshot:
if not specified.
"""
if snapshot_id is None:
snapshot_id = self._metadata.current_snapshot_id
if self.current_branch == _MAIN_BRANCH:
snapshot_id = self._metadata.current_snapshot_id
else:
snapshot_id = self.version_to_snapshot_id(self.current_branch)

if snapshot_id in self._metadata.snapshots:
return self._metadata.snapshots[snapshot_id]
Expand Down Expand Up @@ -185,7 +211,8 @@ def reload(self) -> bool:
return False

metadata = _read_metadata(self._fs, self._location, entry_point)
self.__init__(self.location, entry_point.metadata_file, metadata) # type: ignore[misc] # pylint: disable=unnecessary-dunder-call
self.__init__( # type: ignore[misc] # pylint: disable=unnecessary-dunder-call
self.location, entry_point.metadata_file, metadata, self.current_branch)
logging.info(
f"Storage reloaded to snapshot: {self._metadata.current_snapshot_id}")
return True
Expand All @@ -199,9 +226,9 @@ def version_to_snapshot_id(self, version: Version) -> int:
if isinstance(version, int):
return version

return self._lookup_reference(version).snapshot_id
return self.lookup_reference(version).snapshot_id

def _lookup_reference(self, tag_or_branch: str) -> meta.SnapshotReference:
def lookup_reference(self, tag_or_branch: str) -> meta.SnapshotReference:
"""Lookup a snapshot reference."""
if tag_or_branch in self._metadata.refs:
return self._metadata.refs[tag_or_branch]
Expand All @@ -210,58 +237,107 @@ def _lookup_reference(self, tag_or_branch: str) -> meta.SnapshotReference:

def add_tag(self, tag: str, snapshot_id: Optional[int] = None) -> None:
"""Add tag to a snapshot"""
self._add_reference(tag, meta.SnapshotReference.TAG, snapshot_id)

def add_branch(self, branch: str) -> None:
"""Add branch to a snapshot"""
self._add_reference(branch, meta.SnapshotReference.BRANCH, None)

def set_current_branch(self, branch: str) -> None:
"""Set current branch for the snapshot."""
if branch != _MAIN_BRANCH:
snapshot_ref = self.lookup_reference(branch)
if snapshot_ref.type != meta.SnapshotReference.BRANCH:
raise errors.UserInputError("{branch} is not a branch.")

self._current_branch = branch

def _add_reference(self,
ref_name: str,
ref_type: meta.SnapshotReference.ReferenceType.ValueType,
snapshot_id: Optional[int] = None) -> None:
"""Add reference to a snapshot"""
if snapshot_id is None:
snapshot_id = self._metadata.current_snapshot_id

if snapshot_id not in self._metadata.snapshots:
raise errors.VersionNotFoundError(f"Snapshot {snapshot_id} is not found")

if len(tag) == 0:
raise errors.UserInputError("Tag cannot be empty")
if not ref_name:
raise errors.UserInputError("Reference name cannot be empty.")

if ref_name in _RESERVED_REFERENCE:
raise errors.UserInputError("{ref_name} is reserved.")

if tag in self._metadata.refs:
raise errors.VersionAlreadyExistError(f"Reference {tag} already exist")
if ref_name in self._metadata.refs:
raise errors.VersionAlreadyExistError(
f"Reference {ref_name} already exist")

new_metadata = meta.StorageMetadata()
new_metadata.CopyFrom(self._metadata)
tag_ref = meta.SnapshotReference(reference_name=tag,
snapshot_id=snapshot_id,
type=meta.SnapshotReference.TAG)
new_metadata.refs[tag].CopyFrom(tag_ref)
ref = meta.SnapshotReference(reference_name=ref_name,
snapshot_id=snapshot_id,
type=ref_type)
new_metadata.refs[ref_name].CopyFrom(ref)
new_metadata_path = self.new_metadata_path()
self._write_metadata(new_metadata_path, new_metadata)
self._metadata = new_metadata
self._metadata_file = new_metadata_path

def remove_tag(self, tag: str) -> None:
"""Remove tag from metadata"""
if (tag not in self._metadata.refs or
self._metadata.refs[tag].type != meta.SnapshotReference.TAG):
raise errors.VersionNotFoundError(f"Tag {tag} is not found")
self._remove_reference(tag, meta.SnapshotReference.TAG)

def remove_branch(self, branch: str) -> None:
"""Remove branch from metadata"""
if branch == self._current_branch:
raise errors.UserInputError("Cannot remove the current branch.")

self._remove_reference(branch, meta.SnapshotReference.BRANCH)

def _remove_reference(
self, ref_name: str,
ref_type: meta.SnapshotReference.ReferenceType.ValueType) -> None:
if (ref_name not in self._metadata.refs or
self._metadata.refs[ref_name].type != ref_type):
raise errors.VersionNotFoundError(
f"Reference {ref_name} is not found or has a wrong type "
"(tag vs branch)")

new_metadata = meta.StorageMetadata()
new_metadata.CopyFrom(self._metadata)
del new_metadata.refs[tag]
del new_metadata.refs[ref_name]
new_metadata_path = self.new_metadata_path()
self._write_metadata(new_metadata_path, new_metadata)
self._metadata = new_metadata
self._metadata_file = new_metadata_path

def commit(self, patch: rt.Patch) -> None:
def commit(self, patch: rt.Patch, branch: str) -> None:
"""Commit changes to the storage.
TODO: only support a single writer; to ensure atomicity in commit by
concurrent writers.
Args:
patch: a patch describing changes made to the storage.
branch: the branch this commit is writing to.
"""
current_snapshot = self.snapshot()

new_metadata = meta.StorageMetadata()
new_metadata.CopyFrom(self._metadata)
new_snapshot_id = self._next_snapshot_id()
new_metadata.current_snapshot_id = new_snapshot_id
if branch != _MAIN_BRANCH:
branch_snapshot = self.lookup_reference(branch)
# To block the case delete branch and add a tag during commit
# TODO: move this check out of commit()
if branch_snapshot.type != meta.SnapshotReference.BRANCH:
raise errors.UserInputError("Branch {branch} is no longer exists.")

new_metadata.refs[branch].snapshot_id = new_snapshot_id
current_snapshot = self.snapshot(branch_snapshot.snapshot_id)
else:
new_metadata.current_snapshot_id = new_snapshot_id
current_snapshot = self.snapshot(self._metadata.current_snapshot_id)

new_metadata.last_update_time.CopyFrom(proto_now())
new_metadata_path = self.new_metadata_path()

Expand Down Expand Up @@ -417,7 +493,8 @@ def _initialize_files(self, metadata_path: str) -> None:
raise errors.StorageExistError(str(e)) from None

def _next_snapshot_id(self) -> int:
return self._metadata.current_snapshot_id + 1
self._max_snapshot_id = self._max_snapshot_id + 1
return self._max_snapshot_id

def _write_metadata(
self,
Expand Down Expand Up @@ -473,7 +550,7 @@ def __init__(self, storage: Storage):
self._txn_id = uuid_()
# The storage snapshot ID when the transaction starts.
self._snapshot_id: Optional[int] = None

self._branch = storage.current_branch
self._result: Optional[JobResult] = None

def commit(self, patch: Optional[rt.Patch]) -> None:
Expand All @@ -483,7 +560,9 @@ def commit(self, patch: Optional[rt.Patch]) -> None:
# Check that no other commit has taken place.
assert self._snapshot_id is not None
self._storage.reload()
if self._snapshot_id != self._storage.metadata.current_snapshot_id:
current_snapshot_id = self._storage.current_snapshot_id(self._branch)

if self._snapshot_id != current_snapshot_id:
self._result = JobResult(
JobResult.State.FAILED, None,
"Abort commit because the storage has been modified.")
Expand All @@ -493,7 +572,7 @@ def commit(self, patch: Optional[rt.Patch]) -> None:
self._result = JobResult(JobResult.State.SKIPPED)
return

self._storage.commit(patch)
self._storage.commit(patch, self._branch)
self._result = JobResult(JobResult.State.SUCCEEDED,
patch.storage_statistics_update)

Expand All @@ -509,7 +588,7 @@ def __enter__(self) -> Transaction:
# All mutations start with a transaction, so storage is always reloaded for
# mutations.
self._storage.reload()
self._snapshot_id = self._storage.metadata.current_snapshot_id
self._snapshot_id = self._storage.current_snapshot_id(self._branch)
logging.info(f"Start transaction {self._txn_id}")
return self

Expand Down
2 changes: 1 addition & 1 deletion python/tests/core/loaders/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ def test_append_parquet(self, tmp_path):
]).combine_chunks().sort_by("int64")

assert not ds.index_files(version="empty")
assert ds.index_files(version="after_append") == [file0, file1]
assert sorted(ds.index_files(version="after_append")) == [file0, file1]
4 changes: 2 additions & 2 deletions python/tests/core/ops/test_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_delete_all_types(self, tmp_path, all_types_schema,
for batch in input_data:
append_op.write(batch)

storage.commit(append_op.finish())
storage.commit(append_op.finish(), "main")
old_data_files = storage.data_files()

delete_op = FileSetDeleteOp(
Expand All @@ -54,7 +54,7 @@ def test_delete_all_types(self, tmp_path, all_types_schema,
_default_file_options)
patch = delete_op.delete()
assert patch is not None
storage.commit(patch)
storage.commit(patch, "main")

# Verify storage metadata after patch.
new_data_files = storage.data_files()
Expand Down
4 changes: 2 additions & 2 deletions python/tests/core/ops/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_read_all_types(self, tmp_path, all_types_schema,
for batch in input_data:
append_op.write(batch)

storage.commit(append_op.finish())
storage.commit(append_op.finish(), "main")

read_op = FileSetReadOp(str(location), storage.metadata,
storage.data_files())
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_read_with_record_filters(self, tmp_path, record_fields_schema,
for batch in input_data:
append_op.write(batch)

storage.commit(append_op.finish())
storage.commit(append_op.finish(), "main")
data_files = storage.data_files()

read_op = FileSetReadOp(str(location), storage.metadata, data_files)
Expand Down
Loading

0 comments on commit f911cc7

Please sign in to comment.