Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more types of sparse vectors #293

Merged
merged 8 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions milvus/grpc/Data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ import {
parseSparseRowsToBytes,
getSparseDim,
parseBinaryVectorToBytes,
parseFloat16VectorToBytes,
Float16Vector,
} from '../';
import { Collection } from './Collection';

Expand Down Expand Up @@ -183,7 +181,7 @@ export class Data extends Collection {
}
if (
DataTypeMap[field.type] === DataType.BinaryVector &&
(rowData[name] as VectorTypes).length !== field.dim! / 8
(rowData[name] as BinaryVector).length !== field.dim! / 8
) {
throw new Error(ERROR_REASONS.INSERT_CHECK_WRONG_DIM);
}
Expand Down
16 changes: 15 additions & 1 deletion milvus/types/Data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,21 @@ export type FloatVector = number[];
export type Float16Vector = number[];
export type BFloat16Vector = number[];
export type BinaryVector = number[];
export type SparseFloatVector = { [key: string]: number };
export type SparseVectorArray = (number | undefined)[];
export type SparseVectorDic = { [key: string]: number };
export type SparseVectorCSR = {
indices: number[];
values: number[];
};
export type SparseVectorCOO = { index: number; value: number }[];

export type SparseFloatVector =
| SparseVectorArray
| SparseVectorDic
| SparseVectorCSR
| SparseVectorCOO;

// export type SparseFloatVector = { [key: string]: number };
export type VectorTypes =
| FloatVector
| Float16Vector
Expand Down
63 changes: 57 additions & 6 deletions milvus/utils/Bytes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import {
DataType,
VectorTypes,
Float16Vector,
SparseVectorCSR,
SparseVectorCOO,
} from '..';

/**
Expand Down Expand Up @@ -50,6 +52,26 @@ export const parseBytesToFloat16Vector = (float16Bytes: Uint8Array) => {
return Array.from(float16Array);
};

/**
* Get SparseVector type.
* @param {SparseFloatVector} vector - The sparse float vector to convert.
*
* @returns string, 'array' | 'coo' | 'csr' | 'dict'
*/
export const getSparseFloatVectorType = (vector: SparseFloatVector) => {
if (Array.isArray(vector)) {
if (typeof vector[0] === 'number' || typeof vector[0] === 'undefined') {
return 'array';
} else {
return 'coo';
}
} else if ('indices' in vector && 'values' in vector) {
return 'csr';
} else {
return 'dict';
}
};

