Skip to content

Commit

Permalink
[Sparse Float Vector] add sparse float vector support to different
Browse files Browse the repository at this point in the history
milvus components, including proxy, data node to receive and write
sparse float vectors to binlog, query node to handle search requests,
index node to build index for sparse float column, etc.

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian committed Feb 18, 2024
1 parent 1329c72 commit fd1b555
Show file tree
Hide file tree
Showing 44 changed files with 1,640 additions and 128 deletions.
2 changes: 1 addition & 1 deletion internal/datacoord/compaction_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ func (t *compactionTrigger) handleSignal(signal *compactionSignal) {
segments := t.getCandidateSegments(channel, partitionID)

if len(segments) == 0 {
log.Info("the length of segments is 0, skip to handle compaction")
log.Info("the number of candidate segments is 0, skip to handle compaction")
return
}

Expand Down
9 changes: 9 additions & 0 deletions internal/datanode/compactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,15 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{}
data.Dim = len(data.Data) * 8 / int(numRows)
rst = data

case schemapb.DataType_SparseFloatVector:
data := &storage.SparseFloatVectorFieldData{}
for _, c := range content {
if err := data.AppendRow(c); err != nil {
return nil, fmt.Errorf("failed to append row: %v, %w", err, errTransferType)
}
}
rst = data

default:
return nil, errUnknownDataType
}
Expand Down
22 changes: 22 additions & 0 deletions internal/datanode/compactor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,28 @@ func TestCompactionTaskInnerMethods(t *testing.T) {
{false, schemapb.DataType_BinaryVector, []interface{}{nil, nil}, "invalid binaryvector"},
{false, schemapb.DataType_Float16Vector, []interface{}{nil, nil}, "invalid float16vector"},
{false, schemapb.DataType_BFloat16Vector, []interface{}{nil, nil}, "invalid bfloat16vector"},

// sparse float vector: each element is schemapb.SparseFloatRow
{false, schemapb.DataType_SparseFloatVector, []interface{}{nil, nil}, "invalid sparsefloatvector"},
{false, schemapb.DataType_SparseFloatVector, []interface{}{[]byte{255}, []byte{15}}, "invalid sparsefloatvector"},
{true, schemapb.DataType_SparseFloatVector, []interface{}{
&schemapb.SparseFloatRow{
Indices: &schemapb.IntArray{
Data: []int32{1, 2},
},
Values: &schemapb.FloatArray{
Data: []float32{1.0, 2.0},
},
},
&schemapb.SparseFloatRow{
Indices: &schemapb.IntArray{
Data: []int32{3, 4},
},
Values: &schemapb.FloatArray{
Data: []float32{1.0, 2.0},
},
},
}, "valid sparsefloatvector"},
}

// make sure all new data types missed to handle would throw unexpected error
Expand Down
4 changes: 4 additions & 0 deletions internal/indexnode/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package indexnode

import (
"errors"
"unsafe"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
Expand All @@ -37,5 +38,8 @@ func estimateFieldDataSize(dim int64, numRows int64, dataType schemapb.DataType)
if dataType == schemapb.DataType_BFloat16Vector {
return uint64(dim) * uint64(numRows) * 2, nil
}
if dataType == schemapb.DataType_SparseFloatVector {
return 0, errors.New("could not estimate field data size of SparseFloatVector")
}
return 0, nil
}
2 changes: 2 additions & 0 deletions internal/parser/planparserv2/plan_parser_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ func CreateSearchPlan(schemaPb *schemapb.CollectionSchema, exprStr string, vecto
vectorType = planpb.VectorType_Float16Vector
} else if dataType == schemapb.DataType_BFloat16Vector {
vectorType = planpb.VectorType_BFloat16Vector
} else if dataType == schemapb.DataType_SparseFloatVector {
vectorType = planpb.VectorType_SparseFloatVector
}
planNode := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
Expand Down
11 changes: 11 additions & 0 deletions internal/parser/planparserv2/plan_parser_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,17 @@ func TestCreateBFloat16earchPlan(t *testing.T) {
assert.NoError(t, err)
}

func TestCreateSparseFloatVectorSearchPlan(t *testing.T) {
schema := newTestSchema()
_, err := CreateSearchPlan(schema, `$meta["A"] != 10`, "SparseFloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
})
assert.NoError(t, err)
}

