From a2a1f9d08c5b2fd9003148e2a98139d0f9af49fa Mon Sep 17 00:00:00 2001 From: dragondriver Date: Mon, 6 Dec 2021 10:03:34 +0800 Subject: [PATCH] Disable multiple vector fields (#12691) Signed-off-by: dragondriver --- internal/proxy/task.go | 4 ++ internal/proxy/task_test.go | 45 ++++++++++++++++++- internal/proxy/validate_util.go | 25 +++++++++++ internal/proxy/validate_util_test.go | 36 +++++++++++++++ .../testcases/test_collection.py | 2 + .../python_client/testcases/test_insert_20.py | 1 + tests/python_client/testcases/test_query.py | 6 +++ 7 files changed, 118 insertions(+), 1 deletion(-) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 2ec3d6940224..f3f2a63bec28 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1173,6 +1173,10 @@ func (cct *createCollectionTask) PreExecute(ctx context.Context) error { } } + if err := validateMultipleVectorFields(cct.schema); err != nil { + return err + } + return nil } diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index cccfd26449c4..6abdbf79bc1e 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -178,6 +178,23 @@ func constructCollectionSchemaWithAllType( AutoID: false, } + if enableMultipleVectorFields { + return &schemapb.CollectionSchema{ + Name: collectionName, + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + b, + i32, + i64, + f, + d, + fVec, + bVec, + }, + } + } + return &schemapb.CollectionSchema{ Name: collectionName, Description: "", @@ -189,7 +206,7 @@ func constructCollectionSchemaWithAllType( f, d, fVec, - bVec, + // bVec, }, } } @@ -1114,6 +1131,32 @@ func TestCreateCollectionTask(t *testing.T) { task.CreateCollectionRequest.Schema = binaryTooLargeDimSchema err = task.PreExecute(ctx) assert.Error(t, err) + + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 0, + Name: "second_vector", + IsPrimaryKey: false, + Description: "", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: strconv.Itoa(128), + }, + }, + IndexParams: nil, + AutoID: false, + }) + twoVecFieldsSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = twoVecFieldsSchema + err = task.PreExecute(ctx) + if enableMultipleVectorFields { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } }) } diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index f9f65b2ef261..22b511a5f390 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -26,6 +26,8 @@ import ( "github.com/milvus-io/milvus/internal/proto/schemapb" ) +const enableMultipleVectorFields = false + func isAlpha(c uint8) bool { if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') { return false @@ -343,3 +345,26 @@ func validateSchema(coll *schemapb.CollectionSchema) error { return nil } + +func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error { + vecExist := false + var vecName string + + for i := range schema.Fields { + name := schema.Fields[i].Name + dType := schema.Fields[i].DataType + isVec := (dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector) + if isVec && vecExist && !enableMultipleVectorFields { + return fmt.Errorf( + "multiple vector fields is not supported, fields name: %s, %s", + vecName, + name, + ) + } else if isVec { + vecExist = true + vecName = name + } + } + + return nil +} diff --git a/internal/proxy/validate_util_test.go b/internal/proxy/validate_util_test.go index b750141494d2..ef6dbc5e295b 100644 --- a/internal/proxy/validate_util_test.go +++ b/internal/proxy/validate_util_test.go @@ -372,3 +372,39 @@ func TestValidateSchema(t *testing.T) { pf3.IndexParams = ip3Good assert.Nil(t, validateSchema(coll)) } + +func TestValidateMultipleVectorFields(t *testing.T) { + // case1, no vector field + schema1 := &schemapb.CollectionSchema{} + assert.NoError(t, validateMultipleVectorFields(schema1)) + + // case2, only one vector field + schema2 := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "case2", + DataType: schemapb.DataType_FloatVector, + }, + }, + } + assert.NoError(t, validateMultipleVectorFields(schema2)) + + // case3, multiple vectors + schema3 := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "case3_f", + DataType: schemapb.DataType_FloatVector, + }, + { + Name: "case3_b", + DataType: schemapb.DataType_BinaryVector, + }, + }, + } + if enableMultipleVectorFields { + assert.NoError(t, validateMultipleVectorFields(schema3)) + } else { + assert.Error(t, validateMultipleVectorFields(schema3)) + } +} diff --git a/tests/python_client/testcases/test_collection.py b/tests/python_client/testcases/test_collection.py index 07b2f1803cb9..77d9a6f728a1 100644 --- a/tests/python_client/testcases/test_collection.py +++ b/tests/python_client/testcases/test_collection.py @@ -422,6 +422,7 @@ def test_collection_only_vector_field(self, field): self.collection_schema_wrap.init_collection_schema([field], check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.skip("https://github.com/milvus-io/milvus/issues/12680") def test_collection_multi_float_vectors(self): """ target: test collection with multi float vectors @@ -436,6 +437,7 @@ def test_collection_multi_float_vectors(self): check_items={exp_name: c_name, exp_schema: schema}) @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.skip("https://github.com/milvus-io/milvus/issues/12680") def test_collection_mix_vectors(self): """ target: test collection with mix vectors diff --git a/tests/python_client/testcases/test_insert_20.py b/tests/python_client/testcases/test_insert_20.py index 3dc0397a6964..b1a8c9441df5 100644 --- a/tests/python_client/testcases/test_insert_20.py +++ b/tests/python_client/testcases/test_insert_20.py @@ -421,6 +421,7 @@ def test_insert_without_connection(self): collection_w.insert(data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.skip("https://github.com/milvus-io/milvus/issues/12680") @pytest.mark.parametrize("vec_fields", [[cf.gen_float_vec_field(name="float_vector1")], [cf.gen_binary_vec_field()], [cf.gen_binary_vec_field(), cf.gen_binary_vec_field("binary_vec")]]) diff --git a/tests/python_client/testcases/test_query.py b/tests/python_client/testcases/test_query.py index cfc373bde284..497ae821417f 100644 --- a/tests/python_client/testcases/test_query.py +++ b/tests/python_client/testcases/test_query.py @@ -531,6 +531,7 @@ def test_query_output_float_vec_field(self): check_items={exp_res: res, "with_vec": True}) @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.skip("https://github.com/milvus-io/milvus/issues/12680") @pytest.mark.parametrize("vec_fields", [[cf.gen_float_vec_field(name="float_vector1")]]) def test_query_output_multi_float_vec_field(self, vec_fields): """ @@ -557,6 +558,7 @@ def test_query_output_multi_float_vec_field(self, vec_fields): check_items={exp_res: res, "with_vec": True}) @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.skip("https://github.com/milvus-io/milvus/issues/12680") @pytest.mark.parametrize("vec_fields", [[cf.gen_binary_vec_field()], [cf.gen_binary_vec_field(), cf.gen_binary_vec_field("binary_vec1")]]) def test_query_output_mix_float_binary_field(self, vec_fields): @@ -642,6 +644,7 @@ def test_query_invalid_output_fields(self): check_items=error) @pytest.mark.tags(CaseLabel.L0) + @pytest.mark.skip("https://github.com/milvus-io/milvus/issues/12680") def test_query_output_fields_simple_wildcard(self): """ target: test query output_fields with simple wildcard (* and %) @@ -673,6 +676,7 @@ def test_query_output_fields_simple_wildcard(self): check_items={exp_res: res3, "with_vec": True}) @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.skip("https://github.com/milvus-io/milvus/issues/12680") def test_query_output_fields_part_scale_wildcard(self): """ target: test query output_fields with part wildcard @@ -697,6 +701,7 @@ def test_query_output_fields_part_scale_wildcard(self): check_items={exp_res: res2}) @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.skip("https://github.com/milvus-io/milvus/issues/12680") def test_query_output_fields_part_vector_wildcard(self): """ target: test query output_fields with part wildcard @@ -721,6 +726,7 @@ def test_query_output_fields_part_vector_wildcard(self): check_items={exp_res: res2, "with_vec": True}) @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.skip("https://github.com/milvus-io/milvus/issues/12680") @pytest.mark.parametrize("output_fields", [["*%"], ["**"], ["*", "@"]]) def test_query_invalid_wildcard(self, output_fields): """