Skip to content

Commit

Permalink
sparse: update sparse related tests. not testing sparse vectors
Browse files Browse the repository at this point in the history
in golang tests that uses indexcgowrapper/dataset.go to build index,

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian committed Mar 13, 2024
1 parent 2e506f0 commit 817c7f3
Show file tree
Hide file tree
Showing 15 changed files with 111 additions and 54 deletions.
4 changes: 2 additions & 2 deletions internal/querynodev2/local_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (suite *LocalWorkerTestSuite) BeforeTest(suiteName, testName string) {
err = suite.node.Start()
suite.NoError(err)

suite.schema = segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
suite.schema = segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
suite.indexMeta = segments.GenTestIndexMeta(suite.collectionID, suite.schema)
collection := segments.NewCollection(suite.collectionID, suite.schema, suite.indexMeta, querypb.LoadType_LoadCollection)
loadMata := &querypb.LoadMetaInfo{
Expand All @@ -111,7 +111,7 @@ func (suite *LocalWorkerTestSuite) AfterTest(suiteName, testName string) {

func (suite *LocalWorkerTestSuite) TestLoadSegment() {
// load empty
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: suite.node.session.GetServerID(),
Expand Down
4 changes: 2 additions & 2 deletions internal/querynodev2/pipeline/insert_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (suite *InsertNodeSuite) SetupSuite() {

func (suite *InsertNodeSuite) TestBasic() {
// data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
in := suite.buildInsertNodeMsg(schema)

collection := segments.NewCollection(suite.collectionID, schema, segments.GenTestIndexMeta(suite.collectionID, schema), querypb.LoadType_LoadCollection)
Expand Down Expand Up @@ -92,7 +92,7 @@ func (suite *InsertNodeSuite) TestBasic() {
}

func (suite *InsertNodeSuite) TestDataTypeNotSupported() {
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
in := suite.buildInsertNodeMsg(schema)

collection := segments.NewCollection(suite.collectionID, schema, segments.GenTestIndexMeta(suite.collectionID, schema), querypb.LoadType_LoadCollection)
Expand Down
2 changes: 1 addition & 1 deletion internal/querynodev2/pipeline/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (suite *PipelineTestSuite) SetupTest() {
func (suite *PipelineTestSuite) TestBasic() {
// init mock
// mock collection manager
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
collection := segments.NewCollection(suite.collectionID, schema, segments.GenTestIndexMeta(suite.collectionID, schema), querypb.LoadType_LoadCollection)
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection)

Expand Down
2 changes: 1 addition & 1 deletion internal/querynodev2/segments/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (s *ManagerSuite) SetupTest() {
s.mgr = NewSegmentManager()

for i, id := range s.segmentIDs {
schema := GenTestCollectionSchema("manager-suite", schemapb.DataType_Int64)
schema := GenTestCollectionSchema("manager-suite", schemapb.DataType_Int64, true)
segment, err := NewSegment(
context.Background(),
NewCollection(s.collectionIDs[i], schema, GenTestIndexMeta(s.collectionIDs[i], schema), querypb.LoadType_LoadCollection),
Expand Down
66 changes: 59 additions & 7 deletions internal/querynodev2/segments/mock_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/testutils"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

Expand All @@ -60,6 +61,7 @@ const (
IndexFaissBinIDMap = "BIN_FLAT"
IndexFaissBinIVFFlat = "BIN_IVF_FLAT"
IndexHNSW = "HNSW"
IndexSparseWand = "SPARSE_WAND"

nlist = 100
m = 4
Expand Down Expand Up @@ -130,6 +132,13 @@ var simpleBFloat16VecField = vecFieldParam{
fieldName: "bfloat16VectorField",
}

var simpleSparseFloatVectorField = vecFieldParam{
id: 114,
metricType: metric.IP,
vecType: schemapb.DataType_SparseFloatVector,
fieldName: "sparseFloatVectorField",
}

var simpleBoolField = constFieldParam{
id: 102,
dataType: schemapb.DataType_Bool,
Expand Down Expand Up @@ -235,23 +244,27 @@ func genVectorFieldSchema(param vecFieldParam) *schemapb.FieldSchema {
Name: param.fieldName,
IsPrimaryKey: false,
DataType: param.vecType,
TypeParams: []*commonpb.KeyValuePair{
{
Key: dimKey,
Value: strconv.Itoa(param.dim),
},
},
IndexParams: []*commonpb.KeyValuePair{
{
Key: metricTypeKey,
Value: param.metricType,
},
},
}
if fieldVec.DataType != schemapb.DataType_SparseFloatVector {
fieldVec.TypeParams = []*commonpb.KeyValuePair{
{
Key: dimKey,
Value: strconv.Itoa(param.dim),
},
}
}
return fieldVec
}

func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType) *schemapb.CollectionSchema {
// some tests do not yet support sparse float vector, see comments of
// GenSparseFloatVecDataset in indexcgowrapper/dataset.go
func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, withSparse bool) *schemapb.CollectionSchema {
fieldRowID := genConstantFieldSchema(rowIDField)
fieldTimestamp := genConstantFieldSchema(timestampField)
fieldBool := genConstantFieldSchema(simpleBoolField)
Expand Down Expand Up @@ -292,6 +305,10 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType) *s
},
}

if withSparse {
schema.Fields = append(schema.Fields, genVectorFieldSchema(simpleSparseFloatVectorField))
}

for i, field := range schema.GetFields() {
field.FieldID = 100 + int64(i)
}
Expand Down Expand Up @@ -329,6 +346,14 @@ func GenTestIndexInfoList(collectionID int64, schema *schemapb.CollectionSchema)
{Key: "nlist", Value: "128"},
}
}
case schemapb.DataType_SparseFloatVector:
{
index.IndexParams = []*commonpb.KeyValuePair{
{Key: common.MetricTypeKey, Value: metric.IP},
{Key: common.IndexTypeKey, Value: IndexSparseWand},
{Key: "M", Value: "16"},
}
}
}
res = append(res, index)
}
Expand Down Expand Up @@ -622,6 +647,7 @@ func GenTestScalarFieldData(dType schemapb.DataType, fieldName string, fieldID i
return ret
}

// dim is ignored for sparse
func GenTestVectorFiledData(dType schemapb.DataType, fieldName string, fieldID int64, numRows int, dim int) *schemapb.FieldData {
ret := &schemapb.FieldData{
Type: dType,
Expand Down Expand Up @@ -671,6 +697,20 @@ func GenTestVectorFiledData(dType schemapb.DataType, fieldName string, fieldID i
},
},
}
case schemapb.DataType_SparseFloatVector:
ret.FieldId = fieldID
sparseData := testutils.GenerateSparseFloatVectors(numRows)
ret.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: sparseData.Dim,
Data: &schemapb.VectorField_SparseFloatVector{
SparseFloatVector: &schemapb.SparseFloatArray{
Dim: sparseData.Dim,
Contents: sparseData.Contents,
},
},
},
}
default:
panic("data type not supported")
}
Expand Down Expand Up @@ -864,6 +904,11 @@ func genInsertData(msgLength int, schema *schemapb.CollectionSchema) (*storage.I
Data: generateBinaryVectors(msgLength, dim),
Dim: dim,
}
case schemapb.DataType_SparseFloatVector:
sparseData := testutils.GenerateSparseFloatVectors(msgLength)
insertData.Data[f.FieldID] = &storage.SparseFloatVectorFieldData{
SparseFloatArray: *sparseData,
}
default:
err := errors.New("data type not supported")
return nil, err
Expand Down Expand Up @@ -963,6 +1008,11 @@ func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64,
dataset = indexcgowrapper.GenBinaryVecDataset(generateBinaryVectors(msgLength, defaultDim))
case schemapb.DataType_FloatVector:
dataset = indexcgowrapper.GenFloatVecDataset(generateFloatVectors(msgLength, defaultDim))
case schemapb.DataType_SparseFloatVector:
data := testutils.GenerateSparseFloatVectors(msgLength)
dataset = indexcgowrapper.GenSparseFloatVecDataset(&storage.SparseFloatVectorFieldData{
SparseFloatArray: *data,
})
}

err = index.Build(dataset)
Expand Down Expand Up @@ -1366,6 +1416,8 @@ func genInsertMsg(collection *Collection, partitionID, segment int64, numRows in
case schemapb.DataType_BFloat16Vector:
dim := simpleBFloat16VecField.dim // if no dim specified, use simpleFloatVecField's dim
fieldsData = append(fieldsData, GenTestVectorFiledData(f.DataType, f.Name, f.FieldID, numRows, dim))
case schemapb.DataType_SparseFloatVector:
fieldsData = append(fieldsData, GenTestVectorFiledData(f.DataType, f.Name, f.FieldID, numRows, 0))
default:
err := errors.New("data type not supported")
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion internal/querynodev2/segments/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (suite *PlanSuite) SetupTest() {
suite.collectionID = 100
suite.partitionID = 10
suite.segmentID = 1
schema := GenTestCollectionSchema("plan-suite", schemapb.DataType_Int64)
schema := GenTestCollectionSchema("plan-suite", schemapb.DataType_Int64, true)
suite.collection = NewCollection(suite.collectionID, schema, GenTestIndexMeta(suite.collectionID, schema), querypb.LoadType_LoadCollection)
suite.collection.AddPartition(suite.partitionID)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/querynodev2/segments/reduce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (suite *ReduceSuite) SetupTest() {
suite.collectionID = 100
suite.partitionID = 10
suite.segmentID = 1
schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64)
schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true)
suite.collection = NewCollection(suite.collectionID,
schema,
GenTestIndexMeta(suite.collectionID, schema),
Expand Down
2 changes: 1 addition & 1 deletion internal/querynodev2/segments/retrieve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (suite *RetrieveSuite) SetupTest() {
suite.segmentID = 1

suite.manager = NewManager()
schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64)
schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true)
indexMeta := GenTestIndexMeta(suite.collectionID, schema)
suite.manager.Collection.PutOrRef(suite.collectionID,
schema,
Expand Down
2 changes: 1 addition & 1 deletion internal/querynodev2/segments/search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (suite *SearchSuite) SetupTest() {
suite.segmentID = 1

suite.manager = NewManager()
schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64)
schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true)
indexMeta := GenTestIndexMeta(suite.collectionID, schema)
suite.manager.Collection.PutOrRef(suite.collectionID,
schema,
Expand Down
8 changes: 4 additions & 4 deletions internal/querynodev2/segments/segment_loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (suite *SegmentLoaderSuite) SetupTest() {
initcore.InitRemoteChunkManager(paramtable.Get())

// Data
suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64)
suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64, false)
indexMeta := GenTestIndexMeta(suite.collectionID, suite.schema)
loadMeta := &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection,
Expand Down Expand Up @@ -665,7 +665,7 @@ func (suite *SegmentLoaderDetailSuite) SetupSuite() {
suite.partitionID = rand.Int63()
suite.segmentID = rand.Int63()
suite.segmentNum = 5
suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64)
suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64, false)
}

