-
Notifications
You must be signed in to change notification settings - Fork 193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: allow fixed size tensors to be used as query vectors #1736
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,13 @@ | |
Iterable[RecordBatch], | ||
pa.RecordBatchReader, | ||
] | ||
QueryVectorLike = Union[ | ||
pd.Series, | ||
pa.Array, | ||
pa.Scalar, | ||
np.ndarray, | ||
Iterable[float], | ||
] | ||
except ImportError: | ||
pd = None | ||
ReaderLike = Union[ | ||
|
@@ -69,6 +76,12 @@ | |
Iterable[RecordBatch], | ||
pa.RecordBatchReader, | ||
] | ||
QueryVectorLike = Union[ | ||
pa.Array, | ||
pa.Scalar, | ||
np.ndarray, | ||
Iterable[float], | ||
] | ||
|
||
if TYPE_CHECKING: | ||
import torch | ||
|
@@ -1570,37 +1583,32 @@ def with_fragments( | |
def nearest( | ||
self, | ||
column: str, | ||
q: pa.FloatingPointArray | List[float] | np.ndarray, | ||
q: QueryVectorLike, | ||
k: Optional[int] = None, | ||
metric: Optional[str] = None, | ||
nprobes: Optional[int] = None, | ||
refine_factor: Optional[int] = None, | ||
use_index: bool = True, | ||
) -> ScannerBuilder: | ||
column_field = self.ds.schema.field(column) | ||
q_size = q.size if isinstance(q, np.ndarray) else len(q) | ||
q = _coerce_query_vector(q) | ||
|
||
if self.ds.schema.get_field_index(column) < 0: | ||
raise ValueError(f"Embedding column {column} not in dataset") | ||
if not ( | ||
isinstance( | ||
column_field.type, (pa.FloatingPointArray, np.ndarray, list, tuple) | ||
) | ||
or pa.types.is_fixed_size_list(column_field.type) | ||
): | ||
raise ValueError(f"Embedding column {column} is not in the dataset") | ||
|
||
column_field = self.ds.schema.field(column) | ||
column_type = column_field.type | ||
if hasattr(column_type, "storage_type"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need to check storage type for pyarrow > 12? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not for fixed shape tensor array. Once we are using pyarrow 13 we can directly check for |
||
column_type = column_type.storage_type | ||
if not pa.types.is_fixed_size_list(column_type): | ||
raise TypeError( | ||
f"Query column {column} must be a vector. Got {column_field.type}." | ||
) | ||
if q_size != column_field.type.list_size: | ||
if len(q) != column_type.list_size: | ||
raise ValueError( | ||
f"Query vector size {q_size} does not match index column size" | ||
f" {column_field.type.list_size}" | ||
f"Query vector size {len(q)} does not match index column size" | ||
f" {column_type.list_size}" | ||
) | ||
if isinstance(q, (np.ndarray, list, tuple)): | ||
q = np.array(q).astype("float64") # workaround for GH-608 | ||
q = pa.FloatingPointArray.from_pandas(q, type=pa.float32()) | ||
if not isinstance(q, pa.FloatingPointArray): | ||
raise TypeError("query vector must be list-like or pa.FloatingPointArray") | ||
|
||
if k is not None and int(k) <= 0: | ||
raise ValueError(f"Nearest-K must be > 0 but got {k}") | ||
if nprobes is not None and int(nprobes) <= 0: | ||
|
@@ -1978,6 +1986,39 @@ def _coerce_reader( | |
) | ||
|
||
|
||
def _coerce_query_vector(query: QueryVectorLike): | ||
if isinstance(query, pa.Scalar): | ||
if isinstance(query, pa.ExtensionScalar): | ||
# If it's an extension scalar then convert to storage | ||
query = query.value | ||
if isinstance(query.type, pa.FixedSizeListType): | ||
query = query.values | ||
elif isinstance(query, (np.ndarray, list, tuple)): | ||
query = np.array(query).astype("float64") # workaround for GH-608 | ||
query = pa.FloatingPointArray.from_pandas(query, type=pa.float32()) | ||
elif not isinstance(query, pa.Array): | ||
try: | ||
query = pa.array(query) | ||
except: # noqa: E722 | ||
raise TypeError( | ||
"Query vectors should be an array of floats, " | ||
f"got {type(query)} which we cannot coerce to a " | ||
"float array" | ||
) | ||
|
||
# At this point `query` should be an arrow array | ||
if not isinstance(query, pa.FloatingPointArray): | ||
if pa.types.is_integer(query.type): | ||
query = query.cast(pa.float32()) | ||
else: | ||
raise TypeError( | ||
"query vector must be list-like or pa.FloatingPointArray " | ||
f"but received {query.type}" | ||
) | ||
|
||
return query | ||
|
||
|
||
def _validate_schema(schema: pa.Schema): | ||
""" | ||
Make sure the metadata is valid utf8 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can we use scalar here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's to handle FixedSizeListScalar (and fixed shape tensor scalar) Pretty often I find myself doing something like: