Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [Sparse Float Vector] add sparse vector support to milvus components #30630

Merged
merged 2 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/datacoord/compaction_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ func (t *compactionTrigger) handleSignal(signal *compactionSignal) {
segments := t.getCandidateSegments(channel, partitionID)

if len(segments) == 0 {
log.Info("the length of segments is 0, skip to handle compaction")
log.Info("the number of candidate segments is 0, skip to handle compaction")
return
}

Expand Down
9 changes: 9 additions & 0 deletions internal/datanode/compactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,15 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{}
data.Dim = len(data.Data) * 8 / int(numRows)
rst = data

case schemapb.DataType_SparseFloatVector:
data := &storage.SparseFloatVectorFieldData{}
for _, c := range content {
if err := data.AppendRow(c); err != nil {
return nil, fmt.Errorf("failed to append row: %v, %w", err, errTransferType)
}
}
rst = data

default:
return nil, errUnknownDataType
}
Expand Down
8 changes: 8 additions & 0 deletions internal/datanode/compactor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/testutils"
"github.com/milvus-io/milvus/pkg/util/timerecord"
)

Expand Down Expand Up @@ -105,6 +106,13 @@ func TestCompactionTaskInnerMethods(t *testing.T) {
{false, schemapb.DataType_BinaryVector, []interface{}{nil, nil}, "invalid binaryvector"},
{false, schemapb.DataType_Float16Vector, []interface{}{nil, nil}, "invalid float16vector"},
{false, schemapb.DataType_BFloat16Vector, []interface{}{nil, nil}, "invalid bfloat16vector"},

{false, schemapb.DataType_SparseFloatVector, []interface{}{nil, nil}, "invalid sparsefloatvector"},
{false, schemapb.DataType_SparseFloatVector, []interface{}{[]byte{255}, []byte{15}}, "invalid sparsefloatvector"},
{true, schemapb.DataType_SparseFloatVector, []interface{}{
testutils.CreateSparseFloatRow([]uint32{1, 2}, []float32{1.0, 2.0}),
testutils.CreateSparseFloatRow([]uint32{3, 4}, []float32{1.0, 2.0}),
}, "valid sparsefloatvector"},
}

// make sure all new data types missed to handle would throw unexpected error
Expand Down
5 changes: 5 additions & 0 deletions internal/indexnode/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import (
"unsafe"

"github.com/cockroachdb/errors"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)

Expand All @@ -37,5 +39,8 @@
if dataType == schemapb.DataType_BFloat16Vector {
return uint64(dim) * uint64(numRows) * 2, nil
}
if dataType == schemapb.DataType_SparseFloatVector {
return 0, errors.New("could not estimate field data size of SparseFloatVector")
}

Check warning on line 44 in internal/indexnode/util.go

View check run for this annotation

Codecov / codecov/patch

internal/indexnode/util.go#L42-L44

Added lines #L42 - L44 were not covered by tests
return 0, nil
}
2 changes: 2 additions & 0 deletions internal/parser/planparserv2/plan_parser_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField
vectorType = planpb.VectorType_Float16Vector
} else if dataType == schemapb.DataType_BFloat16Vector {
vectorType = planpb.VectorType_BFloat16Vector
} else if dataType == schemapb.DataType_SparseFloatVector {
vectorType = planpb.VectorType_SparseFloatVector
}
planNode := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
Expand Down
11 changes: 11 additions & 0 deletions internal/parser/planparserv2/plan_parser_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,17 @@ func TestCreateBFloat16earchPlan(t *testing.T) {
assert.NoError(t, err)
}

func TestCreateSparseFloatVectorSearchPlan(t *testing.T) {
schema := newTestSchemaHelper(t)
_, err := CreateSearchPlan(schema, `$meta["A"] != 10`, "SparseFloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
})
assert.NoError(t, err)
}

