Skip to content

Commit

Permalink
Add more configurations for hnswlib
Browse files Browse the repository at this point in the history
Signed-off-by: anna-charlotte <charlotte.gerhaher@jina.ai>
  • Loading branch information
anna-charlotte committed Apr 27, 2023
1 parent de262f9 commit 30456bc
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 4 deletions.
54 changes: 53 additions & 1 deletion langchain/vectorstores/hnsw_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def __init__(
work_dir: str,
n_dim: int,
dist_metric: str = "cosine",
max_elements: int = 1024,
index: bool = True,
ef_construction: int = 200,
ef: int = 10,
M: int = 16,
allow_replace_deleted: bool = True,
num_threads: int = 1,
) -> None:
"""Initialize HnswLib store.
Expand All @@ -33,6 +40,19 @@ def __init__(
n_dim (int): dimension of an embedding.
dist_metric (str): Distance metric for HnswLib can be one of: "cosine",
"ip", and "l2". Defaults to "cosine".
max_elements (int): Maximum number of vectors that can be stored.
Defaults to 1024.
index (bool): Whether an index should be built for this field.
Defaults to True.
ef_construction (int): defines a construction time/accuracy trade-off.
Defaults to 200.
ef (int): parameter controlling query time/accuracy trade-off.
Defaults to 10.
M (int): parameter that defines the maximum number of outgoing
connections in the graph. Defaults to 16.
allow_replace_deleted (bool): Enables replacing of deleted elements
with new added ones. Defaults to True.
num_threads (int): Sets the number of cpu threads to use. Defaults to 1.
"""
_check_docarray_import()
from docarray.index import HnswDocumentIndex
Expand All @@ -45,7 +65,19 @@ def __init__(
"Please install it with `pip install \"langchain[hnswlib]\"`."
)

doc_cls = self._get_doc_cls({"dim": n_dim, "space": dist_metric})
doc_cls = self._get_doc_cls(
{
"dim": n_dim,
"space": dist_metric,
"max_elements": max_elements,
"index": index,
"ef_construction": ef_construction,
"ef": ef,
"M": M,
"allow_replace_deleted": allow_replace_deleted,
"num_threads": num_threads,
}
)
doc_index = HnswDocumentIndex[doc_cls](work_dir=work_dir)
super().__init__(doc_index, embedding)

Expand All @@ -58,6 +90,13 @@ def from_texts(
work_dir: str = None,
n_dim: int = None,
dist_metric: str = "cosine",
max_elements: int = 1024,
index: bool = True,
ef_construction: int = 200,
ef: int = 10,
M: int = 16,
allow_replace_deleted: bool = True,
num_threads: int = 1,
) -> HnswLib:
"""Create an HnswLib store and insert data.
Expand All @@ -70,6 +109,19 @@ def from_texts(
n_dim (int): dimension of an embedding.
dist_metric (str): Distance metric for HnswLib can be one of: "cosine",
"ip", and "l2". Defaults to "cosine".
max_elements (int): Maximum number of vectors that can be stored.
Defaults to 1024.
index (bool): Whether an index should be built for this field.
Defaults to True.
ef_construction (int): defines a construction time/accuracy trade-off.
Defaults to 200.
ef (int): parameter controlling query time/accuracy trade-off.
Defaults to 10.
M (int): parameter that defines the maximum number of outgoing
connections in the graph. Defaults to 16.
allow_replace_deleted (bool): Enables replacing of deleted elements
with new added ones. Defaults to True.
num_threads (int): Sets the number of cpu threads to use. Defaults to 1.
Returns:
HnswLib Vector Store
Expand Down
51 changes: 48 additions & 3 deletions tests/integration_tests/vectorstores/test_hnsw_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_hnswlib_vec_store_add_texts(tmp_path) -> None:
assert docsearch.doc_index.num_docs() == 3


@pytest.mark.parametrize('metric', ['cosine', 'ip', 'l2'])
@pytest.mark.parametrize('metric', ['cosine', 'l2'])
def test_sim_search(metric, tmp_path) -> None:
"""Test end to end construction and simple similarity search."""
texts = ["foo", "bar", "baz"]
Expand All @@ -45,12 +45,35 @@ def test_sim_search(metric, tmp_path) -> None:
FakeEmbeddings(),
work_dir=str(tmp_path),
n_dim=10,
dist_metric=metric,
)
output = hnswlib_vec_store.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]


@pytest.mark.parametrize('metric', ['cosine', 'l2'])
def test_sim_search_all_configurations(metric, tmp_path) -> None:
"""Test end to end construction and simple similarity search."""
texts = ["foo", "bar", "baz"]
hnswlib_vec_store = HnswLib.from_texts(
texts,
FakeEmbeddings(),
work_dir=str(tmp_path),
dist_metric=metric,
n_dim=10,
max_elements=8,
index=False,
ef_construction=300,
ef=20,
M=8,
allow_replace_deleted=False,
num_threads=2,
)
output = hnswlib_vec_store.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]


@pytest.mark.parametrize('metric', ['cosine', 'ip', 'l2'])
@pytest.mark.parametrize('metric', ['cosine', 'l2'])
def test_sim_search_by_vector(metric, tmp_path) -> None:
"""Test end to end construction and similarity search by vector."""
texts = ["foo", "bar", "baz"]
Expand All @@ -59,14 +82,15 @@ def test_sim_search_by_vector(metric, tmp_path) -> None:
FakeEmbeddings(),
work_dir=str(tmp_path),
n_dim=10,
dist_metric=metric,
)
embedding = [1.0] * 10
output = hnswlib_vec_store.similarity_search_by_vector(embedding, k=1)

assert output == [Document(page_content="bar")]


@pytest.mark.parametrize('metric', ['cosine', 'ip', 'l2'])
@pytest.mark.parametrize('metric', ['cosine', 'l2'])
def test_sim_search_with_score(metric, tmp_path) -> None:
"""Test end to end construction and similarity search with score."""
texts = ["foo", "bar", "baz"]
Expand All @@ -75,6 +99,7 @@ def test_sim_search_with_score(metric, tmp_path) -> None:
FakeEmbeddings(),
work_dir=str(tmp_path),
n_dim=10,
dist_metric=metric,
)
output = hnswlib_vec_store.similarity_search_with_score("foo", k=1)
assert len(output) == 1
Expand All @@ -84,6 +109,26 @@ def test_sim_search_with_score(metric, tmp_path) -> None:
assert np.isclose(out_score, 0.0, atol=1.e-6)


def test_sim_search_with_score_for_ip_metric(tmp_path) -> None:
"""
Test end to end construction and similarity search with score for ip
(inner-product) metric.
"""
texts = ["foo", "bar", "baz"]
hnswlib_vec_store = HnswLib.from_texts(
texts,
FakeEmbeddings(),
work_dir=str(tmp_path),
n_dim=10,
dist_metric='ip',
)
output = hnswlib_vec_store.similarity_search_with_score("foo", k=3)
assert len(output) == 3

for result in output:
assert result[1] == -8.0


@pytest.mark.parametrize('metric', ['cosine', 'l2'])
def test_max_marginal_relevance_search(metric, tmp_path) -> None:
"""Test MRR search."""
Expand Down

0 comments on commit 30456bc

Please sign in to comment.