Skip to content

Commit

Permalink
Support more types of sparse vectors (#293)
Browse files Browse the repository at this point in the history
* generate different types of sparse vector

Signed-off-by: ryjiang <jiangruiyi@gmail.com>

* fix sparse array in js

Signed-off-by: ryjiang <jiangruiyi@gmail.com>

* add sparse array test

Signed-off-by: ryjiang <jiangruiyi@gmail.com>

* add csr sparse vector test

Signed-off-by: ryjiang <jiangruiyi@gmail.com>

* add coo support

Signed-off-by: ryjiang <jiangruiyi@gmail.com>

* remove unused import

Signed-off-by: ryjiang <jiangruiyi@gmail.com>

* refine comments

Signed-off-by: ryjiang <jiangruiyi@gmail.com>

* refine comment

Signed-off-by: ryjiang <jiangruiyi@gmail.com>

---------

Signed-off-by: ryjiang <jiangruiyi@gmail.com>
  • Loading branch information
shanghaikid committed Mar 29, 2024
1 parent 5034a82 commit 611a550
Show file tree
Hide file tree
Showing 10 changed files with 649 additions and 28 deletions.
4 changes: 1 addition & 3 deletions milvus/grpc/Data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ import {
parseSparseRowsToBytes,
getSparseDim,
parseBinaryVectorToBytes,
parseFloat16VectorToBytes,
Float16Vector,
} from '../';
import { Collection } from './Collection';

Expand Down Expand Up @@ -183,7 +181,7 @@ export class Data extends Collection {
}
if (
DataTypeMap[field.type] === DataType.BinaryVector &&
(rowData[name] as VectorTypes).length !== field.dim! / 8
(rowData[name] as BinaryVector).length !== field.dim! / 8
) {
throw new Error(ERROR_REASONS.INSERT_CHECK_WRONG_DIM);
}
Expand Down
16 changes: 15 additions & 1 deletion milvus/types/Data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,21 @@ export type FloatVector = number[];
export type Float16Vector = number[];
export type BFloat16Vector = number[];
export type BinaryVector = number[];
export type SparseFloatVector = { [key: string]: number };
export type SparseVectorArray = (number | undefined)[];
export type SparseVectorDic = { [key: string]: number };
export type SparseVectorCSR = {
indices: number[];
values: number[];
};
export type SparseVectorCOO = { index: number; value: number }[];

export type SparseFloatVector =
| SparseVectorArray
| SparseVectorDic
| SparseVectorCSR
| SparseVectorCOO;

// export type SparseFloatVector = { [key: string]: number };
export type VectorTypes =
| FloatVector
| Float16Vector
Expand Down
63 changes: 57 additions & 6 deletions milvus/utils/Bytes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import {
DataType,
VectorTypes,
Float16Vector,
SparseVectorCSR,
SparseVectorCOO,
} from '..';

/**
Expand Down Expand Up @@ -50,6 +52,26 @@ export const parseBytesToFloat16Vector = (float16Bytes: Uint8Array) => {
return Array.from(float16Array);
};

/**
* Get SparseVector type.
* @param {SparseFloatVector} vector - The sparse float vector to convert.
*
* @returns string, 'array' | 'coo' | 'csr' | 'dict'
*/
export const getSparseFloatVectorType = (vector: SparseFloatVector) => {
if (Array.isArray(vector)) {
if (typeof vector[0] === 'number' || typeof vector[0] === 'undefined') {
return 'array';
} else {
return 'coo';
}
} else if ('indices' in vector && 'values' in vector) {
return 'csr';
} else {
return 'dict';
}
};

/**
* Converts a sparse float vector into bytes format.
*
Expand All @@ -61,10 +83,39 @@ export const parseBytesToFloat16Vector = (float16Bytes: Uint8Array) => {
export const parseSparseVectorToBytes = (
data: SparseFloatVector
): Uint8Array => {
const indices = Object.keys(data).map(Number);
const values = Object.values(data);
// detect the format of the sparse vector
const type = getSparseFloatVectorType(data);

let indices: number[];
let values: number[];

switch (type) {
case 'array':
indices = Object.keys(data).map(Number);
values = Object.values(data);
break;
case 'coo':
indices = Object.values(
(data as SparseVectorCOO).map((item: any) => item.index)
);
values = Object.values(
(data as SparseVectorCOO).map((item: any) => item.value)
);
break;
case 'csr':
indices = (data as SparseVectorCSR).indices;
values = (data as SparseVectorCSR).values;
break;
case 'dict':
indices = Object.keys(data).map(Number);
values = Object.values(data);
break;
}

// create a buffer to store the bytes
const bytes = new Uint8Array(8 * indices.length);

// loop through the indices and values and add them to the buffer
for (let i = 0; i < indices.length; i++) {
const index = indices[i];
const value = values[i];
Expand All @@ -73,9 +124,7 @@ export const parseSparseVectorToBytes = (
`Sparse vector index must be positive and less than 2^32-1: ${index}`
);
}
if (isNaN(value)) {
throw new Error('Sparse vector value must not be NaN');
}

const indexBytes = new Uint32Array([index]);
const valueBytes = new Float32Array([value]);
bytes.set(new Uint8Array(indexBytes.buffer), i * 8);
Expand Down Expand Up @@ -115,7 +164,9 @@ export const parseBufferToSparseRow = (
for (let i = 0; i < bufferData.length; i += 8) {
const key: string = bufferData.readUInt32LE(i).toString();
const value: number = bufferData.readFloatLE(i + 4);
result[key] = value;
if (value) {
result[key] = value;
}
}
return result;
};
Expand Down
131 changes: 131 additions & 0 deletions test/grpc/SparseVector.array.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import {
MilvusClient,
ErrorCode,
DataType,
IndexType,
MetricType,
} from '../../milvus';
import {
IP,
genCollectionParams,
GENERATE_NAME,
generateInsertData,
} from '../tools';

const milvusClient = new MilvusClient({ address: IP, logLevel: 'info' });
const COLLECTION_NAME = GENERATE_NAME();

const dbParam = {
db_name: 'sparse_array_vector_DB',
};

const p = {
collectionName: COLLECTION_NAME,
vectorType: [DataType.SparseFloatVector],
dim: [24], // useless
};
const collectionParams = genCollectionParams(p);
const data = generateInsertData(collectionParams.fields, 10, {
sparseType: 'array',
});

describe(`Sparse vectors type:object API testing`, () => {
beforeAll(async () => {
await milvusClient.createDatabase(dbParam);
await milvusClient.use(dbParam);
});

afterAll(async () => {
await milvusClient.dropCollection({ collection_name: COLLECTION_NAME });
await milvusClient.dropDatabase(dbParam);
});

it(`Create collection with sparse vectors should be successful`, async () => {
const create = await milvusClient.createCollection(collectionParams);
expect(create.error_code).toEqual(ErrorCode.SUCCESS);

const describe = await milvusClient.describeCollection({
collection_name: COLLECTION_NAME,
});

const sparseFloatVectorFields = describe.schema.fields.filter(
(field: any) => field.data_type === 'SparseFloatVector'
);
expect(sparseFloatVectorFields.length).toBe(1);

// console.dir(describe.schema, { depth: null });
});

it(`insert sparse vector data should be successful`, async () => {
const insert = await milvusClient.insert({
collection_name: COLLECTION_NAME,
data,
});

// console.log('data to insert', data);

expect(insert.status.error_code).toEqual(ErrorCode.SUCCESS);
expect(insert.succ_index.length).toEqual(data.length);
});

it(`create index should be successful`, async () => {
const indexes = await milvusClient.createIndex([
{
collection_name: COLLECTION_NAME,
field_name: 'vector',
metric_type: MetricType.IP,
index_type: IndexType.SPARSE_WAND,
params: {
drop_ratio_build: 0.2,
},
},
]);

expect(indexes.error_code).toEqual(ErrorCode.SUCCESS);
});

it(`load collection should be successful`, async () => {
const load = await milvusClient.loadCollection({
collection_name: COLLECTION_NAME,
});

expect(load.error_code).toEqual(ErrorCode.SUCCESS);
});

it(`query sparse vector should be successful`, async () => {
const query = await milvusClient.query({
collection_name: COLLECTION_NAME,
filter: 'id > 0',
output_fields: ['vector', 'id'],
});

// console.dir(query, { depth: null });

const originKeys = Object.keys(query.data[0].vector);
const originValues = Object.values(query.data[0].vector);

const outputKeys: string[] = Object.keys(query.data[0].vector);
const outputValues: number[] = Object.values(query.data[0].vector);

expect(originKeys).toEqual(outputKeys);

// filter undefined in originValues
originValues.forEach((value, index) => {
if (value) {
expect(value).toBeCloseTo(outputValues[index]);
}
});
});

it(`search with sparse vector should be successful`, async () => {
const search = await milvusClient.search({
vector: data[0].vector,
collection_name: COLLECTION_NAME,
output_fields: ['id', 'vector'],
limit: 5,
});

expect(search.status.error_code).toEqual(ErrorCode.SUCCESS);
expect(search.results.length).toBeGreaterThan(0);
});
});
Loading

0 comments on commit 611a550

Please sign in to comment.