Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make transformer optional #305

Merged
merged 5 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ Machine learning and neural networks often use half-precision data types, such a

> However, these data types are not natively available in the Node.js environment, To enable users to utilize these formats, the Node SDK provides support for transformers during insert, query, and search operations.
>
> There are four built-in transformers available for performing a float32 to bytes transformation.
> `f32ArrayToF16Bytes`, `f16BytesToF32Array`, `f32ArrayToBf16Bytes`, `bf16BytesToF32Array`
> The transform parameter is optional. If not specified, it defaults to inserting/outputting bytes.
> There are four default transformers for performing a float32 to bytes transformation for BF16 and Float16 types: f32ArrayToF16Bytes, f16BytesToF32Array, f32ArrayToBf16Bytes, and bf16BytesToF32Array. If you wish to use your own transformers for Float16 and BFloat16, you can specify them.
>
> ```javascript
> import {
Expand All @@ -52,32 +50,33 @@ Machine learning and neural networks often use half-precision data types, such a
> f32ArrayToBf16Bytes,
> bf16BytesToF32Array,
> } from '@zilliz/milvus2-sdk-node';
> // insert
>
> //Insert float32 array for the float16 field. Node SDK will transform it to bytes using `f32ArrayToF16Bytes`. You can use your own transformer.
> const insert = await milvusClient.insert({
> collection_name: COLLECTION_NAME,
> data: data,
> transformers: {
> [DataType.BFloat16Vector]: f32ArrayToBf16Bytes,
> },
> // transformers: {
> // [DataType.BFloat16Vector]: f32ArrayToF16Bytes, // use your own transformer
> // },
> });
> // query
> // query: output float32 array other than bytes,
> const query = await milvusClient.query({
> collection_name: COLLECTION_NAME,
> filter: 'id > 0',
> output_fields: ['vector', 'id'],
> transformers: {
> [DataType.BFloat16Vector]: bf16BytesToF32Array,
> },
> // transformers: {
> // [DataType.BFloat16Vector]: bf16BytesToF32Array, // use your own transformer
> // },
> });
> // search
> // search: use bytes to search, output float32 array
> const search = await milvusClient.search({
> vector: data[0].vector,
> vector: data[0].vector, // if you pass bytes, no transform will performed
> collection_name: COLLECTION_NAME,
> output_fields: ['id', 'vector'],
> limit: 5,
> transformers: {
> [DataType.BFloat16Vector]: bf16BytesToF32Array,
> },
> // transformers: {
> // [DataType.BFloat16Vector]: bf16BytesToF32Array, // use your own transformer
> // },
> });
> ```

Expand Down Expand Up @@ -112,6 +111,17 @@ const sparseArray = [undefined, 0.0, 0.5, 0.3, undefined, 0.2];
Starting from Milvus 2.4, it supports [Multi-Vector Search](https://milvus.io/docs/multi-vector-search.md#API-overview), you can continue to utilize the search API with similar parameters to perform multi-vector searches, and the format of the results remains unchanged.

```javascript
// single-vector search on a collection with multiple vector fields
const search = await milvusClient.search({
collection_name: collection_name,
data: [1, 2, 3, 4, 5, 6, 7, 8],
anns_field: 'vector', // required if you have multiple vector fields in the collection
params: { nprobe: 2 },
filter: 'id > 100',
limit: 5,
});

// multi-vector search on a collection with multiple vector fields
const search = await milvusClient.search({
collection_name: collection_name,
data: [
Expand All @@ -126,6 +136,7 @@ const search = await milvusClient.search({
},
],
limit: 5,
filter: 'id > 100',
});
```

Expand Down Expand Up @@ -317,7 +328,7 @@ const res = await client.search({

- [What is Milvus](https://milvus.io/)
- [Milvus Node SDK API reference](https://milvus.io/api-reference/node/v2.3.x/About.md)
- [Feder, anns index visuliazation tool](https://github.com/zilliztech/feder)
- [Feder, anns index visualization tool](https://github.com/zilliztech/feder)

## How to contribute

Expand Down
4 changes: 2 additions & 2 deletions milvus/types/Data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import {

// all value types supported by milvus
export type FloatVector = number[];
export type Float16Vector = number[];
export type BFloat16Vector = number[];
export type Float16Vector = number[] | Uint8Array;
export type BFloat16Vector = number[] | Uint8Array;
export type BinaryVector = number[];
export type SparseVectorArray = (number | undefined)[];
export type SparseVectorDic = { [key: string]: number };
Expand Down
20 changes: 15 additions & 5 deletions milvus/utils/Bytes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
SparseVectorCSR,
SparseVectorCOO,
BFloat16Vector,
SparseVectorArray,
} from '..';

/**
Expand Down Expand Up @@ -164,8 +165,13 @@ export const sparseToBytes = (data: SparseFloatVector): Uint8Array => {

switch (type) {
case 'array':
indices = Object.keys(data).map(Number);
values = Object.values(data);
for (let i = 0; i < (data as SparseVectorArray).length; i++) {
const element = (data as SparseVectorArray)[i];
if (element !== undefined && !isNaN(element)) {
indices.push(i);
values.push(element);
}
}
break;
case 'coo':
indices = Object.values(
Expand Down Expand Up @@ -244,7 +250,7 @@ export const bytesToSparseRow = (bufferData: Buffer): SparseFloatVector => {
* This function builds a placeholder group in bytes format for Milvus.
*
* @param {Root} milvusProto - The root object of the Milvus protocol.
* @param {VectorTypes[]} searchVectors - An array of search vectors.
* @param {VectorTypes[]} vectors - An array of search vectors.
* @param {DataType} vectorDataType - The data type of the vectors.
*
* @returns {Uint8Array} The placeholder group in bytes format.
Expand All @@ -265,10 +271,14 @@ export const buildPlaceholderGroupBytes = (
bytes = vectors.map(v => f32ArrayToBinaryBytes(v as BinaryVector));
break;
case DataType.BFloat16Vector:
bytes = vectors.map(v => f32ArrayToBf16Bytes(v as BFloat16Vector));
bytes = vectors.map(v =>
Array.isArray(v) ? f32ArrayToBf16Bytes(v as BFloat16Vector) : v
);
break;
case DataType.Float16Vector:
bytes = vectors.map(v => f32ArrayToF16Bytes(v as Float16Vector));
bytes = vectors.map(v =>
Array.isArray(v) ? f32ArrayToF16Bytes(v as Float16Vector) : v
);
break;
case DataType.SparseFloatVector:
bytes = vectors.map(v => sparseToBytes(v as SparseFloatVector));
Expand Down
47 changes: 25 additions & 22 deletions milvus/utils/Format.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ import {
getSparseFloatVectorType,
InsertTransformers,
OutputTransformers,
SparseVectorArray,
f32ArrayToBf16Bytes,
f32ArrayToF16Bytes,
bf16BytesToF32Array,
f16BytesToF32Array,
} from '../';

/**
Expand Down Expand Up @@ -445,22 +450,18 @@ export const buildFieldDataMap = (
for (let i = 0; i < f16Bytes.byteLength; i += f16Dim) {
const slice = f16Bytes.slice(i, i + f16Dim);
const isFloat16 = dataKey === 'float16_vector';
const isBFloat16 = dataKey === 'bfloat16_vector';
let dataType: DataType.BFloat16Vector | DataType.Float16Vector;

dataType = isFloat16
? DataType.Float16Vector
: DataType.BFloat16Vector;

if (
(isFloat16 || isBFloat16) &&
transformers &&
transformers[dataType]
) {
field_data.push(transformers[dataType]!(slice));
} else {
field_data.push(slice);
}
const localTransformers = transformers || {
[DataType.BFloat16Vector]: bf16BytesToF32Array,
[DataType.Float16Vector]: f16BytesToF32Array,
};

field_data.push(localTransformers[dataType]!(slice));
}
break;
case 'sparse_float_vector':
Expand Down Expand Up @@ -545,22 +546,24 @@ export const buildFieldData = (
transformers?: InsertTransformers
): FieldData => {
const { type, elementType, name } = field;
const isFloat32 = Array.isArray(rowData[name]);

switch (DataTypeMap[type]) {
case DataType.BinaryVector:
case DataType.FloatVector:
return rowData[name];
case DataType.BFloat16Vector:
if (transformers && transformers[DataType.BFloat16Vector]) {
return transformers[DataType.BFloat16Vector](
rowData[name] as BFloat16Vector
);
}
const bf16Transformer =
transformers?.[DataType.BFloat16Vector] || f32ArrayToBf16Bytes;
return isFloat32
? bf16Transformer(rowData[name] as BFloat16Vector)
: rowData[name];
case DataType.Float16Vector:
if (transformers && transformers[DataType.Float16Vector]) {
return transformers[DataType.Float16Vector](
rowData[name] as Float16Vector
);
}
const f16Transformer =
transformers?.[DataType.Float16Vector] || f32ArrayToF16Bytes;
return isFloat32
? f16Transformer(rowData[name] as Float16Vector)
: rowData[name];
case DataType.JSON:
return Buffer.from(JSON.stringify(rowData[name] || {}));
case DataType.Array:
Expand Down Expand Up @@ -728,7 +731,7 @@ export const buildSearchRequest = (
searchSimpleReq.vector ||
searchSimpleReq.data;

// format saerching vector
// format searching vector
searchingVector = formatSearchVector(searchingVector, field.dataType!);

// create search request
Expand Down Expand Up @@ -898,7 +901,7 @@ export const formatSearchVector = (
return [searchVector] as VectorTypes[];
}
case DataType.SparseFloatVector:
const type = getSparseFloatVectorType(searchVector as VectorTypes);
const type = getSparseFloatVectorType(searchVector as SparseVectorArray);
if (type !== 'unknown') {
return [searchVector] as VectorTypes[];
}
Expand Down
16 changes: 5 additions & 11 deletions test/grpc/BFloat16Vector.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ describe(`BFloat16 vector API testing`, () => {
const insert = await milvusClient.insert({
collection_name: COLLECTION_NAME,
data: data,
transformers: {
[DataType.BFloat16Vector]: f32ArrayToBf16Bytes,
},
});

// console.log(' insert', insert);
Expand Down Expand Up @@ -108,9 +105,6 @@ describe(`BFloat16 vector API testing`, () => {
collection_name: COLLECTION_NAME,
filter: 'id > 0',
output_fields: ['vector', 'id'],
transformers: {
[DataType.BFloat16Vector]: bf16BytesToF32Array,
},
});

// console.dir(query, { depth: null });
Expand All @@ -127,13 +121,10 @@ describe(`BFloat16 vector API testing`, () => {

it(`search with Bfloat16 vector should be successful`, async () => {
const search = await milvusClient.search({
vector: data[0].vector,
data: f32ArrayToBf16Bytes(data[0].vector),
collection_name: COLLECTION_NAME,
output_fields: ['id', 'vector'],
limit: 5,
transformers: {
[DataType.BFloat16Vector]: bf16BytesToF32Array,
},
});

// console.log('search', search);
Expand All @@ -144,7 +135,10 @@ describe(`BFloat16 vector API testing`, () => {

it(`search with Bfloat16 vector and nq > 0 should be successful`, async () => {
const search = await milvusClient.search({
vector: [data[0].vector, data[1].vector],
data: [
f32ArrayToBf16Bytes(data[0].vector),
f32ArrayToBf16Bytes(data[1].vector),
],
collection_name: COLLECTION_NAME,
output_fields: ['id', 'vector'],
limit: 5,
Expand Down
2 changes: 1 addition & 1 deletion test/grpc/Basic.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ describe(`Basic API without database`, () => {
it(`search should be successful`, async () => {
const search = await milvusClient.search({
collection_name: COLLECTION_NAME,
vector: [1, 2, 3, 4],
data: [1, 2, 3, 4],
});
expect(search.status.error_code).toEqual(ErrorCode.SUCCESS);
});
Expand Down
2 changes: 1 addition & 1 deletion test/grpc/BinaryVector.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ describe(`Binary vectors API testing`, () => {

it(`search with binary vector should be successful`, async () => {
const search = await milvusClient.search({
vector: data[0].vector,
data: data[0].vector,
collection_name: COLLECTION_NAME,
output_fields: ['id', 'vector'],
limit: 5,
Expand Down
Loading
Loading