Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] fix HNSW param defaults in new configuration logic & require batch_size < sync_threshold #2526

Merged
merged 8 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions chromadb/api/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class StaticParameterError(Exception):
pass


class InvalidConfigurationError(ValueError):
"""Represents an error that occurs when a configuration is invalid."""

pass


ParameterValue = Union[str, int, float, bool, "ConfigurationInternal"]


Expand Down Expand Up @@ -110,8 +116,8 @@ def __init__(self, parameters: Optional[List[ConfigurationParameter]] = None):
if not isinstance(parameter.value, type(definition.default_value)):
raise ValueError(f"Invalid parameter value: {parameter.value}")

validator = definition.validator
if not validator(parameter.value):
parameter_validator = definition.validator
if not parameter_validator(parameter.value):
raise ValueError(f"Invalid parameter value: {parameter.value}")
self.parameter_map[parameter.name] = parameter
# Apply the defaults for any missing parameters
Expand All @@ -121,6 +127,8 @@ def __init__(self, parameters: Optional[List[ConfigurationParameter]] = None):
name=name, value=definition.default_value
)

self.configuration_validator()

def __repr__(self) -> str:
return f"Configuration({self.parameter_map.values()})"

Expand All @@ -129,6 +137,14 @@ def __eq__(self, __value: object) -> bool:
return NotImplemented
return self.parameter_map == __value.parameter_map

@abstractmethod
def configuration_validator(self) -> None:
"""Perform custom validation when parameters are dependent on each other.

Raises an InvalidConfigurationError if the configuration is invalid.
"""
pass

def get_parameters(self) -> List[ConfigurationParameter]:
"""Returns the parameters of the configuration."""
return list(self.parameter_map.values())
Expand Down Expand Up @@ -247,16 +263,30 @@ class HNSWConfigurationInternal(ConfigurationInternal):
name="batch_size",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=True,
default_value=1000,
default_value=100,
),
"sync_threshold": ConfigurationDefinition(
name="sync_threshold",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=True,
default_value=100,
default_value=1000,
),
}

@override
def configuration_validator(self) -> None:
batch_size = self.parameter_map.get("batch_size")
sync_threshold = self.parameter_map.get("sync_threshold")

if (
batch_size
and sync_threshold
and cast(int, batch_size.value) > cast(int, sync_threshold.value)
):
raise InvalidConfigurationError(
"batch_size must be less than or equal to sync_threshold"
)

@classmethod
def from_legacy_params(cls, params: Dict[str, Any]) -> Self:
"""Returns an HNSWConfiguration from a metadata dict containing legacy HNSW parameters. Used for migration."""
Expand Down Expand Up @@ -302,8 +332,8 @@ def __init__(
num_threads: int = cpu_count(),
M: int = 16,
resize_factor: float = 1.2,
batch_size: int = 1000,
sync_threshold: int = 100,
batch_size: int = 100,
sync_threshold: int = 1000,
atroyn marked this conversation as resolved.
Show resolved Hide resolved
):
parameters = [
ConfigurationParameter(name="space", value=space),
Expand Down Expand Up @@ -336,6 +366,10 @@ class CollectionConfigurationInternal(ConfigurationInternal):
),
}

@override
def configuration_validator(self) -> None:
pass


