Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support f16 & bf16 #287

Merged
merged 13 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions milvus/grpc/Data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import {
LoadBalanceReq,
ImportReq,
ListImportTasksReq,
// ListIndexedSegmentReq,
// DescribeSegmentIndexDataReq,
ErrorCode,
FlushResult,
GetFlushStateResponse,
Expand All @@ -27,8 +25,6 @@ import {
SearchResults,
ImportResponse,
ListImportTasksResponse,
// ListIndexedSegmentResponse,
// DescribeSegmentIndexDataResponse,
GetMetricsRequest,
QueryReq,
GetReq,
Expand All @@ -45,23 +41,24 @@ import {
sleep,
parseToKeyValue,
checkCollectionName,
checkSearchParams,
parseBinaryVectorToBytes,
DEFAULT_DYNAMIC_FIELD,
buildDynamicRow,
buildFieldDataMap,
getDataKey,
Field,
buildFieldData,
VectorTypes,
BinaryVectors,
BinaryVector,
RowData,
CountReq,
CountResult,
DEFAULT_COUNT_QUERY_STRING,
SparseFloatVectors,
SparseFloatVector,
parseSparseRowsToBytes,
getSparseDim,
parseBinaryVectorToBytes,
parseFloat16VectorToBytes,
Float16Vector,
} from '../';
import { Collection } from './Collection';

Expand All @@ -70,6 +67,7 @@ export class Data extends Collection {
vectorTypes = [
DataType.BinaryVector,
DataType.FloatVector,
DataType.Float16Vector,
DataType.SparseFloatVector,
];

Expand Down Expand Up @@ -194,7 +192,9 @@ export class Data extends Collection {
switch (DataTypeMap[field.type]) {
case DataType.BinaryVector:
case DataType.FloatVector:
field.data = field.data.concat(buildFieldData(rowData, field));
field.data = (field.data as number[]).concat(
buildFieldData(rowData, field) as number[]
);
break;
default:
field.data[rowIndex] = buildFieldData(rowData, field);
Expand Down Expand Up @@ -227,25 +227,30 @@ export class Data extends Collection {
},
};
break;
case DataType.Float16Vector:
keyValue = {
dim: field.dim,
[dataKey]: Buffer.concat(field.data as Buffer[]),
};
break;
case DataType.BinaryVector:
keyValue = {
dim: field.dim,
[dataKey]: parseBinaryVectorToBytes(field.data as BinaryVectors),
[dataKey]: parseBinaryVectorToBytes(field.data as BinaryVector),
};
break;
case DataType.SparseFloatVector:
const dim = getSparseDim(field.data as SparseFloatVectors[]);
const dim = getSparseDim(field.data as SparseFloatVector[]);
keyValue = {
dim,
[dataKey]: {
dim,
contents: parseSparseRowsToBytes(
field.data as SparseFloatVectors[]
field.data as SparseFloatVector[]
),
},
};
break;

case DataType.Array:
keyValue = {
[dataKey]: {
Expand Down
19 changes: 12 additions & 7 deletions milvus/types/Data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@ import {
} from '../';

// all value types supported by milvus
export type FloatVectors = number[];
export type BinaryVectors = number[];
export type SparseFloatVectors = { [key: string]: number };
export type VectorTypes = FloatVectors | BinaryVectors | SparseFloatVectors;
export type FloatVector = number[];
export type Float16Vector = number[];
export type BFloat16Vector = number[];
export type BinaryVector = number[];
export type SparseFloatVector = { [key: string]: number };
export type VectorTypes =
| FloatVector
| Float16Vector
| BinaryVector
| BFloat16Vector
| SparseFloatVector;
export type Bool = boolean;
export type Int8 = number;
export type Int16 = number;
Expand Down Expand Up @@ -50,9 +57,7 @@ export type FieldData =
| VarChar
| JSON
| Array
| VectorTypes
| FloatVectors
| BinaryVectors;
| VectorTypes;

// Represents a row of data in Milvus.
export interface RowData {
Expand Down
4 changes: 2 additions & 2 deletions milvus/types/Http.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { FloatVectors } from '..';
import { FloatVector } from '..';
type Fetch = (input: any, init?: any) => Promise<any>;

// Class types
Expand Down Expand Up @@ -137,7 +137,7 @@ export interface HttpVectorQueryResponse
// search request
export interface HttpVectorSearchReq
extends Omit<HttpVectorQueryReq, 'filter'> {
vector: FloatVectors;
vector: FloatVector;
filter?: string;
}

Expand Down
55 changes: 37 additions & 18 deletions milvus/utils/Bytes.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import { Root } from 'protobufjs';
import { Float16Array } from '@petamoriken/float16';
import {
FloatVectors,
BinaryVectors,
SparseFloatVectors,
FloatVector,
BinaryVector,
SparseFloatVector,
DataType,
VectorTypes,
Float16Vector,
} from '..';

/**
* Converts a float vector into bytes format.
*
* @param {FloatVectors} array - The float vector to convert.
* @param {FloatVector} array - The float vector to convert.
*
* @returns {Buffer} Bytes representing the float vector.
*/
export const parseFloatVectorToBytes = (array: FloatVectors) => {
export const parseFloatVectorToBytes = (array: FloatVector) => {
// create array buffer
const a = new Float32Array(array);
// need return bytes to milvus proto
Expand All @@ -24,27 +26,40 @@ export const parseFloatVectorToBytes = (array: FloatVectors) => {
/**
* Converts a binary vector into bytes format.
*
* @param {BinaryVectors} array - The binary vector to convert.
* @param {BinaryVector} array - The binary vector to convert.
*
* @returns {Buffer} Bytes representing the binary vector.
*/
export const parseBinaryVectorToBytes = (array: BinaryVectors) => {
// create array buffer
export const parseBinaryVectorToBytes = (array: BinaryVector) => {
const a = new Uint8Array(array);
// need return bytes to milvus proto
return Buffer.from(a.buffer);
};

export const parseFloat16VectorToBytes = (f16Array: Float16Vector) => {
const float16Bytes = new Float16Array(f16Array);
return Buffer.from(float16Bytes.buffer);
};

export const parseBytesToFloat16Vector = (float16Bytes: Uint8Array) => {
const buffer = new ArrayBuffer(float16Bytes.length);
const view = new Uint8Array(buffer);
view.set(float16Bytes);

const float16Array = new Float16Array(buffer);
return Array.from(float16Array);
};

/**
* Converts a sparse float vector into bytes format.
*
* @param {SparseFloatVectors} data - The sparse float vector to convert.
* @param {SparseFloatVector} data - The sparse float vector to convert.
*
* @returns {Uint8Array} Bytes representing the sparse float vector.
* @throws {Error} If the length of indices and values is not the same, or if the index is not within the valid range, or if the value is NaN.
*/
export const parseSparseVectorToBytes = (
data: SparseFloatVectors
data: SparseFloatVector
): Uint8Array => {
const indices = Object.keys(data).map(Number);
const values = Object.values(data);
Expand Down Expand Up @@ -72,12 +87,12 @@ export const parseSparseVectorToBytes = (
/**
* Converts an array of sparse float vectors into an array of bytes format.
*
* @param {SparseFloatVectors[]} data - The array of sparse float vectors to convert.
* @param {SparseFloatVector[]} data - The array of sparse float vectors to convert.
*
* @returns {Uint8Array[]} An array of bytes representing the sparse float vectors.
*/
export const parseSparseRowsToBytes = (
data: SparseFloatVectors[]
data: SparseFloatVector[]
): Uint8Array[] => {
const result: Uint8Array[] = [];
for (const row of data) {
Expand All @@ -91,12 +106,12 @@ export const parseSparseRowsToBytes = (
*
* @param {Buffer} bufferData - The buffer data to parse.
*
* @returns {SparseFloatVectors} The parsed sparse float vectors.
* @returns {SparseFloatVector} The parsed sparse float vectors.
*/
export const parseBufferToSparseRow = (
bufferData: Buffer
): SparseFloatVectors => {
const result: SparseFloatVectors = {};
): SparseFloatVector => {
const result: SparseFloatVector = {};
for (let i = 0; i < bufferData.length; i += 8) {
const key: string = bufferData.readUInt32LE(i).toString();
const value: number = bufferData.readFloatLE(i + 4);
Expand Down Expand Up @@ -124,14 +139,18 @@ export const buildPlaceholderGroupBytes = (
// parse vectors to bytes
switch (vectorDataType) {
case DataType.FloatVector:
bytes = vectors.map(v => parseFloatVectorToBytes(v as FloatVectors));
bytes = vectors.map(v => parseFloatVectorToBytes(v as FloatVector));
break;
case DataType.BinaryVector:
bytes = vectors.map(v => parseBinaryVectorToBytes(v as BinaryVectors));
bytes = vectors.map(v => parseBinaryVectorToBytes(v as BinaryVector));
break;
case DataType.Float16Vector:
case DataType.BFloat16Vector:
bytes = vectors.map(v => parseFloat16VectorToBytes(v as Float16Vector));
break;
case DataType.SparseFloatVector:
bytes = vectors.map(v =>
parseSparseVectorToBytes(v as SparseFloatVectors)
parseSparseVectorToBytes(v as SparseFloatVector)
);

break;
Expand Down
18 changes: 18 additions & 0 deletions milvus/utils/Format.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ import {
RerankerObj,
parseBufferToSparseRow,
buildPlaceholderGroupBytes,
parseBytesToFloat16Vector,
parseFloat16VectorToBytes,
Float16Vector,
} from '../';

/**
Expand Down Expand Up @@ -427,6 +430,18 @@ export const buildFieldDataMap = (fields_data: any[]) => {
});
break;

case 'float16_vector':
case 'bfloat16_vector':
field_data = [];
const f16Dim = Number(item.vectors!.dim) * 2; // float16 is 2 bytes, so we need to multiply dim with 2 = one element length
const f16Bytes = item.vectors![dataKey]!;

// split buffer data to float16 vector(bytes)
for (let i = 0; i < f16Bytes.byteLength; i += f16Dim) {
const slice = f16Bytes.slice(i, i + f16Dim);
field_data.push(parseBytesToFloat16Vector(slice));
}
break;
case 'sparse_float_vector':
const sparseVectorValue = item.vectors![dataKey]!.contents;
field_data = [];
Expand Down Expand Up @@ -509,6 +524,9 @@ export const buildFieldData = (rowData: RowData, field: Field): FieldData => {
case DataType.BinaryVector:
case DataType.FloatVector:
return rowData[name];
case DataType.Float16Vector:
case DataType.BFloat16Vector:
return parseFloat16VectorToBytes(rowData[name] as Float16Vector);
case DataType.JSON:
return Buffer.from(JSON.stringify(rowData[name] || {}));
case DataType.Array:
Expand Down
10 changes: 8 additions & 2 deletions milvus/utils/Function.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { KeyValuePair, DataType, ERROR_REASONS, SparseFloatVectors } from '../';
import { KeyValuePair, DataType, ERROR_REASONS, SparseFloatVector } from '../';
import { Pool } from 'generic-pool';

/**
Expand Down Expand Up @@ -89,6 +89,12 @@ export const getDataKey = (type: DataType, camelCase: boolean = false) => {
case DataType.FloatVector:
dataKey = 'float_vector';
break;
case DataType.Float16Vector:
dataKey = 'float16_vector';
break;
case DataType.BFloat16Vector:
dataKey = 'bfloat16_vector';
break;
case DataType.BinaryVector:
dataKey = 'binary_vector';
break;
Expand Down Expand Up @@ -134,7 +140,7 @@ export const getDataKey = (type: DataType, camelCase: boolean = false) => {
};

// get biggest size of sparse vector array
export const getSparseDim = (data: SparseFloatVectors[]) => {
export const getSparseDim = (data: SparseFloatVector[]) => {
let dim = 0;
for (const row of data) {
const indices = Object.keys(row).map(Number);
Expand Down
10 changes: 3 additions & 7 deletions milvus/utils/Validate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ import { status as grpcStatus } from '@grpc/grpc-js';
* @param fields
*/
export const checkCollectionFields = (fields: FieldType[]) => {
// Define arrays of data types that are allowed for vector fields and primary keys, respectively
const vectorDataTypes = [
DataType.BinaryVector,
DataType.FloatVector,
DataType.SparseFloatVector,
];
const int64VarCharTypes = [DataType.Int64, DataType.VarChar];

let hasPrimaryKey = false;
Expand Down Expand Up @@ -60,7 +54,7 @@ export const checkCollectionFields = (fields: FieldType[]) => {
}

// if this is the vector field, check dimension
const isVectorField = vectorDataTypes.includes(dataType!);
const isVectorField = isVectorType(dataType!);
const typeParams = field.type_params;
if (isVectorField) {
const dim = Number(typeParams?.dim ?? field.dim);
Expand Down Expand Up @@ -219,6 +213,8 @@ export const isVectorType = (type: DataType) => {
return (
type === DataType.BinaryVector ||
type === DataType.FloatVector ||
type === DataType.Float16Vector ||
type === DataType.BFloat16Vector ||
type === DataType.SparseFloatVector
);
};
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
},
"devDependencies": {
"@babel/plugin-transform-modules-commonjs": "^7.21.5",
"@petamoriken/float16": "^3.8.6",
"@types/jest": "^29.5.1",
"@types/node-fetch": "^2.6.8",
"jest": "^29.5.0",
Expand Down
Loading
Loading