func TestExpr_Invalid(t *testing.T) {
schema := newTestSchema()
helper, err := typeutil.CreateSchemaHelper(schema)
Expand Down
5 changes: 2 additions & 3 deletions internal/proxy/msg_pack.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ import (
"strconv"
"time"

"go.uber.org/zap"
"golang.org/x/sync/errgroup"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
Expand All @@ -36,6 +33,8 @@ import (
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/retry"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
)

func genInsertMsgsByPartition(ctx context.Context,
Expand Down
4 changes: 2 additions & 2 deletions internal/proxy/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ func (t *createCollectionTask) PreExecute(ctx context.Context) error {
if err := validateFieldName(field.Name); err != nil {
return err
}
// validate vector field type parameters
if isVectorType(field.DataType) {
// validate dense vector field type parameters
if isVectorType(field.DataType) && !isSparseVectorType(field.DataType) {
err = validateDimension(field)
if err != nil {
return err
Expand Down
27 changes: 7 additions & 20 deletions internal/proxy/task_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,7 @@ func (cit *createIndexTask) parseIndexParams() error {
fmt.Sprintf("create index on %s field", cit.fieldSchema.DataType.String()),
fmt.Sprintf("create index on %s field is not supported", cit.fieldSchema.DataType.String()))
}
}

if isVecIndex {
} else {
specifyIndexType, exist := indexParamsMap[common.IndexTypeKey]
if Params.AutoIndexConfig.Enable.GetAsBool() { // `enable` only for cloud instance.
log.Info("create index trigger AutoIndex",
Expand Down Expand Up @@ -308,13 +306,7 @@ func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.Fiel
}

func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) error {
vecDataTypes := []schemapb.DataType{
schemapb.DataType_FloatVector,
schemapb.DataType_BinaryVector,
schemapb.DataType_Float16Vector,
schemapb.DataType_BFloat16Vector,
}
if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) {
if !isVectorType(field.GetDataType()) {
return nil
}
params := make([]*commonpb.KeyValuePair, 0, len(field.GetTypeParams())+len(field.GetIndexParams()))
Expand All @@ -337,14 +329,7 @@ func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) e

func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) error {
indexType := indexParams[common.IndexTypeKey]
// skip params check of non-vector field.
vecDataTypes := []schemapb.DataType{
schemapb.DataType_FloatVector,
schemapb.DataType_BinaryVector,
schemapb.DataType_Float16Vector,
schemapb.DataType_BFloat16Vector,
}
if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) {
if !isVectorType(field.GetDataType()) {
return indexparamcheck.CheckIndexValid(field.GetDataType(), indexType, indexParams)
}

Expand All @@ -354,8 +339,10 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro
return fmt.Errorf("invalid index type: %s", indexType)
}

if err := fillDimension(field, indexParams); err != nil {
return err
if !isSparseVectorType(field.DataType) {
if err := fillDimension(field, indexParams); err != nil {
return err
}
}

if err := checker.CheckValidDataType(field.GetDataType()); err != nil {
Expand Down
70 changes: 70 additions & 0 deletions internal/proxy/task_index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,76 @@ func TestCreateIndexTask_PreExecute(t *testing.T) {
})
}

func Test_sparse_parseIndexParams(t *testing.T) {
cit := &createIndexTask{
Condition: nil,
req: &milvuspb.CreateIndexRequest{
Base: nil,
DbName: "",
CollectionName: "",
FieldName: "",
ExtraParams: []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "SPARSE_INVERTED_INDEX",
},
{
Key: MetricTypeKey,
Value: "IP",
},
{
Key: common.IndexParamsKey,
Value: "{\"drop_ratio_build\": 0.3}",
},
},
IndexName: "",
},
ctx: nil,
rootCoord: nil,
result: nil,
isAutoIndex: false,
newIndexParams: nil,
newTypeParams: nil,
collectionID: 0,
fieldSchema: &schemapb.FieldSchema{
FieldID: 101,
Name: "FieldID",
IsPrimaryKey: false,
Description: "field no.1",
DataType: schemapb.DataType_SparseFloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: MetricTypeKey,
Value: "IP",
},
},
},
}

t.Run("parse index params", func(t *testing.T) {
err := cit.parseIndexParams()
assert.NoError(t, err)

assert.ElementsMatch(t,
[]*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "SPARSE_INVERTED_INDEX",
},
{
Key: MetricTypeKey,
Value: "IP",
},
{
Key: "drop_ratio_build",
Value: "0.3",
},
}, cit.newIndexParams)
assert.ElementsMatch(t,
[]*commonpb.KeyValuePair{}, cit.newTypeParams)
})
}

