Skip to content

Commit

Permalink
[Sparse Float Vector] add sparse float vector support to different
Browse files Browse the repository at this point in the history
milvus components, including proxy, data node to receive and write
sparse float vectors to binlog, query node to handle search requests,
index node to build index for sparse float column, etc.

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian committed Mar 6, 2024
1 parent 3699914 commit c75897d
Show file tree
Hide file tree
Showing 45 changed files with 1,448 additions and 151 deletions.
2 changes: 1 addition & 1 deletion internal/datacoord/compaction_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ func (t *compactionTrigger) handleSignal(signal *compactionSignal) {
segments := t.getCandidateSegments(channel, partitionID)

if len(segments) == 0 {
log.Info("the length of segments is 0, skip to handle compaction")
log.Info("the number of candidate segments is 0, skip to handle compaction")
return
}

Expand Down
9 changes: 9 additions & 0 deletions internal/datanode/compactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,15 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{}
data.Dim = len(data.Data) * 8 / int(numRows)
rst = data

case schemapb.DataType_SparseFloatVector:
data := &storage.SparseFloatVectorFieldData{}
for _, c := range content {
if err := data.AppendRow(c); err != nil {
return nil, fmt.Errorf("failed to append row: %v, %w", err, errTransferType)
}
}
rst = data

default:
return nil, errUnknownDataType
}
Expand Down
8 changes: 8 additions & 0 deletions internal/datanode/compactor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/testutils"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)

Expand Down Expand Up @@ -105,6 +106,13 @@ func TestCompactionTaskInnerMethods(t *testing.T) {
{false, schemapb.DataType_BinaryVector, []interface{}{nil, nil}, "invalid binaryvector"},
{false, schemapb.DataType_Float16Vector, []interface{}{nil, nil}, "invalid float16vector"},
{false, schemapb.DataType_BFloat16Vector, []interface{}{nil, nil}, "invalid bfloat16vector"},

{false, schemapb.DataType_SparseFloatVector, []interface{}{nil, nil}, "invalid sparsefloatvector"},
{false, schemapb.DataType_SparseFloatVector, []interface{}{[]byte{255}, []byte{15}}, "invalid sparsefloatvector"},
{true, schemapb.DataType_SparseFloatVector, []interface{}{
testutils.CreateSparseFloatRow([]uint32{1, 2}, []float32{1.0, 2.0}),
testutils.CreateSparseFloatRow([]uint32{3, 4}, []float32{1.0, 2.0}),
}, "valid sparsefloatvector"},
}

// make sure all new data types missed to handle would throw unexpected error
Expand Down
4 changes: 4 additions & 0 deletions internal/indexnode/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package indexnode

import (
"errors"
"unsafe"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
Expand All @@ -37,5 +38,8 @@ func estimateFieldDataSize(dim int64, numRows int64, dataType schemapb.DataType)
if dataType == schemapb.DataType_BFloat16Vector {
return uint64(dim) * uint64(numRows) * 2, nil
}
if dataType == schemapb.DataType_SparseFloatVector {
return 0, errors.New("could not estimate field data size of SparseFloatVector")
}
return 0, nil
}
2 changes: 2 additions & 0 deletions internal/parser/planparserv2/plan_parser_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField
vectorType = planpb.VectorType_Float16Vector
} else if dataType == schemapb.DataType_BFloat16Vector {
vectorType = planpb.VectorType_BFloat16Vector
} else if dataType == schemapb.DataType_SparseFloatVector {
vectorType = planpb.VectorType_SparseFloatVector
}
planNode := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
Expand Down
11 changes: 11 additions & 0 deletions internal/parser/planparserv2/plan_parser_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,17 @@ func TestCreateBFloat16earchPlan(t *testing.T) {
assert.NoError(t, err)
}

func TestCreateSparseFloatVectorSearchPlan(t *testing.T) {
schema := newTestSchema()
_, err := CreateSearchPlan(schema, `$meta["A"] != 10`, "SparseFloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
})
assert.NoError(t, err)
}