# This is the user-facing interface for HNSW index configuration parameters.
# Internally, we pass around HNSWConfigurationInternal objects, which perform
Expand Down
56 changes: 54 additions & 2 deletions chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Optional, Sequence, Any, Tuple, cast, Dict, Union, Set
from uuid import UUID
from overrides import override
Expand All @@ -8,6 +9,7 @@
CollectionConfigurationInternal,
ConfigurationParameter,
HNSWConfigurationInternal,
InvalidConfigurationError,
)
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System
from chromadb.db.base import (
Expand Down Expand Up @@ -435,8 +437,8 @@ def get_collections(
metadata = self._metadata_from_rows(rows)
dimension = int(rows[0][3]) if rows[0][3] else None
if rows[0][2] is not None:
configuration = CollectionConfigurationInternal.from_json_str(
rows[0][2]
configuration = self._load_config_from_json_str_and_migrate(
str(collection_id), rows[0][2]
)
else:
# 07/2024: This is a legacy case where we don't have a collection
Expand Down Expand Up @@ -764,6 +766,56 @@ def _insert_metadata(
if sql:
cur.execute(sql, params)

def _load_config_from_json_str_and_migrate(
self, collection_id: str, json_str: str
) -> CollectionConfigurationInternal:
try:
config_json = json.loads(json_str)
except json.JSONDecodeError:
raise ValueError(
f"Unable to decode configuration from JSON string: {json_str}"
)

try:
return CollectionConfigurationInternal.from_json_str(json_str)
except InvalidConfigurationError as error:
# 07/17/2024: the initial migration from the legacy metadata-based config to the new sysdb-based config had a bug where the batch_size and sync_threshold were swapped. Along with this migration, a validator was added to HNSWConfigurationInternal to ensure that batch_size <= sync_threshold.
hnsw_configuration = config_json.get("hnsw_configuration")
if hnsw_configuration:
batch_size = hnsw_configuration.get("batch_size")
sync_threshold = hnsw_configuration.get("sync_threshold")

if batch_size and sync_threshold and batch_size > sync_threshold:
# Allow new defaults to be set
hnsw_configuration = {
k: v
for k, v in hnsw_configuration.items()
if k not in ["batch_size", "sync_threshold"]
}
config_json.update({"hnsw_configuration": hnsw_configuration})

configuration = CollectionConfigurationInternal.from_json(
config_json
)

collections_t = Table("collections")
q = (
self.querybuilder()
.update(collections_t)
.set(
collections_t.config_json_str,
ParameterValue(configuration.to_json_str()),
)
.where(collections_t.id == ParameterValue(collection_id))
)
sql, params = get_sql(q, self.parameter_format())
with self.tx() as cur:
cur.execute(sql, params)

return configuration

raise error

def _insert_config_from_legacy_params(
self, collection_id: Any, metadata: Optional[Metadata]
) -> CollectionConfigurationInternal:
Expand Down
32 changes: 32 additions & 0 deletions chromadb/test/configurations/test_configurations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from overrides import overrides
import pytest
from chromadb.api.configuration import (
ConfigurationInternal,
ConfigurationDefinition,
InvalidConfigurationError,
StaticParameterError,
ConfigurationParameter,
HNSWConfiguration,
)


Expand All @@ -23,6 +26,10 @@ class TestConfiguration(ConfigurationInternal):
),
}

@overrides
def configuration_validator(self) -> None:
pass


def test_default_values() -> None:
default_test_configuration = TestConfiguration()
Expand Down Expand Up @@ -76,3 +83,28 @@ def test_validation() -> None:
]
with pytest.raises(ValueError):
TestConfiguration(parameters=invalid_parameter_names)


def test_configuration_validation() -> None:
class FooConfiguration(ConfigurationInternal):
definitions = {
"foo": ConfigurationDefinition(
name="foo",
validator=lambda value: isinstance(value, str),
is_static=False,
default_value="default",
),
}

@overrides
def configuration_validator(self) -> None:
if self.parameter_map.get("foo") != "bar":
raise InvalidConfigurationError("foo must be 'bar'")

with pytest.raises(ValueError, match="foo must be 'bar'"):
FooConfiguration(parameters=[ConfigurationParameter(name="foo", value="baz")])


def test_hnsw_validation() -> None:
codetheweb marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(ValueError, match="must be less than or equal"):
HNSWConfiguration(batch_size=500, sync_threshold=100)
11 changes: 8 additions & 3 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,17 @@ def collections(
metadata = {}
metadata.update(test_hnsw_config)
if use_persistent_hnsw_params:
metadata["hnsw:batch_size"] = draw(
st.integers(min_value=3, max_value=max_hnsw_batch_size)
)
metadata["hnsw:sync_threshold"] = draw(
st.integers(min_value=3, max_value=max_hnsw_sync_threshold)
)
metadata["hnsw:batch_size"] = draw(
st.integers(
min_value=3,
max_value=min(
[metadata["hnsw:sync_threshold"], max_hnsw_batch_size]
),
)
)
# Sometimes, select a space at random
if draw(st.booleans()):
# TODO: pull the distance functions from a source of truth that lives not
Expand Down
Loading