func (suite *SegmentLoaderDetailSuite) SetupTest() {
Expand All @@ -684,7 +684,7 @@ func (suite *SegmentLoaderDetailSuite) SetupTest() {
initcore.InitRemoteChunkManager(paramtable.Get())

// Data
schema := GenTestCollectionSchema("test", schemapb.DataType_Int64)
schema := GenTestCollectionSchema("test", schemapb.DataType_Int64, false)

indexMeta := GenTestIndexMeta(suite.collectionID, schema)
loadMeta := &querypb.LoadMetaInfo{
Expand Down Expand Up @@ -853,7 +853,7 @@ func (suite *SegmentLoaderV2Suite) SetupTest() {
initcore.InitRemoteChunkManager(paramtable.Get())

// Data
suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64)
suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64, false)
indexMeta := GenTestIndexMeta(suite.collectionID, suite.schema)
loadMeta := &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection,
Expand Down
2 changes: 1 addition & 1 deletion internal/querynodev2/segments/segment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (suite *SegmentSuite) SetupTest() {
suite.segmentID = 1

suite.manager = NewManager()
schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64)
schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true)
indexMeta := GenTestIndexMeta(suite.collectionID, schema)
suite.manager.Collection.PutOrRef(suite.collectionID,
schema,
Expand Down
2 changes: 1 addition & 1 deletion internal/querynodev2/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (suite *QueryNodeSuite) TestStop() {

suite.node.manager = segments.NewManager()

schema := segments.GenTestCollectionSchema("test_stop", schemapb.DataType_Int64)
schema := segments.GenTestCollectionSchema("test_stop", schemapb.DataType_Int64, true)
collection := segments.NewCollection(1, schema, nil, querypb.LoadType_LoadCollection)
segment, err := segments.NewSegment(
context.Background(),
Expand Down
Loading

0 comments on commit 817c7f3

Please sign in to comment.