diff --git a/milvus/grpc/Data.ts b/milvus/grpc/Data.ts index 23f8e6fb..49202540 100644 --- a/milvus/grpc/Data.ts +++ b/milvus/grpc/Data.ts @@ -11,8 +11,6 @@ import { LoadBalanceReq, ImportReq, ListImportTasksReq, - // ListIndexedSegmentReq, - // DescribeSegmentIndexDataReq, ErrorCode, FlushResult, GetFlushStateResponse, @@ -27,8 +25,6 @@ import { SearchResults, ImportResponse, ListImportTasksResponse, - // ListIndexedSegmentResponse, - // DescribeSegmentIndexDataResponse, GetMetricsRequest, QueryReq, GetReq, @@ -45,8 +41,6 @@ import { sleep, parseToKeyValue, checkCollectionName, - checkSearchParams, - parseBinaryVectorToBytes, DEFAULT_DYNAMIC_FIELD, buildDynamicRow, buildFieldDataMap, @@ -54,14 +48,17 @@ import { Field, buildFieldData, VectorTypes, - BinaryVectors, + BinaryVector, RowData, CountReq, CountResult, DEFAULT_COUNT_QUERY_STRING, - SparseFloatVectors, + SparseFloatVector, parseSparseRowsToBytes, getSparseDim, + parseBinaryVectorToBytes, + parseFloat16VectorToBytes, + Float16Vector, } from '../'; import { Collection } from './Collection'; @@ -70,6 +67,7 @@ export class Data extends Collection { vectorTypes = [ DataType.BinaryVector, DataType.FloatVector, + DataType.Float16Vector, DataType.SparseFloatVector, ]; @@ -194,7 +192,9 @@ export class Data extends Collection { switch (DataTypeMap[field.type]) { case DataType.BinaryVector: case DataType.FloatVector: - field.data = field.data.concat(buildFieldData(rowData, field)); + field.data = (field.data as number[]).concat( + buildFieldData(rowData, field) as number[] + ); break; default: field.data[rowIndex] = buildFieldData(rowData, field); @@ -227,25 +227,30 @@ export class Data extends Collection { }, }; break; + case DataType.Float16Vector: + keyValue = { + dim: field.dim, + [dataKey]: Buffer.concat(field.data as Buffer[]), + }; + break; case DataType.BinaryVector: keyValue = { dim: field.dim, - [dataKey]: parseBinaryVectorToBytes(field.data as BinaryVectors), + [dataKey]: parseBinaryVectorToBytes(field.data as BinaryVector), }; break; case DataType.SparseFloatVector: - const dim = getSparseDim(field.data as SparseFloatVectors[]); + const dim = getSparseDim(field.data as SparseFloatVector[]); keyValue = { dim, [dataKey]: { dim, contents: parseSparseRowsToBytes( - field.data as SparseFloatVectors[] + field.data as SparseFloatVector[] ), }, }; break; - case DataType.Array: keyValue = { [dataKey]: { diff --git a/milvus/types/Data.ts b/milvus/types/Data.ts index 5e2799fc..a8452ff6 100644 --- a/milvus/types/Data.ts +++ b/milvus/types/Data.ts @@ -14,10 +14,17 @@ import { } from '../'; // all value types supported by milvus -export type FloatVectors = number[]; -export type BinaryVectors = number[]; -export type SparseFloatVectors = { [key: string]: number }; -export type VectorTypes = FloatVectors | BinaryVectors | SparseFloatVectors; +export type FloatVector = number[]; +export type Float16Vector = number[]; +export type BFloat16Vector = number[]; +export type BinaryVector = number[]; +export type SparseFloatVector = { [key: string]: number }; +export type VectorTypes = + | FloatVector + | Float16Vector + | BinaryVector + | BFloat16Vector + | SparseFloatVector; export type Bool = boolean; export type Int8 = number; export type Int16 = number; @@ -50,9 +57,7 @@ export type FieldData = | VarChar | JSON | Array - | VectorTypes - | FloatVectors - | BinaryVectors; + | VectorTypes; // Represents a row of data in Milvus. export interface RowData { diff --git a/milvus/types/Http.ts b/milvus/types/Http.ts index 17b411e4..2fd57d70 100644 --- a/milvus/types/Http.ts +++ b/milvus/types/Http.ts @@ -1,4 +1,4 @@ -import { FloatVectors } from '..'; +import { FloatVector } from '..'; type Fetch = (input: any, init?: any) => Promise; // Class types @@ -137,7 +137,7 @@ export interface HttpVectorQueryResponse // search request export interface HttpVectorSearchReq extends Omit { - vector: FloatVectors; + vector: FloatVector; filter?: string; } diff --git a/milvus/utils/Bytes.ts b/milvus/utils/Bytes.ts index e10366c5..3905d70a 100644 --- a/milvus/utils/Bytes.ts +++ b/milvus/utils/Bytes.ts @@ -1,20 +1,22 @@ import { Root } from 'protobufjs'; +import { Float16Array } from '@petamoriken/float16'; import { - FloatVectors, - BinaryVectors, - SparseFloatVectors, + FloatVector, + BinaryVector, + SparseFloatVector, DataType, VectorTypes, + Float16Vector, } from '..'; /** * Converts a float vector into bytes format. * - * @param {FloatVectors} array - The float vector to convert. + * @param {FloatVector} array - The float vector to convert. * * @returns {Buffer} Bytes representing the float vector. */ -export const parseFloatVectorToBytes = (array: FloatVectors) => { +export const parseFloatVectorToBytes = (array: FloatVector) => { // create array buffer const a = new Float32Array(array); // need return bytes to milvus proto @@ -24,27 +26,40 @@ export const parseFloatVectorToBytes = (array: FloatVectors) => { /** * Converts a binary vector into bytes format. * - * @param {BinaryVectors} array - The binary vector to convert. + * @param {BinaryVector} array - The binary vector to convert. * * @returns {Buffer} Bytes representing the binary vector. */ -export const parseBinaryVectorToBytes = (array: BinaryVectors) => { - // create array buffer +export const parseBinaryVectorToBytes = (array: BinaryVector) => { const a = new Uint8Array(array); // need return bytes to milvus proto return Buffer.from(a.buffer); }; +export const parseFloat16VectorToBytes = (f16Array: Float16Vector) => { + const float16Bytes = new Float16Array(f16Array); + return Buffer.from(float16Bytes.buffer); +}; + +export const parseBytesToFloat16Vector = (float16Bytes: Uint8Array) => { + const buffer = new ArrayBuffer(float16Bytes.length); + const view = new Uint8Array(buffer); + view.set(float16Bytes); + + const float16Array = new Float16Array(buffer); + return Array.from(float16Array); +}; + /** * Converts a sparse float vector into bytes format. * - * @param {SparseFloatVectors} data - The sparse float vector to convert. + * @param {SparseFloatVector} data - The sparse float vector to convert. * * @returns {Uint8Array} Bytes representing the sparse float vector. * @throws {Error} If the length of indices and values is not the same, or if the index is not within the valid range, or if the value is NaN. */ export const parseSparseVectorToBytes = ( - data: SparseFloatVectors + data: SparseFloatVector ): Uint8Array => { const indices = Object.keys(data).map(Number); const values = Object.values(data); @@ -72,12 +87,12 @@ export const parseSparseVectorToBytes = ( /** * Converts an array of sparse float vectors into an array of bytes format. * - * @param {SparseFloatVectors[]} data - The array of sparse float vectors to convert. + * @param {SparseFloatVector[]} data - The array of sparse float vectors to convert. * * @returns {Uint8Array[]} An array of bytes representing the sparse float vectors. */ export const parseSparseRowsToBytes = ( - data: SparseFloatVectors[] + data: SparseFloatVector[] ): Uint8Array[] => { const result: Uint8Array[] = []; for (const row of data) { @@ -91,12 +106,12 @@ export const parseSparseRowsToBytes = ( * * @param {Buffer} bufferData - The buffer data to parse. * - * @returns {SparseFloatVectors} The parsed sparse float vectors. + * @returns {SparseFloatVector} The parsed sparse float vectors. */ export const parseBufferToSparseRow = ( bufferData: Buffer -): SparseFloatVectors => { - const result: SparseFloatVectors = {}; +): SparseFloatVector => { + const result: SparseFloatVector = {}; for (let i = 0; i < bufferData.length; i += 8) { const key: string = bufferData.readUInt32LE(i).toString(); const value: number = bufferData.readFloatLE(i + 4); @@ -124,14 +139,18 @@ export const buildPlaceholderGroupBytes = ( // parse vectors to bytes switch (vectorDataType) { case DataType.FloatVector: - bytes = vectors.map(v => parseFloatVectorToBytes(v as FloatVectors)); + bytes = vectors.map(v => parseFloatVectorToBytes(v as FloatVector)); break; case DataType.BinaryVector: - bytes = vectors.map(v => parseBinaryVectorToBytes(v as BinaryVectors)); + bytes = vectors.map(v => parseBinaryVectorToBytes(v as BinaryVector)); + break; + case DataType.Float16Vector: + case DataType.BFloat16Vector: + bytes = vectors.map(v => parseFloat16VectorToBytes(v as Float16Vector)); break; case DataType.SparseFloatVector: bytes = vectors.map(v => - parseSparseVectorToBytes(v as SparseFloatVectors) + parseSparseVectorToBytes(v as SparseFloatVector) ); break; diff --git a/milvus/utils/Format.ts b/milvus/utils/Format.ts index fc9ba51e..fece793d 100644 --- a/milvus/utils/Format.ts +++ b/milvus/utils/Format.ts @@ -31,6 +31,9 @@ import { RerankerObj, parseBufferToSparseRow, buildPlaceholderGroupBytes, + parseBytesToFloat16Vector, + parseFloat16VectorToBytes, + Float16Vector, } from '../'; /** @@ -427,6 +430,18 @@ export const buildFieldDataMap = (fields_data: any[]) => { }); break; + case 'float16_vector': + case 'bfloat16_vector': + field_data = []; + const f16Dim = Number(item.vectors!.dim) * 2; // float16 is 2 bytes, so we need to multiply dim with 2 = one element length + const f16Bytes = item.vectors![dataKey]!; + + // split buffer data to float16 vector(bytes) + for (let i = 0; i < f16Bytes.byteLength; i += f16Dim) { + const slice = f16Bytes.slice(i, i + f16Dim); + field_data.push(parseBytesToFloat16Vector(slice)); + } + break; case 'sparse_float_vector': const sparseVectorValue = item.vectors![dataKey]!.contents; field_data = []; @@ -509,6 +524,9 @@ export const buildFieldData = (rowData: RowData, field: Field): FieldData => { case DataType.BinaryVector: case DataType.FloatVector: return rowData[name]; + case DataType.Float16Vector: + case DataType.BFloat16Vector: + return parseFloat16VectorToBytes(rowData[name] as Float16Vector); case DataType.JSON: return Buffer.from(JSON.stringify(rowData[name] || {})); case DataType.Array: diff --git a/milvus/utils/Function.ts b/milvus/utils/Function.ts index 9ca70c96..61951d30 100644 --- a/milvus/utils/Function.ts +++ b/milvus/utils/Function.ts @@ -1,4 +1,4 @@ -import { KeyValuePair, DataType, ERROR_REASONS, SparseFloatVectors } from '../'; +import { KeyValuePair, DataType, ERROR_REASONS, SparseFloatVector } from '../'; import { Pool } from 'generic-pool'; /** @@ -89,6 +89,12 @@ export const getDataKey = (type: DataType, camelCase: boolean = false) => { case DataType.FloatVector: dataKey = 'float_vector'; break; + case DataType.Float16Vector: + dataKey = 'float16_vector'; + break; + case DataType.BFloat16Vector: + dataKey = 'bfloat16_vector'; + break; case DataType.BinaryVector: dataKey = 'binary_vector'; break; @@ -134,7 +140,7 @@ export const getDataKey = (type: DataType, camelCase: boolean = false) => { }; // get biggest size of sparse vector array -export const getSparseDim = (data: SparseFloatVectors[]) => { +export const getSparseDim = (data: SparseFloatVector[]) => { let dim = 0; for (const row of data) { const indices = Object.keys(row).map(Number); diff --git a/milvus/utils/Validate.ts b/milvus/utils/Validate.ts index bb2d5071..48d9440d 100644 --- a/milvus/utils/Validate.ts +++ b/milvus/utils/Validate.ts @@ -21,12 +21,6 @@ import { status as grpcStatus } from '@grpc/grpc-js'; * @param fields */ export const checkCollectionFields = (fields: FieldType[]) => { - // Define arrays of data types that are allowed for vector fields and primary keys, respectively - const vectorDataTypes = [ - DataType.BinaryVector, - DataType.FloatVector, - DataType.SparseFloatVector, - ]; const int64VarCharTypes = [DataType.Int64, DataType.VarChar]; let hasPrimaryKey = false; @@ -60,7 +54,7 @@ export const checkCollectionFields = (fields: FieldType[]) => { } // if this is the vector field, check dimension - const isVectorField = vectorDataTypes.includes(dataType!); + const isVectorField = isVectorType(dataType!); const typeParams = field.type_params; if (isVectorField) { const dim = Number(typeParams?.dim ?? field.dim); @@ -219,6 +213,8 @@ export const isVectorType = (type: DataType) => { return ( type === DataType.BinaryVector || type === DataType.FloatVector || + type === DataType.Float16Vector || + type === DataType.BFloat16Vector || type === DataType.SparseFloatVector ); }; diff --git a/package.json b/package.json index 262321c2..66d45a57 100644 --- a/package.json +++ b/package.json @@ -30,6 +30,7 @@ }, "devDependencies": { "@babel/plugin-transform-modules-commonjs": "^7.21.5", + "@petamoriken/float16": "^3.8.6", "@types/jest": "^29.5.1", "@types/node-fetch": "^2.6.8", "jest": "^29.5.0", diff --git a/test/grpc/Float16Vector.spec.ts b/test/grpc/Float16Vector.spec.ts new file mode 100644 index 00000000..79b4640d --- /dev/null +++ b/test/grpc/Float16Vector.spec.ts @@ -0,0 +1,129 @@ +import { + MilvusClient, + ErrorCode, + DataType, + IndexType, + MetricType, + parseBytesToFloat16Vector, +} from '../../milvus'; +import { + IP, + genCollectionParams, + GENERATE_NAME, + generateInsertData, +} from '../tools'; + +const milvusClient = new MilvusClient({ address: IP, logLevel: 'info' }); +const COLLECTION_NAME = GENERATE_NAME(); + +const dbParam = { + db_name: 'float_vector_16', +}; + +const p = { + collectionName: COLLECTION_NAME, + vectorType: [DataType.Float16Vector], + dim: [8], +}; +const collectionParams = genCollectionParams(p); +const data = generateInsertData(collectionParams.fields, 2); + +// console.log('data to insert', data); + +describe(`Float16 vector API testing`, () => { + beforeAll(async () => { + await milvusClient.createDatabase(dbParam); + await milvusClient.use(dbParam); + }); + + afterAll(async () => { + await milvusClient.dropCollection({ collection_name: COLLECTION_NAME }); + await milvusClient.dropDatabase(dbParam); + }); + + it(`Create collection with float16 vectors should be successful`, async () => { + const create = await milvusClient.createCollection(collectionParams); + expect(create.error_code).toEqual(ErrorCode.SUCCESS); + + const describe = await milvusClient.describeCollection({ + collection_name: COLLECTION_NAME, + }); + + const floatVector16Fields = describe.schema.fields.filter( + (field: any) => field.data_type === 'Float16Vector' + ); + expect(floatVector16Fields.length).toBe(1); + + // console.dir(describe.schema, { depth: null }); + }); + + it(`insert flaot16 vector data should be successful`, async () => { + const insert = await milvusClient.insert({ + collection_name: COLLECTION_NAME, + data, + }); + + // console.log(' insert', insert); + + expect(insert.status.error_code).toEqual(ErrorCode.SUCCESS); + expect(insert.succ_index.length).toEqual(data.length); + }); + + it(`create index should be successful`, async () => { + const indexes = await milvusClient.createIndex([ + { + collection_name: COLLECTION_NAME, + field_name: 'vector', + metric_type: MetricType.L2, + index_type: IndexType.AUTOINDEX, + }, + ]); + + expect(indexes.error_code).toEqual(ErrorCode.SUCCESS); + }); + + it(`load collection should be successful`, async () => { + const load = await milvusClient.loadCollection({ + collection_name: COLLECTION_NAME, + }); + + expect(load.error_code).toEqual(ErrorCode.SUCCESS); + }); + + it(`query float16 vector should be successful`, async () => { + const count = await milvusClient.count({ + collection_name: COLLECTION_NAME, + }); + + expect(count.data).toEqual(data.length); + + const query = await milvusClient.query({ + collection_name: COLLECTION_NAME, + filter: 'id > 0', + output_fields: ['vector', 'id'], + }); + + // verify the query result + data.forEach((obj, index) => { + obj.vector.forEach((v: number, i: number) => { + expect(v).toBeCloseTo(query.data[index].vector[i], 3); + }); + }); + + expect(query.status.error_code).toEqual(ErrorCode.SUCCESS); + }); + + it(`search with float16 vector should be successful`, async () => { + const search = await milvusClient.search({ + vector: data[0].vector, + collection_name: COLLECTION_NAME, + output_fields: ['id', 'vector'], + limit: 5, + }); + + // console.log('search', search); + + expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); + expect(search.results.length).toBeGreaterThan(0); + }); +}); diff --git a/test/tools/data.ts b/test/tools/data.ts index 75267ac7..e98b4146 100644 --- a/test/tools/data.ts +++ b/test/tools/data.ts @@ -5,6 +5,7 @@ import { FieldType, } from '../../milvus'; import { MAX_LENGTH, P_KEY_VALUES } from './const'; +import { Float16Array } from '@petamoriken/float16'; import Long from 'long'; interface DataGenerator { @@ -161,6 +162,14 @@ export const genSparseVector: DataGenerator = params => { return vector; }; +export const genFloat16: DataGenerator = params => { + const float32Array = genFloatVector(params); + // console.log('origin float32array', float32Array); + // const float16Array = new Float16Array(float32Array as number[]); + // const float16Bytes = new Uint8Array(float16Array.buffer); + return float32Array; +}; + export const dataGenMap: { [key in DataType]: DataGenerator } = { [DataType.None]: genNone, [DataType.Bool]: genBool, @@ -175,8 +184,8 @@ export const dataGenMap: { [key in DataType]: DataGenerator } = { [DataType.JSON]: genJSON, [DataType.BinaryVector]: genBinaryVector, [DataType.FloatVector]: genFloatVector, - [DataType.Float16Vector]: genFloatVector, // TODO - [DataType.BFloat16Vector]: genFloatVector, // TODO + [DataType.Float16Vector]: genFloat16, + [DataType.BFloat16Vector]: genFloat16, [DataType.SparseFloatVector]: genSparseVector, }; diff --git a/test/utils/Bytes.spec.ts b/test/utils/Bytes.spec.ts index cd19aa8f..e01da165 100644 --- a/test/utils/Bytes.spec.ts +++ b/test/utils/Bytes.spec.ts @@ -1,7 +1,7 @@ import { parseBufferToSparseRow, parseSparseRowsToBytes, - SparseFloatVectors, + SparseFloatVector, parseSparseVectorToBytes, } from '../../milvus'; @@ -29,7 +29,7 @@ describe('Sparse rows <-> Bytes conversion', () => { it('Conversion is reversible', () => { const inputSparseRows = [ { '12': 0.875, '17': 0.789, '19': 0.934 }, - ] as SparseFloatVectors[]; + ] as SparseFloatVector[]; const bytesArray = parseSparseRowsToBytes(inputSparseRows); diff --git a/test/utils/Function.spec.ts b/test/utils/Function.spec.ts index c86c9d75..793a6082 100644 --- a/test/utils/Function.spec.ts +++ b/test/utils/Function.spec.ts @@ -1,4 +1,10 @@ -import { promisify, getSparseDim, SparseFloatVectors } from '../../milvus'; +import { + promisify, + getSparseDim, + SparseFloatVector, + getDataKey, + DataType, +} from '../../milvus'; describe('promisify', () => { let pool: any; @@ -52,13 +58,13 @@ describe('promisify', () => { { '0': 1, '1': 2, '2': 3 }, { '0': 1, '1': 2, '2': 3, '3': 4 }, { '0': 1, '1': 2 }, - ] as SparseFloatVectors[]; + ] as SparseFloatVector[]; const result = getSparseDim(data); expect(result).toBe(4); }); it('should return 0 for an empty array', () => { - const data = [] as SparseFloatVectors[]; + const data = [] as SparseFloatVector[]; const result = getSparseDim(data); expect(result).toBe(0); }); @@ -68,8 +74,50 @@ describe('promisify', () => { { '0': 1, '1': 2, '2': 3, '3': 4, '4': 5 }, { '0': 1, '1': 2 }, { '0': 1, '1': 2, '2': 3, '3': 4 }, - ] as SparseFloatVectors[]; + ] as SparseFloatVector[]; const result = getSparseDim(data); expect(result).toBe(5); }); + + it('should return the correct data key for each data type without camel case conversion', () => { + expect(getDataKey(DataType.FloatVector)).toEqual('float_vector'); + expect(getDataKey(DataType.Float16Vector)).toEqual('float16_vector'); + expect(getDataKey(DataType.BFloat16Vector)).toEqual('bfloat16_vector'); + expect(getDataKey(DataType.BinaryVector)).toEqual('binary_vector'); + expect(getDataKey(DataType.SparseFloatVector)).toEqual( + 'sparse_float_vector' + ); + expect(getDataKey(DataType.Double)).toEqual('double_data'); + expect(getDataKey(DataType.Float)).toEqual('float_data'); + expect(getDataKey(DataType.Int64)).toEqual('long_data'); + expect(getDataKey(DataType.Int32)).toEqual('int_data'); + expect(getDataKey(DataType.Int16)).toEqual('int_data'); + expect(getDataKey(DataType.Int8)).toEqual('int_data'); + expect(getDataKey(DataType.Bool)).toEqual('bool_data'); + expect(getDataKey(DataType.VarChar)).toEqual('string_data'); + expect(getDataKey(DataType.Array)).toEqual('array_data'); + expect(getDataKey(DataType.JSON)).toEqual('json_data'); + expect(getDataKey(DataType.None)).toEqual('none'); + }); + + it('should return the correct data key for each data type with camel case conversion', () => { + expect(getDataKey(DataType.FloatVector, true)).toEqual('floatVector'); + expect(getDataKey(DataType.Float16Vector, true)).toEqual('float16Vector'); + expect(getDataKey(DataType.BFloat16Vector, true)).toEqual('bfloat16Vector'); + expect(getDataKey(DataType.BinaryVector, true)).toEqual('binaryVector'); + expect(getDataKey(DataType.SparseFloatVector, true)).toEqual( + 'sparseFloatVector' + ); + expect(getDataKey(DataType.Double, true)).toEqual('doubleData'); + expect(getDataKey(DataType.Float, true)).toEqual('floatData'); + expect(getDataKey(DataType.Int64, true)).toEqual('longData'); + expect(getDataKey(DataType.Int32, true)).toEqual('intData'); + expect(getDataKey(DataType.Int16, true)).toEqual('intData'); + expect(getDataKey(DataType.Int8, true)).toEqual('intData'); + expect(getDataKey(DataType.Bool, true)).toEqual('boolData'); + expect(getDataKey(DataType.VarChar, true)).toEqual('stringData'); + expect(getDataKey(DataType.Array, true)).toEqual('arrayData'); + expect(getDataKey(DataType.JSON, true)).toEqual('jsonData'); + expect(getDataKey(DataType.None, true)).toEqual('none'); + }); }); diff --git a/yarn.lock b/yarn.lock index 7e700986..f33c707a 100644 --- a/yarn.lock +++ b/yarn.lock @@ -698,6 +698,11 @@ "@jridgewell/resolve-uri" "3.1.0" "@jridgewell/sourcemap-codec" "1.4.14" +"@petamoriken/float16@^3.8.6": + version "3.8.6" + resolved "https://registry.yarnpkg.com/@petamoriken/float16/-/float16-3.8.6.tgz#580701cb97a510882342333d31c7cbfd9e14b4f4" + integrity sha512-GNJhABTtcmt9al/nqdJPycwFD46ww2+q2zwZzTjY0dFFwUAFRw9zszvEr9osyJRd9krRGy6hUDopWUg9fX7VVw== + "@protobufjs/aspromise@^1.1.1", "@protobufjs/aspromise@^1.1.2": version "1.1.2" resolved "https://registry.yarnpkg.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz#9b8b0cc663d669a7d8f6f5d0893a14d348f30fbf"