Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] simplify logic for when to persist index changes #2539

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 13 additions & 47 deletions chromadb/segment/impl/vector/local_persistent_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
from overrides import override
import pickle
from typing import Any, Dict, List, Optional, Sequence, Set, cast
from typing import Dict, List, Optional, Sequence, Set, cast
from chromadb.config import System
from chromadb.segment.impl.vector.batch import Batch
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
Expand Down Expand Up @@ -40,9 +40,6 @@ class PersistentData:
"""Stores the data and metadata needed for a PersistentLocalHnswSegment"""

dimensionality: Optional[int]
total_elements_added: int
total_elements_updated: int
total_invalid_operations: int
max_seq_id: SeqId

id_to_label: Dict[str, int]
Expand All @@ -52,29 +49,17 @@ class PersistentData:
def __init__(
self,
dimensionality: Optional[int],
total_elements_added: int,
total_elements_updated: int,
total_invalid_operations: int,
max_seq_id: int,
id_to_label: Dict[str, int],
label_to_id: Dict[int, str],
id_to_seq_id: Dict[str, SeqId],
):
self.dimensionality = dimensionality
self.total_elements_added = total_elements_added
self.total_elements_updated = total_elements_updated
self.total_invalid_operations = total_invalid_operations
self.max_seq_id = max_seq_id
self.id_to_label = id_to_label
self.label_to_id = label_to_id
self.id_to_seq_id = id_to_seq_id

def __setstate__(self, state: Any) -> None:
# Fields were added after the initial implementation
self.total_elements_updated = 0
self.total_invalid_operations = 0
self.__dict__.update(state)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested with pickle and it seems to be fine if you remove a field from a class def


@staticmethod
def load_from_file(filename: str) -> "PersistentData":
"""Load persistent data from a file"""
Expand All @@ -100,6 +85,9 @@ class PersistentLocalHnswSegment(LocalHnswSegment):

_opentelemtry_client: OpenTelemetryClient

_num_log_records_since_last_batch: int = 0
_num_log_records_since_last_persist: int = 0

def __init__(self, system: System, segment: Segment):
super().__init__(system, segment)

Expand All @@ -120,7 +108,6 @@ def __init__(self, system: System, segment: Segment):
self._get_metadata_file()
)
self._dimensionality = self._persist_data.dimensionality
self._total_elements_added = self._persist_data.total_elements_added
self._max_seq_id = self._persist_data.max_seq_id
self._id_to_label = self._persist_data.id_to_label
self._label_to_id = self._persist_data.label_to_id
Expand All @@ -132,9 +119,6 @@ def __init__(self, system: System, segment: Segment):
else:
self._persist_data = PersistentData(
self._dimensionality,
self._total_elements_added,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirming this was not used anywhere else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed

Copy link
Contributor Author

@codetheweb codetheweb Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turns out this was the issue, I should have fully reasoned through it :/
self._total_elements_added isn't used in local_persistent_hnsw, but it is used in local_hnsw (the super class) and there's some non-obvious contracts between the two

self._total_elements_updated,
self._total_invalid_operations,
self._max_seq_id,
self._id_to_label,
self._label_to_id,
Expand Down Expand Up @@ -208,8 +192,6 @@ def _persist(self) -> None:

# Persist the metadata
self._persist_data.dimensionality = self._dimensionality
self._persist_data.total_elements_added = self._total_elements_added
self._persist_data.total_elements_updated = self._total_elements_updated
self._persist_data.max_seq_id = self._max_seq_id

# TODO: This should really be stored in sqlite, the index itself, or a better
Expand All @@ -221,30 +203,19 @@ def _persist(self) -> None:
with open(self._get_metadata_file(), "wb") as metadata_file:
pickle.dump(self._persist_data, metadata_file, pickle.HIGHEST_PROTOCOL)

self._num_log_records_since_last_persist = 0

@trace_method(
"PersistentLocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL
)
@override
def _apply_batch(self, batch: Batch) -> None:
super()._apply_batch(batch)
num_elements_added_since_last_persist = (
self._total_elements_added - self._persist_data.total_elements_added
)
num_elements_updated_since_last_persist = (
self._total_elements_updated - self._persist_data.total_elements_updated
)
num_invalid_operations_since_last_persist = (
self._total_invalid_operations - self._persist_data.total_invalid_operations
)

if (
num_elements_added_since_last_persist
+ num_elements_updated_since_last_persist
+ num_invalid_operations_since_last_persist
>= self._sync_threshold
):
if self._num_log_records_since_last_persist >= self._sync_threshold:
self._persist()

self._num_log_records_since_last_batch = 0

@trace_method(
"PersistentLocalHnswSegment._write_records", OpenTelemetryGranularity.ALL
)
Expand All @@ -255,6 +226,9 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
raise RuntimeError("Cannot add embeddings to stopped component")
with WriteRWLock(self._lock):
for record in records:
self._num_log_records_since_last_batch += 1
self._num_log_records_since_last_persist += 1

if record["record"]["embedding"] is not None:
self._ensure_index(len(records), len(record["record"]["embedding"]))
if not self._index_initialized:
Expand Down Expand Up @@ -305,15 +279,7 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
self._curr_batch.apply(record, exists_in_index)
self._brute_force_index.upsert([record])

num_invalid_operations_since_last_persist = (
self._total_invalid_operations
- self._persist_data.total_invalid_operations
)

if (
len(self._curr_batch) + num_invalid_operations_since_last_persist
>= self._batch_size
):
if self._num_log_records_since_last_batch >= self._batch_size:
self._apply_batch(self._curr_batch)
self._curr_batch = Batch()
self._brute_force_index.clear()
Expand Down
Loading