Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions elasticsearch/dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
MatchOnlyText,
Murmur3,
Nested,
NumpyDenseVector,
Object,
Passthrough,
Percolator,
Expand Down Expand Up @@ -189,6 +190,7 @@
"Murmur3",
"Nested",
"NestedFacet",
"NumpyDenseVector",
"Object",
"Passthrough",
"Percolator",
Expand Down
32 changes: 27 additions & 5 deletions elasticsearch/dsl/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,11 +1616,33 @@ def __init__(
kwargs["multi"] = True
super().__init__(*args, **kwargs)

def _deserialize(self, data: Any) -> Any:
if self._element_type == "float":
return float(data)
elif self._element_type == "byte":
return int(data)

class NumpyDenseVector(DenseVector):
"""A dense vector field that uses numpy arrays.

Accepts the same arguments as class ``DenseVector`` plus:

:arg dtype: The numpy data type to use for the array. If not given, numpy will select the type based on the data.
"""

def __init__(self, *args: Any, dtype: Optional[type] = None, **kwargs: Any):
super().__init__(*args, **kwargs)
self._dtype = dtype

def deserialize(self, data: Any) -> Any:
if isinstance(data, list):
import numpy as np

return np.array(data, dtype=self._dtype)
return super().deserialize(data)

def clean(self, data: Any) -> Any:
# this method does the same as the one in the parent classes, but it
# avoids comparisons that do not work for numpy arrays
if data is not None:
data = self.deserialize(data)
if (data is None or len(data) == 0) and self._required:
raise ValidationException("Value required for this field.")
return data


Expand Down
13 changes: 11 additions & 2 deletions elasticsearch/dsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,17 @@ def to_dict(self, skip_empty: bool = True) -> Dict[str, Any]:
if skip_empty:
# don't serialize empty values
# careful not to include numeric zeros
if v in ([], {}, None):
continue
try:
if v in ([], {}, None):
continue
except ValueError:
# the above fails when v is a numpy array
# try using len() instead
try:
if len(v) == 0:
continue
except TypeError:
pass

out[k] = v
return out
Expand Down
18 changes: 13 additions & 5 deletions examples/quotes/backend/quotes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,30 @@
from typing import Annotated

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, ValidationError
import numpy as np
from pydantic import BaseModel, Field, PlainSerializer
from sentence_transformers import SentenceTransformer

from elasticsearch import NotFoundError
from elasticsearch import NotFoundError, OrjsonSerializer
from elasticsearch.dsl.pydantic import AsyncBaseESModel
from elasticsearch import dsl

model = SentenceTransformer("all-MiniLM-L6-v2")
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']])
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']], serializer=OrjsonSerializer())


class Quote(AsyncBaseESModel):
quote: str
author: Annotated[str, dsl.Keyword()]
tags: Annotated[list[str], dsl.Keyword()]
embedding: Annotated[list[float], dsl.DenseVector()] = Field(init=False, default=[])
embedding: Annotated[
np.ndarray,
PlainSerializer(lambda v: v.tolist()),
dsl.NumpyDenseVector(dtype=np.float32)
] = Field(init=False, default_factory=lambda: np.array([], dtype=np.float32))

class Config:
arbitrary_types_allowed = True

class Index:
name = 'quotes'
Expand Down Expand Up @@ -135,7 +143,7 @@ async def search_quotes(req: SearchRequest) -> SearchResponse:
def embed_quotes(quotes):
embeddings = model.encode([q.quote for q in quotes])
for q, e in zip(quotes, embeddings):
q.embedding = e.tolist()
q.embedding = e


async def ingest_quotes():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Tuple, Union

import numpy as np
import pytest
from pytest import raises
from pytz import timezone
Expand All @@ -47,6 +48,7 @@
Mapping,
MetaField,
Nested,
NumpyDenseVector,
Object,
Q,
RankFeatures,
Expand Down Expand Up @@ -865,25 +867,33 @@ class Doc(AsyncDocument):
float_vector: List[float] = mapped_field(DenseVector())
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
bit_vector: List[int] = mapped_field(DenseVector(element_type="bit"))
numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())

class Index:
name = "vectors"

await Doc._index.delete(ignore_unavailable=True)
await Doc.init()