func Test_parseIndexParams(t *testing.T) {
cit := &createIndexTask{
Condition: nil,
Expand Down
30 changes: 19 additions & 11 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ func isVectorType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_FloatVector ||
dataType == schemapb.DataType_BinaryVector ||
dataType == schemapb.DataType_Float16Vector ||
dataType == schemapb.DataType_BFloat16Vector
dataType == schemapb.DataType_BFloat16Vector ||
dataType == schemapb.DataType_SparseFloatVector
}

func isSparseVectorType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_SparseFloatVector
}

func validateMaxQueryResultWindow(offset int64, limit int64) error {
Expand Down Expand Up @@ -293,6 +298,7 @@ func validateFieldName(fieldName string) error {
return nil
}

// should not be called on sparse vector field
func validateDimension(field *schemapb.FieldSchema) error {
exist := false
var dim int64
Expand Down Expand Up @@ -499,7 +505,7 @@ func isVector(dataType schemapb.DataType) (bool, error) {
schemapb.DataType_Float, schemapb.DataType_Double:
return false, nil

case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_SparseFloatVector:
return true, nil
}

Expand All @@ -510,7 +516,7 @@ func validateMetricType(dataType schemapb.DataType, metricTypeStrRaw string) err
metricTypeStr := strings.ToUpper(metricTypeStrRaw)
switch metricTypeStr {
case metric.L2, metric.IP, metric.COSINE:
if dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_Float16Vector || dataType == schemapb.DataType_BFloat16Vector {
if dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_Float16Vector || dataType == schemapb.DataType_BFloat16Vector || dataType == schemapb.DataType_SparseFloatVector {
return nil
}
case metric.JACCARD, metric.HAMMING, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE:
Expand Down Expand Up @@ -571,13 +577,15 @@ func validateSchema(coll *schemapb.CollectionSchema) error {
if err2 != nil {
return err2
}
dimStr, ok := typeKv[common.DimKey]
if !ok {
return fmt.Errorf("dim not found in type_params for vector field %s(%d)", field.Name, field.FieldID)
}
dim, err := strconv.Atoi(dimStr)
if err != nil || dim < 0 {
return fmt.Errorf("invalid dim; %s", dimStr)
if !isSparseVectorType(field.DataType) {
dimStr, ok := typeKv[common.DimKey]
if !ok {
return fmt.Errorf("dim not found in type_params for vector field %s(%d)", field.Name, field.FieldID)
}
dim, err := strconv.Atoi(dimStr)
if err != nil || dim < 0 {
return fmt.Errorf("invalid dim; %s", dimStr)
}
}

metricTypeStr, ok := indexKv[common.MetricTypeKey]
Expand Down Expand Up @@ -614,7 +622,7 @@ func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error {
for i := range schema.Fields {
name := schema.Fields[i].Name
dType := schema.Fields[i].DataType
isVec := dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector || dType == schemapb.DataType_Float16Vector || dType == schemapb.DataType_BFloat16Vector
isVec := dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector || dType == schemapb.DataType_Float16Vector || dType == schemapb.DataType_BFloat16Vector || dType == schemapb.DataType_SparseFloatVector
if isVec && vecExist && !enableMultipleVectorFields {
return fmt.Errorf(
"multiple vector fields is not supported, fields name: %s, %s",
Expand Down
20 changes: 20 additions & 0 deletions internal/proxy/validate_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col
if err := v.checkBinaryVectorFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_SparseFloatVector:
if err := v.checkSparseFloatFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_VarChar:
if err := v.checkVarCharFieldData(field, fieldSchema); err != nil {
return err
Expand Down Expand Up @@ -205,6 +209,13 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
if n != numRows {
return errNumRowsMismatch(field.GetFieldName(), n, numRows)
}

case schemapb.DataType_SparseFloatVector:
n := uint64(len(field.GetVectors().GetSparseFloatVector().Contents))
if n != numRows {
return errNumRowsMismatch(field.GetFieldName(), n, numRows)
}

default:
// error won't happen here.
n, err := funcutil.GetNumRowOfFieldData(field)
Expand Down Expand Up @@ -326,6 +337,15 @@ func (v *validateUtil) checkBinaryVectorFieldData(field *schemapb.FieldData, fie
return nil
}

func (v *validateUtil) checkSparseFloatFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
sparseRows := field.GetVectors().GetSparseFloatVector().GetContents()
if sparseRows == nil {
msg := fmt.Sprintf("sparse float field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need sparse float array", "got nil", msg)
}
return typeutil.ValidateSparseFloatRows(sparseRows...)
}

func (v *validateUtil) checkVarCharFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
strArr := field.GetScalars().GetStringData().GetData()
if strArr == nil && fieldSchema.GetDefaultValue() == nil {
Expand Down
Loading

0 comments on commit fd1b555

Please sign in to comment.