func TestExpr_Invalid(t *testing.T) {
schema := newTestSchema()
helper, err := typeutil.CreateSchemaHelper(schema)
Expand Down
5 changes: 2 additions & 3 deletions internal/proxy/msg_pack.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ import (
"strconv"
"time"

"go.uber.org/zap"
"golang.org/x/sync/errgroup"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
Expand All @@ -36,6 +33,8 @@ import (
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/retry"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
)

func genInsertMsgsByPartition(ctx context.Context,
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (t *createCollectionTask) PreExecute(ctx context.Context) error {
if err := validateFieldName(field.Name); err != nil {
return err
}
// validate vector field type parameters
// validate dense vector field type parameters
if isVectorType(field.DataType) {
err = validateDimension(field)
if err != nil {
Expand Down
34 changes: 14 additions & 20 deletions internal/proxy/task_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
Expand Down Expand Up @@ -174,9 +175,7 @@ func (cit *createIndexTask) parseIndexParams() error {
fmt.Sprintf("create index on %s field", cit.fieldSchema.DataType.String()),
fmt.Sprintf("create index on %s field is not supported", cit.fieldSchema.DataType.String()))
}
}

if isVecIndex {
} else {
specifyIndexType, exist := indexParamsMap[common.IndexTypeKey]
if Params.AutoIndexConfig.Enable.GetAsBool() { // `enable` only for cloud instance.
log.Info("create index trigger AutoIndex",
Expand Down Expand Up @@ -258,6 +257,12 @@ func (cit *createIndexTask) parseIndexParams() error {
return err
}
}
if indexType == indexparamcheck.IndexSparseInverted || indexType == indexparamcheck.IndexSparseWand {
metricType, metricTypeExist := indexParamsMap[common.MetricTypeKey]
if !metricTypeExist || metricType != metric.IP {
return fmt.Errorf("only IP is the supported metric type for sparse index")
}
}

err := checkTrain(cit.fieldSchema, indexParamsMap)
if err != nil {
Expand Down Expand Up @@ -309,13 +314,7 @@ func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.Fiel
}

func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) error {
vecDataTypes := []schemapb.DataType{
schemapb.DataType_FloatVector,
schemapb.DataType_BinaryVector,
schemapb.DataType_Float16Vector,
schemapb.DataType_BFloat16Vector,
}
if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) {
if !isVectorType(field.GetDataType()) {
return nil
}
params := make([]*commonpb.KeyValuePair, 0, len(field.GetTypeParams())+len(field.GetIndexParams()))
Expand All @@ -338,14 +337,7 @@ func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) e

func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) error {
indexType := indexParams[common.IndexTypeKey]
// skip params check of non-vector field.
vecDataTypes := []schemapb.DataType{
schemapb.DataType_FloatVector,
schemapb.DataType_BinaryVector,
schemapb.DataType_Float16Vector,
schemapb.DataType_BFloat16Vector,
}
if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) {
if !isVectorType(field.GetDataType()) {
return indexparamcheck.CheckIndexValid(field.GetDataType(), indexType, indexParams)
}

Expand All @@ -355,8 +347,10 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro
return fmt.Errorf("invalid index type: %s", indexType)
}

if err := fillDimension(field, indexParams); err != nil {
return err
if !isSparseVectorType(field.DataType) {
if err := fillDimension(field, indexParams); err != nil {
return err
}
}

if err := checker.CheckValidDataType(field.GetDataType()); err != nil {
Expand Down
70 changes: 70 additions & 0 deletions internal/proxy/task_index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,76 @@ func TestCreateIndexTask_PreExecute(t *testing.T) {
})
}

func Test_sparse_parseIndexParams(t *testing.T) {
cit := &createIndexTask{
Condition: nil,
req: &milvuspb.CreateIndexRequest{
Base: nil,
DbName: "",
CollectionName: "",
FieldName: "",
ExtraParams: []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "SPARSE_INVERTED_INDEX",
},
{
Key: MetricTypeKey,
Value: "IP",
},
{
Key: common.IndexParamsKey,
Value: "{\"drop_ratio_build\": 0.3}",
},
},
IndexName: "",
},
ctx: nil,
rootCoord: nil,
result: nil,
isAutoIndex: false,
newIndexParams: nil,
newTypeParams: nil,
collectionID: 0,
fieldSchema: &schemapb.FieldSchema{
FieldID: 101,
Name: "FieldID",
IsPrimaryKey: false,
Description: "field no.1",
DataType: schemapb.DataType_SparseFloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: MetricTypeKey,
Value: "IP",
},
},
},
}

t.Run("parse index params", func(t *testing.T) {
err := cit.parseIndexParams()
assert.NoError(t, err)

assert.ElementsMatch(t,
[]*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "SPARSE_INVERTED_INDEX",
},
{
Key: MetricTypeKey,
Value: "IP",
},
{
Key: "drop_ratio_build",
Value: "0.3",
},
}, cit.newIndexParams)
assert.ElementsMatch(t,
[]*commonpb.KeyValuePair{}, cit.newTypeParams)
})
}

func Test_parseIndexParams(t *testing.T) {
cit := &createIndexTask{
Condition: nil,
Expand Down
35 changes: 24 additions & 11 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ func isVectorType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_FloatVector ||
dataType == schemapb.DataType_BinaryVector ||
dataType == schemapb.DataType_Float16Vector ||
dataType == schemapb.DataType_BFloat16Vector
dataType == schemapb.DataType_BFloat16Vector ||
dataType == schemapb.DataType_SparseFloatVector
}

func isSparseVectorType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_SparseFloatVector
}