test_float_vector = [1.0, 1.2, 2.3]
test_byte_vector = [12, 23, 34, 45]
test_bit_vector = [18, -43, -112]

doc = Doc(
float_vector=[1.0, 1.2, 2.3],
byte_vector=[12, 23, 34, 45],
bit_vector=[18, -43, -112],
float_vector=test_float_vector,
byte_vector=test_byte_vector,
bit_vector=test_bit_vector,
numpy_float_vector=np.array(test_float_vector),
)
await doc.save(refresh=True)

docs = await Doc.search().execute()
assert len(docs) == 1
assert [round(v, 1) for v in docs[0].float_vector] == doc.float_vector
assert docs[0].byte_vector == doc.byte_vector
assert docs[0].bit_vector == doc.bit_vector
assert [round(v, 1) for v in docs[0].float_vector] == test_float_vector
assert docs[0].byte_vector == test_byte_vector
assert docs[0].bit_vector == test_bit_vector
assert type(docs[0].numpy_float_vector) is np.ndarray
assert [round(v, 1) for v in docs[0].numpy_float_vector] == test_float_vector


@pytest.mark.anyio
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union

import numpy as np
import pytest
from pytest import raises
from pytz import timezone
Expand All @@ -46,6 +47,7 @@
Mapping,
MetaField,
Nested,
NumpyDenseVector,
Object,
Q,
RankFeatures,
Expand Down Expand Up @@ -853,25 +855,33 @@ class Doc(Document):
float_vector: List[float] = mapped_field(DenseVector())
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
bit_vector: List[int] = mapped_field(DenseVector(element_type="bit"))
numpy_float_vector: np.ndarray = mapped_field(NumpyDenseVector())

class Index:
name = "vectors"

Doc._index.delete(ignore_unavailable=True)
Doc.init()

test_float_vector = [1.0, 1.2, 2.3]
test_byte_vector = [12, 23, 34, 45]
test_bit_vector = [18, -43, -112]

doc = Doc(
float_vector=[1.0, 1.2, 2.3],
byte_vector=[12, 23, 34, 45],
bit_vector=[18, -43, -112],
float_vector=test_float_vector,
byte_vector=test_byte_vector,
bit_vector=test_bit_vector,
numpy_float_vector=np.array(test_float_vector),
)
doc.save(refresh=True)

docs = Doc.search().execute()
assert len(docs) == 1
assert [round(v, 1) for v in docs[0].float_vector] == doc.float_vector
assert docs[0].byte_vector == doc.byte_vector
assert docs[0].bit_vector == doc.bit_vector
assert [round(v, 1) for v in docs[0].float_vector] == test_float_vector
assert docs[0].byte_vector == test_byte_vector
assert docs[0].bit_vector == test_bit_vector
assert type(docs[0].numpy_float_vector) is np.ndarray
assert [round(v, 1) for v in docs[0].numpy_float_vector] == test_float_vector


@pytest.mark.sync
Expand Down
29 changes: 24 additions & 5 deletions utils/templates/field.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,30 @@ class {{ k.name }}({{ k.parent }}):
kwargs["multi"] = True
super().__init__(*args, **kwargs)

def _deserialize(self, data: Any) -> Any:
if self._element_type == "float":
return float(data)
elif self._element_type == "byte":
return int(data)
class NumpyDenseVector(DenseVector):
"""A dense vector field that uses numpy arrays.

Accepts the same arguments as class ``DenseVector`` plus:

:arg dtype: The numpy data type to use for the array. If not given, numpy will select the type based on the data.
"""
def __init__(self, *args: Any, dtype: Optional[type] = None, **kwargs: Any):
super().__init__(*args, **kwargs)
self._dtype = dtype

def deserialize(self, data: Any) -> Any:
if isinstance(data, list):
import numpy as np
return np.array(data, dtype=self._dtype)
return super().deserialize(data)

def clean(self, data: Any) -> Any:
# this method does the same as the one in the parent classes, but it
# avoids comparisons that do not work for numpy arrays
if data is not None:
data = self.deserialize(data)
if (data is None or len(data) == 0) and self._required:
raise ValidationException("Value required for this field.")
return data
{% elif k.field == "scaled_float" %}
if 'scaling_factor' not in kwargs:
Expand Down