Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow fixed size tensors to be used as query vectors #1736

Merged
merged 4 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 59 additions & 18 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@
Iterable[RecordBatch],
pa.RecordBatchReader,
]
QueryVectorLike = Union[
pd.Series,
pa.Array,
pa.Scalar,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can we use scalar here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's to handle FixedSizeListScalar (and fixed shape tensor scalar) Pretty often I find myself doing something like:

queries = dataset.take(random_indices)
for query in queries:
  # Here `query` is a FixedSizeListScalar
  ...

np.ndarray,
Iterable[float],
]
except ImportError:
pd = None
ReaderLike = Union[
Expand All @@ -69,6 +76,12 @@
Iterable[RecordBatch],
pa.RecordBatchReader,
]
QueryVectorLike = Union[
pa.Array,
pa.Scalar,
np.ndarray,
Iterable[float],
]

if TYPE_CHECKING:
import torch
Expand Down Expand Up @@ -1570,37 +1583,32 @@ def with_fragments(
def nearest(
self,
column: str,
q: pa.FloatingPointArray | List[float] | np.ndarray,
q: QueryVectorLike,
k: Optional[int] = None,
metric: Optional[str] = None,
nprobes: Optional[int] = None,
refine_factor: Optional[int] = None,
use_index: bool = True,
) -> ScannerBuilder:
column_field = self.ds.schema.field(column)
q_size = q.size if isinstance(q, np.ndarray) else len(q)
q = _coerce_query_vector(q)

if self.ds.schema.get_field_index(column) < 0:
raise ValueError(f"Embedding column {column} not in dataset")
if not (
isinstance(
column_field.type, (pa.FloatingPointArray, np.ndarray, list, tuple)
)
or pa.types.is_fixed_size_list(column_field.type)
):
raise ValueError(f"Embedding column {column} is not in the dataset")

column_field = self.ds.schema.field(column)
column_type = column_field.type
if hasattr(column_type, "storage_type"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to check storage type for pyarrow > 12?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for fixed shape tensor array. Once we are using pyarrow 13 we can directly check for FixedShapeTensorArray. However, another advantage of using storage_type is that we are future proofing ourselves against other possible extension types.

column_type = column_type.storage_type
if not pa.types.is_fixed_size_list(column_type):
raise TypeError(
f"Query column {column} must be a vector. Got {column_field.type}."
)
if q_size != column_field.type.list_size:
if len(q) != column_type.list_size:
raise ValueError(
f"Query vector size {q_size} does not match index column size"
f" {column_field.type.list_size}"
f"Query vector size {len(q)} does not match index column size"
f" {column_type.list_size}"
)
if isinstance(q, (np.ndarray, list, tuple)):
q = np.array(q).astype("float64") # workaround for GH-608
q = pa.FloatingPointArray.from_pandas(q, type=pa.float32())
if not isinstance(q, pa.FloatingPointArray):
raise TypeError("query vector must be list-like or pa.FloatingPointArray")

if k is not None and int(k) <= 0:
raise ValueError(f"Nearest-K must be > 0 but got {k}")
if nprobes is not None and int(nprobes) <= 0:
Expand Down Expand Up @@ -1978,6 +1986,39 @@ def _coerce_reader(
)


def _coerce_query_vector(query: QueryVectorLike):
if isinstance(query, pa.Scalar):
if isinstance(query, pa.ExtensionScalar):
# If it's an extension scalar then convert to storage
query = query.value
if isinstance(query.type, pa.FixedSizeListType):
query = query.values
elif isinstance(query, (np.ndarray, list, tuple)):
query = np.array(query).astype("float64") # workaround for GH-608
query = pa.FloatingPointArray.from_pandas(query, type=pa.float32())
elif not isinstance(query, pa.Array):
try:
query = pa.array(query)
except: # noqa: E722
raise TypeError(
"Query vectors should be an array of floats, "
f"got {type(query)} which we cannot coerce to a "
"float array"
)

# At this point `query` should be an arrow array
if not isinstance(query, pa.FloatingPointArray):
if pa.types.is_integer(query.type):
query = query.cast(pa.float32())
else:
raise TypeError(
"query vector must be list-like or pa.FloatingPointArray "
f"but received {query.type}"
)

return query


def _validate_schema(schema: pa.Schema):
"""
Make sure the metadata is valid utf8
Expand Down
24 changes: 24 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,3 +1062,27 @@ def test_dataset_progress(tmp_path: Path):
assert len(ds.get_fragments()) == 2
assert progress.begin_called == 2
assert progress.complete_called == 2


def test_tensor_type(tmp_path: Path):
arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]]
storage = pa.array(arr, pa.list_(pa.float32(), 4))
tensor_type = pa.fixed_shape_tensor(pa.float32(), [4])
ext_arr = pa.ExtensionArray.from_storage(tensor_type, storage)
data = pa.table({"tensor": ext_arr})

ds = lance.write_dataset(data, tmp_path)

query_arr = [[10, 20, 30, 40]]
storage = pa.array(query_arr, pa.list_(pa.float32(), 4))
ext_arr = pa.ExtensionArray.from_storage(tensor_type, storage)
ext_scalar = ext_arr[0]

results = ds.to_table(
nearest={
"column": "tensor",
"k": 1,
"q": ext_scalar,
}
)
assert results.num_rows == 1
Loading