diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 54207f1869e9..66fe2555c876 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -389,7 +389,7 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro } } - if err := checker.CheckValidDataType(field.GetDataType()); err != nil { + if err := checker.CheckValidDataType(field); err != nil { log.Info("create index with invalid data type", zap.Error(err), zap.String("data_type", field.GetDataType().String())) return err } diff --git a/pkg/util/indexparamcheck/auto_index_checker.go b/pkg/util/indexparamcheck/auto_index_checker.go index 9f960d96695b..cc83f196d2e0 100644 --- a/pkg/util/indexparamcheck/auto_index_checker.go +++ b/pkg/util/indexparamcheck/auto_index_checker.go @@ -13,7 +13,7 @@ func (c *AUTOINDEXChecker) CheckTrain(params map[string]string) error { return nil } -func (c *AUTOINDEXChecker) CheckValidDataType(dType schemapb.DataType) error { +func (c *AUTOINDEXChecker) CheckValidDataType(field *schemapb.FieldSchema) error { return nil } diff --git a/pkg/util/indexparamcheck/base_checker.go b/pkg/util/indexparamcheck/base_checker.go index 2566fc097591..bcddaa591758 100644 --- a/pkg/util/indexparamcheck/base_checker.go +++ b/pkg/util/indexparamcheck/base_checker.go @@ -36,7 +36,7 @@ func (c baseChecker) CheckTrain(params map[string]string) error { } // CheckValidDataType check whether the field data type is supported for the index type -func (c baseChecker) CheckValidDataType(dType schemapb.DataType) error { +func (c baseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { return nil } diff --git a/pkg/util/indexparamcheck/base_checker_test.go b/pkg/util/indexparamcheck/base_checker_test.go index a016d4da8849..a8dd77d56ba4 100644 --- a/pkg/util/indexparamcheck/base_checker_test.go +++ b/pkg/util/indexparamcheck/base_checker_test.go @@ -98,7 +98,8 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) { c := newBaseChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + field_schema := &schemapb.FieldSchema{DataType: test.dType} + err := c.CheckValidDataType(field_schema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/bin_flat_checker_test.go b/pkg/util/indexparamcheck/bin_flat_checker_test.go index 7c10f2e62b3d..dfe496529cc6 100644 --- a/pkg/util/indexparamcheck/bin_flat_checker_test.go +++ b/pkg/util/indexparamcheck/bin_flat_checker_test.go @@ -136,7 +136,8 @@ func Test_binFlatChecker_CheckValidDataType(t *testing.T) { c := newBinFlatChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + field_schema := &schemapb.FieldSchema{DataType: test.dType} + err := c.CheckValidDataType(field_schema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go b/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go index 27ef913c2aee..48b6933fba45 100644 --- a/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go +++ b/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go @@ -187,7 +187,8 @@ func Test_binIVFFlatChecker_CheckValidDataType(t *testing.T) { c := newBinIVFFlatChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + field_schema := &schemapb.FieldSchema{DataType: test.dType} + err := c.CheckValidDataType(field_schema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/binary_vector_base_checker.go b/pkg/util/indexparamcheck/binary_vector_base_checker.go index e700fab78aa0..e73bd8b62e40 100644 --- a/pkg/util/indexparamcheck/binary_vector_base_checker.go +++ b/pkg/util/indexparamcheck/binary_vector_base_checker.go @@ -27,8 +27,8 @@ func (c binaryVectorBaseChecker) CheckTrain(params map[string]string) error { return c.staticCheck(params) } -func (c binaryVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error { - if dType != schemapb.DataType_BinaryVector { +func (c binaryVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { + if field.GetDataType() != schemapb.DataType_BinaryVector { return fmt.Errorf("binary vector is only supported") } return nil diff --git a/pkg/util/indexparamcheck/binary_vector_base_checker_test.go b/pkg/util/indexparamcheck/binary_vector_base_checker_test.go index fc166fabd921..08a05297b9f5 100644 --- a/pkg/util/indexparamcheck/binary_vector_base_checker_test.go +++ b/pkg/util/indexparamcheck/binary_vector_base_checker_test.go @@ -69,7 +69,8 @@ func Test_binaryVectorBaseChecker_CheckValidDataType(t *testing.T) { c := newBinaryVectorBaseChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + field_schema := &schemapb.FieldSchema{DataType: test.dType} + err := c.CheckValidDataType(field_schema) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/bitmap_checker_test.go b/pkg/util/indexparamcheck/bitmap_checker_test.go index 7f1bb38986a8..5d76b3a586f1 100644 --- a/pkg/util/indexparamcheck/bitmap_checker_test.go +++ b/pkg/util/indexparamcheck/bitmap_checker_test.go @@ -13,12 +13,24 @@ func Test_BitmapIndexChecker(t *testing.T) { assert.NoError(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "100"})) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Int64)) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Float)) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_String)) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Array)) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int8})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int16})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_String})) - assert.Error(t, c.CheckValidDataType(schemapb.DataType_JSON)) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Double})) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) assert.Error(t, c.CheckTrain(map[string]string{})) assert.Error(t, c.CheckTrain(map[string]string{"bitmap_cardinality_limit": "0"})) } diff --git a/pkg/util/indexparamcheck/bitmap_index_checker.go b/pkg/util/indexparamcheck/bitmap_index_checker.go index 3b9be2786e3b..d5501cdfe19f 100644 --- a/pkg/util/indexparamcheck/bitmap_index_checker.go +++ b/pkg/util/indexparamcheck/bitmap_index_checker.go @@ -20,9 +20,18 @@ func (c *BITMAPChecker) CheckTrain(params map[string]string) error { return c.scalarIndexChecker.CheckTrain(params) } -func (c *BITMAPChecker) CheckValidDataType(dType schemapb.DataType) error { - if !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) && !typeutil.IsArrayType(dType) { - return fmt.Errorf("bitmap index are only supported on numeric, string and array field") +func (c *BITMAPChecker) CheckValidDataType(field *schemapb.FieldSchema) error { + main_type := field.GetDataType() + elem_type := field.GetElementType() + if !typeutil.IsBoolType(main_type) && !typeutil.IsIntegerType(main_type) && + !typeutil.IsStringType(main_type) && !typeutil.IsArrayType(main_type) { + return fmt.Errorf("bitmap index are only supported on bool, int, string and array field") + } + if typeutil.IsArrayType(main_type) { + if !typeutil.IsBoolType(elem_type) && !typeutil.IsIntegerType(elem_type) && + !typeutil.IsStringType(elem_type) { + return fmt.Errorf("bitmap index are only supported on bool, int, string for array field") + } } return nil } diff --git a/pkg/util/indexparamcheck/diskann_checker_test.go b/pkg/util/indexparamcheck/diskann_checker_test.go index 411e8f97d8e9..4fcfdbf019aa 100644 --- a/pkg/util/indexparamcheck/diskann_checker_test.go +++ b/pkg/util/indexparamcheck/diskann_checker_test.go @@ -144,7 +144,7 @@ func Test_diskannChecker_CheckValidDataType(t *testing.T) { c := newDiskannChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/float_vector_base_checker.go b/pkg/util/indexparamcheck/float_vector_base_checker.go index c6d2bd453f47..710dfb3a18a3 100644 --- a/pkg/util/indexparamcheck/float_vector_base_checker.go +++ b/pkg/util/indexparamcheck/float_vector_base_checker.go @@ -28,8 +28,8 @@ func (c floatVectorBaseChecker) CheckTrain(params map[string]string) error { return c.staticCheck(params) } -func (c floatVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error { - if !typeutil.IsDenseFloatVectorType(dType) { +func (c floatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { + if !typeutil.IsDenseFloatVectorType(field.GetDataType()) { return fmt.Errorf("data type should be FloatVector, Float16Vector or BFloat16Vector") } return nil diff --git a/pkg/util/indexparamcheck/float_vector_base_checker_test.go b/pkg/util/indexparamcheck/float_vector_base_checker_test.go index affc4d9d53c2..7eb0a97d36c6 100644 --- a/pkg/util/indexparamcheck/float_vector_base_checker_test.go +++ b/pkg/util/indexparamcheck/float_vector_base_checker_test.go @@ -69,7 +69,7 @@ func Test_floatVectorBaseChecker_CheckValidDataType(t *testing.T) { c := newFloatVectorBaseChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/hnsw_checker.go b/pkg/util/indexparamcheck/hnsw_checker.go index 729fea6d72d2..b5f9e1f2b77e 100644 --- a/pkg/util/indexparamcheck/hnsw_checker.go +++ b/pkg/util/indexparamcheck/hnsw_checker.go @@ -32,9 +32,9 @@ func (c hnswChecker) CheckTrain(params map[string]string) error { return c.baseChecker.CheckTrain(params) } -func (c hnswChecker) CheckValidDataType(dType schemapb.DataType) error { - if !typeutil.IsVectorType(dType) { - return fmt.Errorf("can't build hnsw in not vector type.") +func (c hnswChecker) CheckValidDataType(field *schemapb.FieldSchema) error { + if !typeutil.IsVectorType(field.GetDataType()) { + return fmt.Errorf("can't build hnsw in not vector type") } return nil } diff --git a/pkg/util/indexparamcheck/hnsw_checker_test.go b/pkg/util/indexparamcheck/hnsw_checker_test.go index fd2499cafc0c..1b19099e93db 100644 --- a/pkg/util/indexparamcheck/hnsw_checker_test.go +++ b/pkg/util/indexparamcheck/hnsw_checker_test.go @@ -164,7 +164,7 @@ func Test_hnswChecker_CheckValidDataType(t *testing.T) { c := newHnswChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/index_checker.go b/pkg/util/indexparamcheck/index_checker.go index 2b52b6f1de34..1c1128089839 100644 --- a/pkg/util/indexparamcheck/index_checker.go +++ b/pkg/util/indexparamcheck/index_checker.go @@ -22,7 +22,7 @@ import ( type IndexChecker interface { CheckTrain(map[string]string) error - CheckValidDataType(dType schemapb.DataType) error + CheckValidDataType(field *schemapb.FieldSchema) error SetDefaultMetricTypeIfNotExist(map[string]string, schemapb.DataType) StaticCheck(map[string]string) error } diff --git a/pkg/util/indexparamcheck/inverted_checker.go b/pkg/util/indexparamcheck/inverted_checker.go index dfc24127d356..8d6893c10085 100644 --- a/pkg/util/indexparamcheck/inverted_checker.go +++ b/pkg/util/indexparamcheck/inverted_checker.go @@ -16,7 +16,8 @@ func (c *INVERTEDChecker) CheckTrain(params map[string]string) error { return c.scalarIndexChecker.CheckTrain(params) } -func (c *INVERTEDChecker) CheckValidDataType(dType schemapb.DataType) error { +func (c *INVERTEDChecker) CheckValidDataType(field *schemapb.FieldSchema) error { + dType := field.GetDataType() if !typeutil.IsBoolType(dType) && !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) && !typeutil.IsArrayType(dType) { return fmt.Errorf("INVERTED are not supported on %s field", dType.String()) diff --git a/pkg/util/indexparamcheck/inverted_checker_test.go b/pkg/util/indexparamcheck/inverted_checker_test.go index 7a3129006149..baecd97dd176 100644 --- a/pkg/util/indexparamcheck/inverted_checker_test.go +++ b/pkg/util/indexparamcheck/inverted_checker_test.go @@ -13,13 +13,13 @@ func Test_INVERTEDIndexChecker(t *testing.T) { assert.NoError(t, c.CheckTrain(map[string]string{})) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_VarChar)) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_String)) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Bool)) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Int64)) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Float)) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Array)) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Array})) - assert.Error(t, c.CheckValidDataType(schemapb.DataType_JSON)) - assert.Error(t, c.CheckValidDataType(schemapb.DataType_FloatVector)) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector})) } diff --git a/pkg/util/indexparamcheck/ivf_base_checker_test.go b/pkg/util/indexparamcheck/ivf_base_checker_test.go index ad0ad42a2090..4a379038dde3 100644 --- a/pkg/util/indexparamcheck/ivf_base_checker_test.go +++ b/pkg/util/indexparamcheck/ivf_base_checker_test.go @@ -142,7 +142,7 @@ func Test_ivfBaseChecker_CheckValidDataType(t *testing.T) { c := newIVFBaseChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/ivf_pq_checker_test.go b/pkg/util/indexparamcheck/ivf_pq_checker_test.go index d9f655f87471..b4de37579fb4 100644 --- a/pkg/util/indexparamcheck/ivf_pq_checker_test.go +++ b/pkg/util/indexparamcheck/ivf_pq_checker_test.go @@ -207,7 +207,7 @@ func Test_ivfPQChecker_CheckValidDataType(t *testing.T) { c := newIVFPQChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/ivf_sq_checker_test.go b/pkg/util/indexparamcheck/ivf_sq_checker_test.go index fa8a5a73c86e..9478623fe89e 100644 --- a/pkg/util/indexparamcheck/ivf_sq_checker_test.go +++ b/pkg/util/indexparamcheck/ivf_sq_checker_test.go @@ -162,7 +162,7 @@ func Test_ivfSQChecker_CheckValidDataType(t *testing.T) { c := newIVFSQChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go b/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go index 1b6d1a2647f1..3d64f830392f 100644 --- a/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go +++ b/pkg/util/indexparamcheck/raft_ivf_flat_checker_test.go @@ -156,7 +156,7 @@ func Test_raftIvfFlatChecker_CheckValidDataType(t *testing.T) { c := newRaftIVFFlatChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go b/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go index 5d7e431135a4..8c882900e9ef 100644 --- a/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go +++ b/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go @@ -216,7 +216,7 @@ func Test_raftIVFPQChecker_CheckValidDataType(t *testing.T) { c := newRaftIVFPQChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/scann_checker_test.go b/pkg/util/indexparamcheck/scann_checker_test.go index 7e86beeb1f83..4f7014c6fde5 100644 --- a/pkg/util/indexparamcheck/scann_checker_test.go +++ b/pkg/util/indexparamcheck/scann_checker_test.go @@ -159,7 +159,7 @@ func Test_scaNNChecker_CheckValidDataType(t *testing.T) { c := newScaNNChecker() for _, test := range cases { - err := c.CheckValidDataType(test.dType) + err := c.CheckValidDataType(&schemapb.FieldSchema{DataType: test.dType}) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go b/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go index 99ca1041a085..218d2d3e03a3 100644 --- a/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go +++ b/pkg/util/indexparamcheck/sparse_float_vector_base_checker.go @@ -32,8 +32,8 @@ func (c sparseFloatVectorBaseChecker) CheckTrain(params map[string]string) error return nil } -func (c sparseFloatVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error { - if !typeutil.IsSparseFloatVectorType(dType) { +func (c sparseFloatVectorBaseChecker) CheckValidDataType(field *schemapb.FieldSchema) error { + if !typeutil.IsSparseFloatVectorType(field.GetDataType()) { return fmt.Errorf("only sparse float vector is supported for the specified index tpye") } return nil diff --git a/pkg/util/indexparamcheck/stl_sort_checker.go b/pkg/util/indexparamcheck/stl_sort_checker.go index f0b152cef9cc..4b3441ad6dfc 100644 --- a/pkg/util/indexparamcheck/stl_sort_checker.go +++ b/pkg/util/indexparamcheck/stl_sort_checker.go @@ -16,8 +16,8 @@ func (c *STLSORTChecker) CheckTrain(params map[string]string) error { return c.scalarIndexChecker.CheckTrain(params) } -func (c *STLSORTChecker) CheckValidDataType(dType schemapb.DataType) error { - if !typeutil.IsArithmetic(dType) { +func (c *STLSORTChecker) CheckValidDataType(field *schemapb.FieldSchema) error { + if !typeutil.IsArithmetic(field.GetDataType()) { return fmt.Errorf("STL_SORT are only supported on numeric field") } return nil diff --git a/pkg/util/indexparamcheck/stl_sort_checker_test.go b/pkg/util/indexparamcheck/stl_sort_checker_test.go index a4af0c51e64e..771a51cd32f6 100644 --- a/pkg/util/indexparamcheck/stl_sort_checker_test.go +++ b/pkg/util/indexparamcheck/stl_sort_checker_test.go @@ -13,10 +13,10 @@ func Test_STLSORTIndexChecker(t *testing.T) { assert.NoError(t, c.CheckTrain(map[string]string{})) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Int64)) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Float)) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) - assert.Error(t, c.CheckValidDataType(schemapb.DataType_Bool)) - assert.Error(t, c.CheckValidDataType(schemapb.DataType_VarChar)) - assert.Error(t, c.CheckValidDataType(schemapb.DataType_JSON)) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) } diff --git a/pkg/util/indexparamcheck/trie_checker.go b/pkg/util/indexparamcheck/trie_checker.go index 1c63fdc36624..002014e42022 100644 --- a/pkg/util/indexparamcheck/trie_checker.go +++ b/pkg/util/indexparamcheck/trie_checker.go @@ -16,8 +16,8 @@ func (c *TRIEChecker) CheckTrain(params map[string]string) error { return c.scalarIndexChecker.CheckTrain(params) } -func (c *TRIEChecker) CheckValidDataType(dType schemapb.DataType) error { - if !typeutil.IsStringType(dType) { +func (c *TRIEChecker) CheckValidDataType(field *schemapb.FieldSchema) error { + if !typeutil.IsStringType(field.GetDataType()) { return fmt.Errorf("TRIE are only supported on varchar field") } return nil diff --git a/pkg/util/indexparamcheck/trie_checker_test.go b/pkg/util/indexparamcheck/trie_checker_test.go index 25c6313ea899..3e1eaea1c589 100644 --- a/pkg/util/indexparamcheck/trie_checker_test.go +++ b/pkg/util/indexparamcheck/trie_checker_test.go @@ -13,11 +13,11 @@ func Test_TrieIndexChecker(t *testing.T) { assert.NoError(t, c.CheckTrain(map[string]string{})) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_VarChar)) - assert.NoError(t, c.CheckValidDataType(schemapb.DataType_String)) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) + assert.NoError(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_String})) - assert.Error(t, c.CheckValidDataType(schemapb.DataType_Bool)) - assert.Error(t, c.CheckValidDataType(schemapb.DataType_Int64)) - assert.Error(t, c.CheckValidDataType(schemapb.DataType_Float)) - assert.Error(t, c.CheckValidDataType(schemapb.DataType_JSON)) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_Float})) + assert.Error(t, c.CheckValidDataType(&schemapb.FieldSchema{DataType: schemapb.DataType_JSON})) }