func validateMaxQueryResultWindow(offset int64, limit int64) error {
Expand Down Expand Up @@ -307,6 +312,12 @@ func validateDimension(field *schemapb.FieldSchema) error {
break
}
}
if isSparseVectorType(field.DataType) {
if exist {
return fmt.Errorf("dim should not be specified for sparse vector field %s(%d)", field.Name, field.FieldID)
}
return nil
}
if !exist {
return errors.New("dimension is not defined in field type params, check type param `dim` for vector field")
}
Expand Down Expand Up @@ -509,7 +520,7 @@ func isVector(dataType schemapb.DataType) (bool, error) {
schemapb.DataType_Float, schemapb.DataType_Double:
return false, nil

case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_SparseFloatVector:
return true, nil
}

Expand All @@ -520,7 +531,7 @@ func validateMetricType(dataType schemapb.DataType, metricTypeStrRaw string) err
metricTypeStr := strings.ToUpper(metricTypeStrRaw)
switch metricTypeStr {
case metric.L2, metric.IP, metric.COSINE:
if dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_Float16Vector || dataType == schemapb.DataType_BFloat16Vector {
if dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_Float16Vector || dataType == schemapb.DataType_BFloat16Vector || dataType == schemapb.DataType_SparseFloatVector {
return nil
}
case metric.JACCARD, metric.HAMMING, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE:
Expand Down Expand Up @@ -581,13 +592,15 @@ func validateSchema(coll *schemapb.CollectionSchema) error {
if err2 != nil {
return err2
}
dimStr, ok := typeKv[common.DimKey]
if !ok {
return fmt.Errorf("dim not found in type_params for vector field %s(%d)", field.Name, field.FieldID)
}
dim, err := strconv.Atoi(dimStr)
if err != nil || dim < 0 {
return fmt.Errorf("invalid dim; %s", dimStr)
if !isSparseVectorType(field.DataType) {
dimStr, ok := typeKv[common.DimKey]
if !ok {
return fmt.Errorf("dim not found in type_params for vector field %s(%d)", field.Name, field.FieldID)
}
dim, err := strconv.Atoi(dimStr)
if err != nil || dim < 0 {
return fmt.Errorf("invalid dim; %s", dimStr)
}
}

metricTypeStr, ok := indexKv[common.MetricTypeKey]
Expand Down Expand Up @@ -624,7 +637,7 @@ func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error {
for i := range schema.Fields {
name := schema.Fields[i].Name
dType := schema.Fields[i].DataType
isVec := dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector || dType == schemapb.DataType_Float16Vector || dType == schemapb.DataType_BFloat16Vector
isVec := dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector || dType == schemapb.DataType_Float16Vector || dType == schemapb.DataType_BFloat16Vector || dType == schemapb.DataType_SparseFloatVector
if isVec && vecExist && !enableMultipleVectorFields {
return fmt.Errorf(
"multiple vector fields is not supported, fields name: %s, %s",
Expand Down
24 changes: 24 additions & 0 deletions internal/proxy/validate_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col
if err := v.checkBinaryVectorFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_SparseFloatVector:
if err := v.checkSparseFloatFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_VarChar:
if err := v.checkVarCharFieldData(field, fieldSchema); err != nil {
return err
Expand Down Expand Up @@ -205,6 +209,13 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
if n != numRows {
return errNumRowsMismatch(field.GetFieldName(), n)
}

case schemapb.DataType_SparseFloatVector:
n := uint64(len(field.GetVectors().GetSparseFloatVector().Contents))
if n != numRows {
return errNumRowsMismatch(field.GetFieldName(), n)
}

default:
// error won't happen here.
n, err := funcutil.GetNumRowOfFieldData(field)
Expand Down Expand Up @@ -326,6 +337,19 @@ func (v *validateUtil) checkBinaryVectorFieldData(field *schemapb.FieldData, fie
return nil
}

func (v *validateUtil) checkSparseFloatFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
if field.GetVectors() == nil || field.GetVectors().GetSparseFloatVector() == nil {
msg := fmt.Sprintf("sparse float field '%v' is illegal, nil SparseFloatVector", field.GetFieldName())
return merr.WrapErrParameterInvalid("need sparse float array", "got nil", msg)
}
sparseRows := field.GetVectors().GetSparseFloatVector().GetContents()
if sparseRows == nil {
msg := fmt.Sprintf("sparse float field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need sparse float array", "got nil", msg)
}
return typeutil.ValidateSparseFloatRows(sparseRows...)
}

func (v *validateUtil) checkVarCharFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
strArr := field.GetScalars().GetStringData().GetData()
if strArr == nil && fieldSchema.GetDefaultValue() == nil {
Expand Down
Loading

0 comments on commit c75897d

Please sign in to comment.