/**
* Converts a sparse float vector into bytes format.
*
Expand All @@ -61,10 +83,39 @@ export const parseBytesToFloat16Vector = (float16Bytes: Uint8Array) => {
export const parseSparseVectorToBytes = (
data: SparseFloatVector
): Uint8Array => {
const indices = Object.keys(data).map(Number);
const values = Object.values(data);
// detect the format of the sparse vector
const type = getSparseFloatVectorType(data);

let indices: number[];
let values: number[];

switch (type) {
case 'array':
indices = Object.keys(data).map(Number);
values = Object.values(data);
break;
case 'coo':
indices = Object.values(
(data as SparseVectorCOO).map((item: any) => item.index)
);
values = Object.values(
(data as SparseVectorCOO).map((item: any) => item.value)
);
break;
case 'csr':
indices = (data as SparseVectorCSR).indices;
values = (data as SparseVectorCSR).values;
break;
case 'dict':
indices = Object.keys(data).map(Number);
values = Object.values(data);
break;
}

// create a buffer to store the bytes
const bytes = new Uint8Array(8 * indices.length);

// loop through the indices and values and add them to the buffer
for (let i = 0; i < indices.length; i++) {
const index = indices[i];
const value = values[i];
Expand All @@ -73,9 +124,7 @@ export const parseSparseVectorToBytes = (
`Sparse vector index must be positive and less than 2^32-1: ${index}`
);
}
if (isNaN(value)) {
throw new Error('Sparse vector value must not be NaN');
}

const indexBytes = new Uint32Array([index]);
const valueBytes = new Float32Array([value]);
bytes.set(new Uint8Array(indexBytes.buffer), i * 8);
Expand Down Expand Up @@ -115,7 +164,9 @@ export const parseBufferToSparseRow = (
for (let i = 0; i < bufferData.length; i += 8) {
const key: string = bufferData.readUInt32LE(i).toString();
const value: number = bufferData.readFloatLE(i + 4);
result[key] = value;
if (value) {
result[key] = value;
}
}
return result;
};
Expand Down
131 changes: 131 additions & 0 deletions test/grpc/SparseVector.array.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import {
MilvusClient,
ErrorCode,
DataType,
IndexType,
MetricType,
} from '../../milvus';
import {
IP,
genCollectionParams,
GENERATE_NAME,
generateInsertData,
} from '../tools';

const milvusClient = new MilvusClient({ address: IP, logLevel: 'info' });
const COLLECTION_NAME = GENERATE_NAME();

const dbParam = {
db_name: 'sparse_array_vector_DB',
};

const p = {
collectionName: COLLECTION_NAME,
vectorType: [DataType.SparseFloatVector],
dim: [24], // useless
};
const collectionParams = genCollectionParams(p);
const data = generateInsertData(collectionParams.fields, 10, {
sparseType: 'array',
});

describe(`Sparse vectors type:object API testing`, () => {
beforeAll(async () => {
await milvusClient.createDatabase(dbParam);
await milvusClient.use(dbParam);
});

afterAll(async () => {
await milvusClient.dropCollection({ collection_name: COLLECTION_NAME });
await milvusClient.dropDatabase(dbParam);
});

it(`Create collection with sparse vectors should be successful`, async () => {
const create = await milvusClient.createCollection(collectionParams);
expect(create.error_code).toEqual(ErrorCode.SUCCESS);

const describe = await milvusClient.describeCollection({
collection_name: COLLECTION_NAME,
});

const sparseFloatVectorFields = describe.schema.fields.filter(
(field: any) => field.data_type === 'SparseFloatVector'
);
expect(sparseFloatVectorFields.length).toBe(1);

// console.dir(describe.schema, { depth: null });
});

it(`insert sparse vector data should be successful`, async () => {
const insert = await milvusClient.insert({
collection_name: COLLECTION_NAME,
data,
});

// console.log('data to insert', data);

expect(insert.status.error_code).toEqual(ErrorCode.SUCCESS);
expect(insert.succ_index.length).toEqual(data.length);
});

it(`create index should be successful`, async () => {
const indexes = await milvusClient.createIndex([
{
collection_name: COLLECTION_NAME,
field_name: 'vector',
metric_type: MetricType.IP,
index_type: IndexType.SPARSE_WAND,
params: {
drop_ratio_build: 0.2,
},
},
]);

expect(indexes.error_code).toEqual(ErrorCode.SUCCESS);
});

it(`load collection should be successful`, async () => {
const load = await milvusClient.loadCollection({
collection_name: COLLECTION_NAME,
});

expect(load.error_code).toEqual(ErrorCode.SUCCESS);
});

it(`query sparse vector should be successful`, async () => {
const query = await milvusClient.query({
collection_name: COLLECTION_NAME,
filter: 'id > 0',
output_fields: ['vector', 'id'],
});

// console.dir(query, { depth: null });

const originKeys = Object.keys(query.data[0].vector);
const originValues = Object.values(query.data[0].vector);

const outputKeys: string[] = Object.keys(query.data[0].vector);
const outputValues: number[] = Object.values(query.data[0].vector);

expect(originKeys).toEqual(outputKeys);

// filter undefined in originValues
originValues.forEach((value, index) => {
if (value) {
expect(value).toBeCloseTo(outputValues[index]);
}
});
});

it(`search with sparse vector should be successful`, async () => {
const search = await milvusClient.search({
vector: data[0].vector,
collection_name: COLLECTION_NAME,
output_fields: ['id', 'vector'],
limit: 5,
});

expect(search.status.error_code).toEqual(ErrorCode.SUCCESS);
expect(search.results.length).toBeGreaterThan(0);
});
});
Loading
Loading