Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: make vector validation work with fp16 #1962

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions python/python/lance/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,19 @@ def validate_vector_index(
if np.isnan(vec).any():
total -= 1
continue
distance = dataset.to_table(
res = dataset.to_table(
nearest={
"column": column,
"q": vec,
"k": 1,
"nprobes": 1,
"refine_factor": refine_factor,
}
)["_distance"].to_pylist()[0]
passes += 1 if abs(distance) < 1e-6 else 0
},
with_row_id=True,
)
passes += (
1 if (res[column].to_numpy(zero_copy_only=False)[0] == vec).all() else 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to worry about epsilon/precision? E.g. should this be < 1e-6?

)

if passes / total < pass_threshold:
raise ValueError(
Expand Down
19 changes: 11 additions & 8 deletions python/python/lance/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
from . import LanceDataset


def _normalize_vectors(vectors, ndim):
def _normalize_vectors(vectors, ndim, dtype):
if ndim is None:
ndim = len(next(iter(vectors)))
values = np.array(vectors, dtype="float32").ravel()
values = np.array(vectors, dtype=dtype).ravel()
return pa.FixedSizeListArray.from_arrays(values, list_size=ndim)


Expand All @@ -54,6 +54,7 @@ def vec_to_table(
names: Optional[Union[str, list]] = None,
ndim: Optional[int] = None,
check_ndim: bool = True,
dtype: str = "float32",
) -> pa.Table:
"""
Create a pyarrow Table containing vectors.
Expand Down Expand Up @@ -109,7 +110,7 @@ def vec_to_table(
values = list(data.values())
if check_ndim:
ndim = _validate_ndim(values, ndim)
vectors = _normalize_vectors(values, ndim)
vectors = _normalize_vectors(values, ndim, dtype)
ids = pa.array(data.keys())
arrays = [ids, vectors]
elif isinstance(data, list) or (
Expand All @@ -123,7 +124,7 @@ def vec_to_table(
raise ValueError(f"names cannot be more than 1 got {len(names)}")
if check_ndim:
ndim = _validate_ndim(data, ndim)
vectors = _normalize_vectors(data, ndim)
vectors = _normalize_vectors(data, ndim, dtype)
arrays = [vectors]
else:
raise NotImplementedError(
Expand Down Expand Up @@ -231,10 +232,12 @@ def compute_partitions(
with_row_id=True,
columns=[column],
)
output_schema = pa.schema([
pa.field("row_id", pa.uint64()),
pa.field("partition", pa.uint32()),
])
output_schema = pa.schema(
[
pa.field("row_id", pa.uint64()),
pa.field("partition", pa.uint32()),
]
)

def _partition_assignment() -> Iterable[pa.RecordBatch]:
with torch.no_grad():
Expand Down
43 changes: 33 additions & 10 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from lance.vector import vec_to_table # noqa: E402


def create_table(nvec=1000, ndim=128, nans=0):
def create_table(nvec=1000, ndim=128, nans=0, dtype="float32"):
mat = np.random.randn(nvec, ndim)
if nans > 0:
nans_mat = np.empty((nans, ndim))
Expand All @@ -41,7 +41,7 @@ def gen_str(n):

meta = np.array([gen_str(100) for _ in range(nvec + nans)])
tbl = (
vec_to_table(data=mat)
vec_to_table(data=mat, dtype=dtype)
.append_column("price", pa.array(price))
.append_column("meta", pa.array(meta))
.append_column("id", pa.array(range(nvec + nans)))
Expand Down Expand Up @@ -435,10 +435,12 @@ def create_uniform_table(min, max, nvec, offset, ndim=8):
mat = np.random.uniform(min, max, (nvec, ndim))
# rowid = np.arange(offset, offset + nvec)
tbl = vec_to_table(data=mat)
tbl = pa.Table.from_pydict({
"vector": tbl.column(0).chunk(0),
"filterable": np.arange(offset, offset + nvec),
})
tbl = pa.Table.from_pydict(
{
"vector": tbl.column(0).chunk(0),
"filterable": np.arange(offset, offset + nvec),
}
)
return tbl


Expand Down Expand Up @@ -506,10 +508,12 @@ def test_knn_with_deletions(tmp_path):
values = pa.array(
[x for val in range(50) for x in [float(val)] * 5], type=pa.float32()
)
tbl = pa.Table.from_pydict({
"vector": pa.FixedSizeListArray.from_arrays(values, dims),
"filterable": pa.array(range(50)),
})
tbl = pa.Table.from_pydict(
{
"vector": pa.FixedSizeListArray.from_arrays(values, dims),
"filterable": pa.array(range(50)),
}
)
dataset = lance.write_dataset(tbl, tmp_path, max_rows_per_group=10)

dataset.delete("not (filterable % 5 == 0)")
Expand Down Expand Up @@ -639,3 +643,22 @@ def direct_first_call_to_new_table(*args, **kwargs):
ds.sample = direct_first_call_to_new_table
with pytest.raises(ValueError, match="Vector index failed sanity check"):
validate_vector_index(ds, "vector", sample_size=100)


def test_validate_vector_index_fp16_torch(tmp_path: Path):
# make sure the sanity check is correctly catching issues
# need higher dims here to catch fp16 precision issues

ds = lance.write_dataset(create_table(nvec=2048, dtype="float16"), tmp_path)
ds = ds.create_index(
"vector",
index_type="IVF_PQ",
metric="cosine",
num_partitions=4,
num_sub_vectors=16,
accelerator=torch.device("cpu"),
replace=True,
)
# fp16 can have some IVF issues
# where the vectors are assigned to the wrong partition
validate_vector_index(ds, "vector", nprobes=1, sample_size=100, pass_threshold=0.95)
Loading