func TestExpr_Invalid(t *testing.T) {
schema := newTestSchema()
helper, err := typeutil.CreateSchemaHelper(schema)
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ func (t *createCollectionTask) PreExecute(ctx context.Context) error {
if err := validateFieldName(field.Name); err != nil {
return err
}
// validate vector field type parameters
// validate dense vector field type parameters
if isVectorType(field.DataType) {
err = validateDimension(field)
if err != nil {
Expand Down
34 changes: 14 additions & 20 deletions internal/proxy/task_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"github.com/milvus-io/milvus/pkg/util/indexparamcheck"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
Expand Down Expand Up @@ -174,9 +175,7 @@
fmt.Sprintf("create index on %s field", cit.fieldSchema.DataType.String()),
fmt.Sprintf("create index on %s field is not supported", cit.fieldSchema.DataType.String()))
}
}

if isVecIndex {
} else {
specifyIndexType, exist := indexParamsMap[common.IndexTypeKey]
if Params.AutoIndexConfig.Enable.GetAsBool() { // `enable` only for cloud instance.
log.Info("create index trigger AutoIndex",
Expand Down Expand Up @@ -258,6 +257,12 @@
return err
}
}
if indexType == indexparamcheck.IndexSparseInverted || indexType == indexparamcheck.IndexSparseWand {
metricType, metricTypeExist := indexParamsMap[common.MetricTypeKey]
if !metricTypeExist || metricType != metric.IP {
return fmt.Errorf("only IP is the supported metric type for sparse index")
}

Check warning on line 264 in internal/proxy/task_index.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/task_index.go#L263-L264

Added lines #L263 - L264 were not covered by tests
}

err := checkTrain(cit.fieldSchema, indexParamsMap)
if err != nil {
Expand Down Expand Up @@ -309,13 +314,7 @@
}

func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) error {
vecDataTypes := []schemapb.DataType{
schemapb.DataType_FloatVector,
schemapb.DataType_BinaryVector,
schemapb.DataType_Float16Vector,
schemapb.DataType_BFloat16Vector,
}
if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) {
if !isVectorType(field.GetDataType()) {
return nil
}
params := make([]*commonpb.KeyValuePair, 0, len(field.GetTypeParams())+len(field.GetIndexParams()))
Expand All @@ -338,14 +337,7 @@

func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) error {
indexType := indexParams[common.IndexTypeKey]
// skip params check of non-vector field.
vecDataTypes := []schemapb.DataType{
schemapb.DataType_FloatVector,
schemapb.DataType_BinaryVector,
schemapb.DataType_Float16Vector,
schemapb.DataType_BFloat16Vector,
}
if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) {
if !isVectorType(field.GetDataType()) {
return indexparamcheck.CheckIndexValid(field.GetDataType(), indexType, indexParams)
}

Expand All @@ -355,8 +347,10 @@
return fmt.Errorf("invalid index type: %s", indexType)
}

if err := fillDimension(field, indexParams); err != nil {
return err
if !isSparseVectorType(field.DataType) {
if err := fillDimension(field, indexParams); err != nil {
return err
}
}

if err := checker.CheckValidDataType(field.GetDataType()); err != nil {
Expand Down
70 changes: 70 additions & 0 deletions internal/proxy/task_index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,76 @@ func TestCreateIndexTask_PreExecute(t *testing.T) {
})
}

func Test_sparse_parseIndexParams(t *testing.T) {
cit := &createIndexTask{
Condition: nil,
req: &milvuspb.CreateIndexRequest{
Base: nil,
DbName: "",
CollectionName: "",
FieldName: "",
ExtraParams: []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "SPARSE_INVERTED_INDEX",
},
{
Key: MetricTypeKey,
Value: "IP",
},
{
Key: common.IndexParamsKey,
Value: "{\"drop_ratio_build\": 0.3}",
},
},
IndexName: "",
},
ctx: nil,
rootCoord: nil,
result: nil,
isAutoIndex: false,
newIndexParams: nil,
newTypeParams: nil,
collectionID: 0,
fieldSchema: &schemapb.FieldSchema{
FieldID: 101,
Name: "FieldID",
IsPrimaryKey: false,
Description: "field no.1",
DataType: schemapb.DataType_SparseFloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: MetricTypeKey,
Value: "IP",
},
},
},
}

t.Run("parse index params", func(t *testing.T) {
err := cit.parseIndexParams()
assert.NoError(t, err)

assert.ElementsMatch(t,
[]*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "SPARSE_INVERTED_INDEX",
},
{
Key: MetricTypeKey,
Value: "IP",
},
{
Key: "drop_ratio_build",
Value: "0.3",
},
}, cit.newIndexParams)
assert.ElementsMatch(t,
[]*commonpb.KeyValuePair{}, cit.newTypeParams)
})
}

func Test_parseIndexParams(t *testing.T) {
cit := &createIndexTask{
Condition: nil,
Expand Down
35 changes: 24 additions & 11 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@
return dataType == schemapb.DataType_FloatVector ||
dataType == schemapb.DataType_BinaryVector ||
dataType == schemapb.DataType_Float16Vector ||
dataType == schemapb.DataType_BFloat16Vector
dataType == schemapb.DataType_BFloat16Vector ||
dataType == schemapb.DataType_SparseFloatVector
}

func isSparseVectorType(dataType schemapb.DataType) bool {
return dataType == schemapb.DataType_SparseFloatVector
}

func validateMaxQueryResultWindow(offset int64, limit int64) error {
Expand Down Expand Up @@ -307,6 +312,12 @@
break
}
}
if isSparseVectorType(field.DataType) {
if exist {
return fmt.Errorf("dim should not be specified for sparse vector field %s(%d)", field.Name, field.FieldID)
}
return nil

Check warning on line 319 in internal/proxy/util.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/util.go#L316-L319

Added lines #L316 - L319 were not covered by tests
}
if !exist {
return errors.New("dimension is not defined in field type params, check type param `dim` for vector field")
}
Expand Down Expand Up @@ -509,7 +520,7 @@
schemapb.DataType_Float, schemapb.DataType_Double:
return false, nil

case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_SparseFloatVector:
return true, nil
}

