diff --git a/docarray/index/backends/weaviate.py b/docarray/index/backends/weaviate.py index 13eb689375..13c08c2d73 100644 --- a/docarray/index/backends/weaviate.py +++ b/docarray/index/backends/weaviate.py @@ -18,6 +18,8 @@ TypeVar, Union, cast, + get_origin, + get_args, ) import numpy as np @@ -43,7 +45,6 @@ TSchema = TypeVar('TSchema', bound=BaseDoc) T = TypeVar('T', bound='WeaviateDocumentIndex') - DEFAULT_BATCH_CONFIG = { "batch_size": 20, "dynamic": False, @@ -210,6 +211,8 @@ def _create_schema(self) -> None: self.bytes_columns.append(column_name) if column_info.db_type == 'number[]': self.nonembedding_array_columns.append(column_name) + if column_info.db_type == 'text[]': + self.nonembedding_array_columns.append(column_name) prop = { "name": column_name if column_name != 'id' @@ -253,6 +256,8 @@ class DBConfig(BaseDocIndex.DBConfig): 'number': {}, 'boolean': {}, 'number[]': {}, + 'int[]': {}, + 'text[]': {}, 'blob': {}, } ) @@ -717,6 +722,23 @@ def python_type_to_db_type(self, python_type: Type) -> Any: bytes: 'blob', } + if get_origin(python_type) == list: + py_weaviate_list_type_map = { + int: 'int[]', + float: 'number[]', + str: 'text[]', + } + + container_type = None + args = get_args(python_type) + if args: + container_type = args[0] + if ( + container_type is not None + and container_type in py_weaviate_list_type_map + ): + return py_weaviate_list_type_map[container_type] + for py_type, weaviate_type in py_weaviate_type_map.items(): if safe_issubclass(python_type, py_type): return weaviate_type