Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support list of strings for Weaviate #1852

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 23 additions & 1 deletion docarray/index/backends/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
TypeVar,
Union,
cast,
get_origin,
get_args,
)

import numpy as np
Expand All @@ -43,7 +45,6 @@
TSchema = TypeVar('TSchema', bound=BaseDoc)
T = TypeVar('T', bound='WeaviateDocumentIndex')


DEFAULT_BATCH_CONFIG = {
"batch_size": 20,
"dynamic": False,
Expand Down Expand Up @@ -210,6 +211,8 @@
self.bytes_columns.append(column_name)
if column_info.db_type == 'number[]':
self.nonembedding_array_columns.append(column_name)
if column_info.db_type == 'text[]':
self.nonembedding_array_columns.append(column_name)

Check warning on line 215 in docarray/index/backends/weaviate.py

View check run for this annotation

Codecov / codecov/patch

docarray/index/backends/weaviate.py#L215

Added line #L215 was not covered by tests
prop = {
"name": column_name
if column_name != 'id'
Expand Down Expand Up @@ -253,6 +256,8 @@
'number': {},
'boolean': {},
'number[]': {},
'int[]': {},
'text[]': {},
'blob': {},
}
)
Expand Down Expand Up @@ -717,6 +722,23 @@
bytes: 'blob',
}

if get_origin(python_type) == list:
py_weaviate_list_type_map = {

Check warning on line 726 in docarray/index/backends/weaviate.py

View check run for this annotation

Codecov / codecov/patch

docarray/index/backends/weaviate.py#L726

Added line #L726 was not covered by tests
int: 'int[]',
float: 'number[]',
str: 'text[]',
}

container_type = None
args = get_args(python_type)
if args:
container_type = args[0]
if (

Check warning on line 736 in docarray/index/backends/weaviate.py

View check run for this annotation

Codecov / codecov/patch

docarray/index/backends/weaviate.py#L732-L736

Added lines #L732 - L736 were not covered by tests
container_type is not None
and container_type in py_weaviate_list_type_map
):
return py_weaviate_list_type_map[container_type]

Check warning on line 740 in docarray/index/backends/weaviate.py

View check run for this annotation

Codecov / codecov/patch

docarray/index/backends/weaviate.py#L740

Added line #L740 was not covered by tests

for py_type, weaviate_type in py_weaviate_type_map.items():
if safe_issubclass(python_type, py_type):
return weaviate_type
Expand Down
6 changes: 3 additions & 3 deletions tests/index/weaviate/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version: '3.8'
version: '3.3'

services:

Expand All @@ -10,7 +10,7 @@ services:
- '8080'
- --scheme
- http
image: semitechnologies/weaviate:1.18.3
image: semitechnologies/weaviate:1.21.1
ports:
- "8080:8080"
restart: on-failure:0
Expand All @@ -24,4 +24,4 @@ services:
LOG_LEVEL: debug # verbose
LOG_FORMAT: text
# LOG_LEVEL: trace # very verbose
GODEBUG: gctrace=1 # make go garbage collector verbose
GODEBUG: gctrace=1 # make go garbage collector verbose
15 changes: 14 additions & 1 deletion tests/index/weaviate/test_index_get_del_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest
from pydantic import Field
from typing import List

from docarray import BaseDoc
from docarray.documents import ImageDoc, TextDoc
Expand All @@ -31,6 +32,9 @@ class SimpleDoc(BaseDoc):
class Document(BaseDoc):
embedding: NdArray[2] = Field(dim=2, is_embedding=True)
text: str = Field()
texts: List[str] = Field(default=[])
# integers: List[int] = Field(default=[])
floats: List[float] = Field(default=[])


class NestedDocument(BaseDoc):
Expand All @@ -50,7 +54,14 @@ def documents():

# create the docs by enumerating from 1 and use that as the id
docs = [
Document(id=str(i), embedding=embedding, text=text)
Document(
id=str(i),
embedding=embedding,
text=text,
texts=[f'text{i}_0', f'text{i}_1'],
integers=[i, i],
floats=[1.5 * i, 2.5 * i],
)
for i, (embedding, text) in enumerate(zip(embeddings, texts))
]

Expand Down Expand Up @@ -170,6 +181,8 @@ class Document(BaseDoc):
({"path": ["text"], "operator": "Equal", "valueText": "lorem ipsum"}, 1),
({"path": ["text"], "operator": "Equal", "valueText": "foo"}, 0),
({"path": ["id"], "operator": "Equal", "valueString": "1"}, 1),
({"path": ["texts"], "operator": "ContainsAny", "valueText": ["text"]}, 3),
({"path": ["texts"], "operator": "ContainsAny", "valueText": ["text1_"]}, 1),
],
)
def test_filter(test_index, filter_query, expected_num_docs):
Expand Down