Expand All @@ -520,7 +531,7 @@
metricTypeStr := strings.ToUpper(metricTypeStrRaw)
switch metricTypeStr {
case metric.L2, metric.IP, metric.COSINE:
if dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_Float16Vector || dataType == schemapb.DataType_BFloat16Vector {
if dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_Float16Vector || dataType == schemapb.DataType_BFloat16Vector || dataType == schemapb.DataType_SparseFloatVector {
return nil
}
case metric.JACCARD, metric.HAMMING, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE:
Expand Down Expand Up @@ -581,13 +592,15 @@
if err2 != nil {
return err2
}
dimStr, ok := typeKv[common.DimKey]
if !ok {
return fmt.Errorf("dim not found in type_params for vector field %s(%d)", field.Name, field.FieldID)
}
dim, err := strconv.Atoi(dimStr)
if err != nil || dim < 0 {
return fmt.Errorf("invalid dim; %s", dimStr)
if !isSparseVectorType(field.DataType) {
dimStr, ok := typeKv[common.DimKey]
if !ok {
return fmt.Errorf("dim not found in type_params for vector field %s(%d)", field.Name, field.FieldID)
}
dim, err := strconv.Atoi(dimStr)
if err != nil || dim < 0 {
return fmt.Errorf("invalid dim; %s", dimStr)
}
}

metricTypeStr, ok := indexKv[common.MetricTypeKey]
Expand Down Expand Up @@ -624,7 +637,7 @@
for i := range schema.Fields {
name := schema.Fields[i].Name
dType := schema.Fields[i].DataType
isVec := dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector || dType == schemapb.DataType_Float16Vector || dType == schemapb.DataType_BFloat16Vector
isVec := dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector || dType == schemapb.DataType_Float16Vector || dType == schemapb.DataType_BFloat16Vector || dType == schemapb.DataType_SparseFloatVector
if isVec && vecExist && !enableMultipleVectorFields {
return fmt.Errorf(
"multiple vector fields is not supported, fields name: %s, %s",
Expand Down
24 changes: 24 additions & 0 deletions internal/proxy/validate_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@
if err := v.checkBinaryVectorFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_SparseFloatVector:
if err := v.checkSparseFloatFieldData(field, fieldSchema); err != nil {
return err
}

Check warning on line 91 in internal/proxy/validate_util.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/validate_util.go#L88-L91

Added lines #L88 - L91 were not covered by tests
case schemapb.DataType_VarChar:
if err := v.checkVarCharFieldData(field, fieldSchema); err != nil {
return err
Expand Down Expand Up @@ -205,6 +209,13 @@
if n != numRows {
return errNumRowsMismatch(field.GetFieldName(), n)
}

case schemapb.DataType_SparseFloatVector:
n := uint64(len(field.GetVectors().GetSparseFloatVector().Contents))
if n != numRows {
return errNumRowsMismatch(field.GetFieldName(), n)
}

Check warning on line 217 in internal/proxy/validate_util.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/validate_util.go#L213-L217

Added lines #L213 - L217 were not covered by tests

default:
// error won't happen here.
n, err := funcutil.GetNumRowOfFieldData(field)
Expand Down Expand Up @@ -326,6 +337,19 @@
return nil
}

func (v *validateUtil) checkSparseFloatFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
if field.GetVectors() == nil || field.GetVectors().GetSparseFloatVector() == nil {
msg := fmt.Sprintf("sparse float field '%v' is illegal, nil SparseFloatVector", field.GetFieldName())
return merr.WrapErrParameterInvalid("need sparse float array", "got nil", msg)
}
sparseRows := field.GetVectors().GetSparseFloatVector().GetContents()
if sparseRows == nil {
msg := fmt.Sprintf("sparse float field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need sparse float array", "got nil", msg)
}
return typeutil.ValidateSparseFloatRows(sparseRows...)

Check warning on line 350 in internal/proxy/validate_util.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/validate_util.go#L340-L350

Added lines #L340 - L350 were not covered by tests
}

func (v *validateUtil) checkVarCharFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
strArr := field.GetScalars().GetStringData().GetData()
if strArr == nil && fieldSchema.GetDefaultValue() == nil {
Expand Down
4 changes: 2 additions & 2 deletions internal/querynodev2/local_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (suite *LocalWorkerTestSuite) BeforeTest(suiteName, testName string) {
err = suite.node.Start()
suite.NoError(err)

suite.schema = segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
suite.schema = segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
suite.indexMeta = segments.GenTestIndexMeta(suite.collectionID, suite.schema)
collection := segments.NewCollection(suite.collectionID, suite.schema, suite.indexMeta, querypb.LoadType_LoadCollection)
loadMata := &querypb.LoadMetaInfo{
Expand All @@ -111,7 +111,7 @@ func (suite *LocalWorkerTestSuite) AfterTest(suiteName, testName string) {

func (suite *LocalWorkerTestSuite) TestLoadSegment() {
// load empty
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: suite.node.session.GetServerID(),
Expand Down
Loading
Loading