Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add transformers for bf16 and f16 data-type #303

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions milvus/grpc/Data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ import {
CountResult,
DEFAULT_COUNT_QUERY_STRING,
SparseFloatVector,
parseSparseRowsToBytes,
sparseRowsToBytes,
getSparseDim,
f32ArrayToBinaryBytes,
} from '../';
Expand All @@ -81,7 +81,8 @@ export class Data extends Collection {
* @param {InsertReq} data - The request parameters.
* @param {string} data.collection_name - The name of the collection.
* @param {string} [data.partition_name] - The name of the partition (optional).
* @param {{ [x: string]: any }[]} data.fields_data - The data to be inserted. If the field type is binary, the vector data length needs to be dimension / 8.
* @param {{ [x: string]: any }[]} data.data - The data to be inserted. If the field type is binary, the vector data length needs to be dimension / 8.
* @param {ToBytesTransformers} data.transformers - The transformers for bf16 or f16 data, it accept an f32 array, it should output f16 or bf16 bytes (optional)
* @param {number} [data.timeout] - An optional duration of time in milliseconds to allow for the RPC. If it is set to undefined, the client keeps waiting until the server responds or error occurs. Default is undefined.
*
* @returns {Promise<MutationResult>} The result of the operation.
Expand Down Expand Up @@ -187,7 +188,11 @@ export class Data extends Collection {
);
break;
default:
field.data[rowIndex] = buildFieldData(rowData, field);
field.data[rowIndex] = buildFieldData(
rowData,
field,
data.transformers
);
break;
}
});
Expand Down Expand Up @@ -236,9 +241,7 @@ export class Data extends Collection {
dim,
[dataKey]: {
dim,
contents: parseSparseRowsToBytes(
field.data as SparseFloatVector[]
),
contents: sparseRowsToBytes(field.data as SparseFloatVector[]),
},
};
break;
Expand Down Expand Up @@ -408,6 +411,9 @@ export class Data extends Collection {
* @param {string} [data.filter] - Scalar field filter expression (optional).
* @param {string[]} [data.output_fields] - Support scalar field (optional).
* @param {object} [data.params] - Search params (optional).
* @param {FromBytesTransformers} data.transformers - The transformers for bf16 or f16 data, it accept bytes or sparse dic vector, it can ouput f32 array or other format(optional)
* @param {number} [data.timeout] - An optional duration of time in milliseconds to allow for the RPC. If it is set to undefined, the client keeps waiting until the server responds or error occurs. Default is undefined.
*
* @returns {Promise<SearchResults>} The result of the operation.
* @returns {string} status.error_code - The error code of the operation.
* @returns {string} status.reason - The reason for the error, if any.
Expand Down Expand Up @@ -460,7 +466,10 @@ export class Data extends Collection {
}

// build final results array
const results = formatSearchResult(originSearchResult, { round_decimal });
const results = formatSearchResult(originSearchResult, {
round_decimal,
transformers: data.transformers,
});

return {
status: originSearchResult.status,
Expand Down Expand Up @@ -571,7 +580,8 @@ export class Data extends Collection {
* @param {string[]} [data.partitions_names] - Array of partition names (optional).
* @param {string[]} data.output_fields - Vector or scalar field to be returned.
* @param {number} [data.timeout] - An optional duration of time in millisecond to allow for the RPC. If it is set to undefined, the client keeps waiting until the server responds or error occurs. Default is undefined.
* @param {{key: value}[]} [data.params] - An optional key pair json array.
* @param {{key: value}[]} [data.params] - An optional key pair json array of search parameters.
* @param {FromBytesTransformers} data.transformers - The transformers for bf16 or f16 data, it accept bytes or sparse dic vector, it can ouput f32 array or other format(optional)
*
* @returns {Promise<QueryResults>} The result of the operation.
* @returns {string} status.error_code - The error code of the operation.
Expand Down Expand Up @@ -634,7 +644,10 @@ export class Data extends Collection {
// always get output_fields from fields_data
const output_fields = promise.fields_data.map(f => f.field_name);

const fieldsDataMap = buildFieldDataMap(promise.fields_data);
const fieldsDataMap = buildFieldDataMap(
promise.fields_data,
data.transformers
);

// For each output field, check if it has a fixed schema or not
const fieldDataContainer = output_fields.map(field_name => {
Expand Down
21 changes: 21 additions & 0 deletions milvus/types/Data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,20 @@ export interface CountReq extends collectionNameReq {
expr?: string; // filter expression
}

// because in javascript, there is no float16 and bfloat16 type
// we need to provide custom data transformer for these types
// milvus only accept bytes(buffer) for these types
export type ToBytesTransformers = {
[DataType.BFloat16Vector]?: (bf16: BFloat16Vector) => Buffer;
[DataType.Float16Vector]?: (f16: Float16Vector) => Buffer;
};

export interface InsertReq extends collectionNameReq {
partition_name?: string; // partition name
data?: RowData[]; // data to insert
fields_data?: RowData[]; // alias for data
hash_keys?: Number[]; // user can generate hash value depend on primarykey value
transformers?: ToBytesTransformers; // provide custom data transformer for specific data type like bf16 or f16 vectors
}

export interface DeleteEntitiesReq extends collectionNameReq {
Expand Down Expand Up @@ -272,6 +281,7 @@ export interface SearchReq extends collectionNameReq {
vector_type: DataType.BinaryVector | DataType.FloatVector; // vector field type
nq?: number; // number of query vectors
consistency_level?: ConsistencyLevelEnum; // consistency level
transformers?: FromBytesTransformers; // provide custom data transformer for specific data type like bf16 or f16 vectors
}

// simplified search api parameter type
Expand All @@ -293,6 +303,7 @@ export interface SearchSimpleReq extends collectionNameReq {
ignore_growing?: boolean; // ignore growing
group_by_field?: string; // group by field
round_decimal?: number; // round decimal
transformers?: FromBytesTransformers; // provide custom data transformer for specific data type like bf16 or f16 vectors
}

export type HybridSearchSingleReq = Pick<
Expand All @@ -302,6 +313,7 @@ export type HybridSearchSingleReq = Pick<
data: VectorTypes[] | VectorTypes; // vector to search
expr?: string; // filter expression
params?: keyValueObj; // extra search parameters
transformers?: FromBytesTransformers; // provide custom data transformer for specific data type like bf16 or f16 vectors
};

// rerank strategy and parameters
Expand Down Expand Up @@ -361,6 +373,14 @@ export interface SearchRes extends resStatusResponse {
};
}

// because in javascript, there is no float16 and bfloat16 type
// we need to provide custom data transformer for these types
export type FromBytesTransformers = {
[DataType.BFloat16Vector]?: (bf16bytes: Uint8Array) => BFloat16Vector;
[DataType.Float16Vector]?: (f16: Uint8Array) => Float16Vector;
[DataType.SparseFloatVector]?: (sparse: SparseVectorDic) => SparseFloatVector;
};

export interface QueryReq extends collectionNameReq {
output_fields?: string[]; // fields to return
partition_names?: string[]; // partition names
Expand All @@ -370,6 +390,7 @@ export interface QueryReq extends collectionNameReq {
offset?: number; // skip how many results
limit?: number; // how many results you want
consistency_level?: ConsistencyLevelEnum; // consistency level
transformers?: FromBytesTransformers; // provide custom data transformer for specific data type like bf16 or f16 vectors
}

export interface GetReq extends collectionNameReq {
Expand Down
34 changes: 15 additions & 19 deletions milvus/utils/Bytes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,40 +40,40 @@ export const f32ArrayToBinaryBytes = (array: BinaryVector) => {
/**
* Converts a float16 vector into bytes format.
*
* @param {Float16Vector} f16Array - The float16 vector to convert.
* @param {Float16Vector} array - The float16 vector(f32 format) to convert.
* @returns {Buffer} Bytes representing the float16 vector.
*/
export const f32ArrayToF16Bytes = (f16Array: Float16Vector) => {
const float16Bytes = new Float16Array(f16Array);
export const f32ArrayToF16Bytes = (array: Float16Vector) => {
const float16Bytes = new Float16Array(array);
return Buffer.from(float16Bytes.buffer);
};

/**
* Convert float16 bytes to float32 array.
* @param {Uint8Array} f16Bytes - The float16 bytes to convert.
* @returns {Float32Array} The float32 array.
* @returns {Array} The float32 array.
*/
export const f16BytesToF32Array = (f16Bytes: Uint8Array) => {
const buffer = new ArrayBuffer(f16Bytes.length);
const view = new Uint8Array(buffer);
view.set(f16Bytes);

const f16Array = new Float16Array(buffer);
return f16Array;
return Array.from(f16Array);
};

/**
* Convert float32 array to BFloat16 bytes.
* @param {BFloat16Vector} float32Array - The float32 array to convert.
* Convert float32 array to BFloat16 bytes, not a real conversion, just take the last 2 bytes of float32.
* @param {BFloat16Vector} array - The float32 array to convert.
* @returns {Buffer} The BFloat16 bytes.
*/
export const f32ArrayToBf16Bytes = (float32Array: BFloat16Vector) => {
const totalBytesNeeded = float32Array.length * 2; // 2 bytes per float32
export const f32ArrayToBf16Bytes = (array: BFloat16Vector) => {
const totalBytesNeeded = array.length * 2; // 2 bytes per float32
const buffer = new ArrayBuffer(totalBytesNeeded);
const bfloatView = new Uint8Array(buffer);

let byteIndex = 0;
float32Array.forEach(float32 => {
array.forEach(float32 => {
const floatBuffer = new ArrayBuffer(4);
const floatView = new Float32Array(floatBuffer);
const bfloatViewSingle = new Uint8Array(floatBuffer);
Expand All @@ -83,13 +83,13 @@ export const f32ArrayToBf16Bytes = (float32Array: BFloat16Vector) => {
byteIndex += 2;
});

return bfloatView;
return Buffer.from(bfloatView);
};

/**
* Convert BFloat16 bytes to Float32 array.
* @param {Uint8Array} bf16Bytes - The BFloat16 bytes to convert.
* @returns {float32Array} The Float32 array.
* @returns {Array} The Float32 array.
*/
export const bf16BytesToF32Array = (bf16Bytes: Uint8Array) => {
const float32Array: number[] = [];
Expand Down Expand Up @@ -150,7 +150,7 @@ export const getSparseFloatVectorType = (
/**
* Converts a sparse float vector into bytes format.
*
* @param {SparseFloatVector} data - The sparse float vector to convert.
* @param {SparseFloatVector} data - The sparse float vector to convert, support 'array' | 'coo' | 'csr' | 'dict'.
*
* @returns {Uint8Array} Bytes representing the sparse float vector.
* @throws {Error} If the length of indices and values is not the same, or if the index is not within the valid range, or if the value is NaN.
Expand Down Expand Up @@ -213,9 +213,7 @@ export const sparseToBytes = (data: SparseFloatVector): Uint8Array => {
*
* @returns {Uint8Array[]} An array of bytes representing the sparse float vectors.
*/
export const parseSparseRowsToBytes = (
data: SparseFloatVector[]
): Uint8Array[] => {
export const sparseRowsToBytes = (data: SparseFloatVector[]): Uint8Array[] => {
const result: Uint8Array[] = [];
for (const row of data) {
result.push(sparseToBytes(row));
Expand All @@ -230,9 +228,7 @@ export const parseSparseRowsToBytes = (
*
* @returns {SparseFloatVector} The parsed sparse float vectors.
*/
export const parseBufferToSparseRow = (
bufferData: Buffer
): SparseFloatVector => {
export const bytesToSparseRow = (bufferData: Buffer): SparseFloatVector => {
const result: SparseFloatVector = {};
for (let i = 0; i < bufferData.length; i += 8) {
const key: string = bufferData.readUInt32LE(i).toString();
Expand Down
60 changes: 43 additions & 17 deletions milvus/utils/Format.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ import {
isVectorType,
RANKER_TYPE,
RerankerObj,
parseBufferToSparseRow,
bytesToSparseRow,
buildPlaceholderGroupBytes,
f16BytesToF32Array,
f32ArrayToF16Bytes,
f32ArrayToBf16Bytes,
Float16Vector,
BFloat16Vector,
bf16BytesToF32Array,
getSparseFloatVectorType,
ToBytesTransformers,
FromBytesTransformers,
} from '../';

/**
Expand Down Expand Up @@ -398,7 +396,10 @@ export const buildDynamicRow = (
* If the field is a vector, split the data into chunks of the appropriate size.
* If the field is a scalar, decode the JSON/array data if necessary.
*/
export const buildFieldDataMap = (fields_data: any[]) => {
export const buildFieldDataMap = (
fields_data: any[],
transformers?: FromBytesTransformers
) => {
const fieldsDataMap = new Map<string, RowData[]>();

fields_data.forEach((item, i) => {
Expand Down Expand Up @@ -443,19 +444,31 @@ export const buildFieldDataMap = (fields_data: any[]) => {
// split buffer data to float16 vector(bytes)
for (let i = 0; i < f16Bytes.byteLength; i += f16Dim) {
const slice = f16Bytes.slice(i, i + f16Dim);
field_data.push(
dataKey == 'float16_vector'
? f16BytesToF32Array(slice)
: bf16BytesToF32Array(slice)
);
const isFloat16 = dataKey === 'float16_vector';
const isBFloat16 = dataKey === 'bfloat16_vector';
let dataType: DataType.BFloat16Vector | DataType.Float16Vector;

dataType = isFloat16
? DataType.Float16Vector
: DataType.BFloat16Vector;

if (
(isFloat16 || isBFloat16) &&
transformers &&
transformers[dataType]
) {
field_data.push(transformers[dataType]!(slice));
} else {
field_data.push(slice);
}
}
break;
case 'sparse_float_vector':
const sparseVectorValue = item.vectors![dataKey]!.contents;
field_data = [];

sparseVectorValue.forEach((buffer: any, i: number) => {
field_data[i] = parseBufferToSparseRow(buffer);
field_data[i] = bytesToSparseRow(buffer);
});
break;
default:
Expand Down Expand Up @@ -526,21 +539,33 @@ export const getAuthString = (data: {
* @param {Field} column - The column information.
* @returns {FieldData} The field data for the row and column.
*/
export const buildFieldData = (rowData: RowData, field: Field): FieldData => {
export const buildFieldData = (
rowData: RowData,
field: Field,
transformers?: ToBytesTransformers
): FieldData => {
const { type, elementType, name } = field;
switch (DataTypeMap[type]) {
case DataType.BinaryVector:
case DataType.FloatVector:
return rowData[name];
case DataType.BFloat16Vector:
return f32ArrayToBf16Bytes(rowData[name] as BFloat16Vector);
if (transformers && transformers[DataType.BFloat16Vector]) {
return transformers[DataType.BFloat16Vector](
rowData[name] as BFloat16Vector
);
}
case DataType.Float16Vector:
return f32ArrayToF16Bytes(rowData[name] as Float16Vector);
if (transformers && transformers[DataType.Float16Vector]) {
return transformers[DataType.Float16Vector](
rowData[name] as Float16Vector
);
}
case DataType.JSON:
return Buffer.from(JSON.stringify(rowData[name] || {}));
case DataType.Array:
const elementField = { ...field, type: elementType! };
return buildFieldData(rowData, elementField);
return buildFieldData(rowData, elementField, transformers);
default:
return rowData[name];
}
Expand Down Expand Up @@ -783,14 +808,15 @@ export const formatSearchResult = (
searchRes: SearchRes,
options: {
round_decimal: number;
transformers?: FromBytesTransformers;
}
) => {
const { round_decimal } = options;
// build final results array
const results: any[] = [];
const { topks, scores, fields_data, ids } = searchRes.results;
// build fields data map
const fieldsDataMap = buildFieldDataMap(fields_data);
const fieldsDataMap = buildFieldDataMap(fields_data, options.transformers);
// build output name array
const output_fields = [
'id',
Expand Down
Loading
Loading