Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: add sparse float vector support to restful v2 #33231

Merged
merged 1 commit into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,10 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche
if vectorField == nil {
return nil, errors.New("cannot find a vector field named: " + fieldName)
}
dim, _ := getDim(vectorField)
dim := int64(0)
if !typeutil.IsSparseFloatVectorType(vectorField.DataType) {
dim, _ = getDim(vectorField)
}
phv, err := convertVectors2Placeholder(body, vectorField.DataType, dim)
if err != nil {
return nil, err
Expand Down
39 changes: 35 additions & 4 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,11 @@ func TestDatabaseWrapper(t *testing.T) {
func TestCreateCollection(t *testing.T) {
postTestCases := []requestBodyTestCase{}
mp := mocks.NewMockProxy(t)
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(11)
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(12)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(6)
mp.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Times(6)
mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Twice()
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once()
mp.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Twice()
testEngine := initHTTPServerV2(mp, false)
path := versionalV2(CollectionCategory, CreateAction)
// quickly create collection
Expand Down Expand Up @@ -564,6 +564,18 @@ func TestCreateCollection(t *testing.T) {
]
}}`),
})
// dim should not be specified for SparseFloatVector field
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": {
"fields": [
{"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Int64", "isPartitionKey": false, "elementTypeParams": {}},
{"fieldName": "partition_field", "dataType": "VarChar", "isPartitionKey": true, "elementTypeParams": {"max_length": 256}},
{"fieldName": "book_intro", "dataType": "SparseFloatVector", "elementTypeParams": {}}
]
}, "params": {"partitionsNum": "32"}}`),
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": {
Expand Down Expand Up @@ -612,6 +624,18 @@ func TestCreateCollection(t *testing.T) {
errMsg: "",
errCode: 65535,
})
postTestCases = append(postTestCases, requestBodyTestCase{
path: path,
requestBody: []byte(`{"collectionName": "` + DefaultCollectionName + `", "schema": {
"fields": [
{"fieldName": "book_id", "dataType": "Int64", "isPrimary": true, "elementTypeParams": {}},
{"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}},
{"fieldName": "book_intro", "dataType": "SparseFloatVector", "elementTypeParams": {"dim": 2}}
]
}, "indexParams": [{"fieldName": "book_intro", "indexName": "book_intro_vector", "metricType": "L2"}]}`),
errMsg: "",
errCode: 65535,
})

for _, testcase := range postTestCases {
t.Run("post"+testcase.path, func(t *testing.T) {
Expand Down Expand Up @@ -1240,16 +1264,19 @@ func TestSearchV2(t *testing.T) {
float16VectorField.Name = "float16Vector"
bfloat16VectorField := generateVectorFieldSchema(schemapb.DataType_BFloat16Vector)
bfloat16VectorField.Name = "bfloat16Vector"
sparseFloatVectorField := generateVectorFieldSchema(schemapb.DataType_SparseFloatVector)
sparseFloatVectorField.Name = "sparseFloatVector"
collSchema.Fields = append(collSchema.Fields, &binaryVectorField)
collSchema.Fields = append(collSchema.Fields, &float16VectorField)
collSchema.Fields = append(collSchema.Fields, &bfloat16VectorField)
collSchema.Fields = append(collSchema.Fields, &sparseFloatVectorField)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
CollectionName: DefaultCollectionName,
Schema: collSchema,
ShardsNum: ShardNumDefault,
Status: &StatusSuccess,
}, nil).Times(9)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Twice()
}, nil).Times(10)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
testEngine := initHTTPServerV2(mp, false)
queryTestCases := []requestBodyTestCase{}
queryTestCases = append(queryTestCases, requestBodyTestCase{
Expand Down Expand Up @@ -1377,6 +1404,10 @@ func TestSearchV2(t *testing.T) {
errMsg: "can only accept json format request, error: dimension: 2, bytesLen: 4, but length of []byte: 3: invalid parameter[expected=BFloat16Vector][actual=\x01\x02\x03]",
errCode: 1801,
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: SearchAction,
requestBody: []byte(`{"collectionName": "book", "data": [{"1": 0.1}], "annsField": "sparseFloatVector", "filter": "book_id in [2, 4, 6, 8]", "limit": 4, "outputFields": ["word_count"]}`),
})

for _, testcase := range queryTestCases {
t.Run("search", func(t *testing.T) {
Expand Down
51 changes: 51 additions & 0 deletions internal/distributed/proxy/httpserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,15 @@
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray
}
reallyData[fieldName] = vectorArray
case schemapb.DataType_SparseFloatVector:
if dataString == "" {
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], "", "missing vector field: "+fieldName), reallyDataArray

Check warning on line 253 in internal/distributed/proxy/httpserver/utils.go

View check run for this annotation

Codecov / codecov/patch

internal/distributed/proxy/httpserver/utils.go#L253

Added line #L253 was not covered by tests
}
sparseVec, err := typeutil.CreateSparseFloatRowFromJSON([]byte(dataString))
if err != nil {
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray
}
reallyData[fieldName] = sparseVec
case schemapb.DataType_Float16Vector:
if dataString == "" {
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], "", "missing vector field: "+fieldName), reallyDataArray
Expand Down Expand Up @@ -638,6 +647,9 @@
data = make([][]byte, 0, rowsLen)
dim, _ := getDim(field)
nameDims[field.Name] = dim
case schemapb.DataType_SparseFloatVector:
data = make([][]byte, 0, rowsLen)
nameDims[field.Name] = int64(0)
default:
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", field.DataType, field.Name)
}
Expand Down Expand Up @@ -704,6 +716,13 @@
nameColumns[field.Name] = append(nameColumns[field.Name].([][]byte), candi.v.Interface().([]byte))
case schemapb.DataType_BFloat16Vector:
nameColumns[field.Name] = append(nameColumns[field.Name].([][]byte), candi.v.Interface().([]byte))
case schemapb.DataType_SparseFloatVector:
content := candi.v.Interface().([]byte)
rowSparseDim := typeutil.SparseFloatRowDim(content)
if rowSparseDim > nameDims[field.Name] {
nameDims[field.Name] = rowSparseDim
}
nameColumns[field.Name] = append(nameColumns[field.Name].([][]byte), content)
default:
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", field.DataType, field.Name)
}
Expand Down Expand Up @@ -895,6 +914,18 @@
},
},
}
case schemapb.DataType_SparseFloatVector:
colData.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: nameDims[name],
Data: &schemapb.VectorField_SparseFloatVector{
SparseFloatVector: &schemapb.SparseFloatArray{
Dim: nameDims[name],
Contents: column.([][]byte),
},
},
},
}
default:
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", colData.Type, name)
}
Expand Down Expand Up @@ -963,6 +994,19 @@
return values, nil
}

func serializeSparseFloatVectors(vectors []gjson.Result, dataType schemapb.DataType) ([][]byte, error) {
values := make([][]byte, 0)
for _, vector := range vectors {
vectorBytes := []byte(vector.String())
sparseVector, err := typeutil.CreateSparseFloatRowFromJSON(vectorBytes)
if err != nil {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vector.String(), err.Error())

Check warning on line 1003 in internal/distributed/proxy/httpserver/utils.go

View check run for this annotation

Codecov / codecov/patch

internal/distributed/proxy/httpserver/utils.go#L1003

Added line #L1003 was not covered by tests
}
values = append(values, sparseVector)
}
return values, nil
}

func convertVectors2Placeholder(body string, dataType schemapb.DataType, dimension int64) (*commonpb.PlaceholderValue, error) {
var valueType commonpb.PlaceholderType
var values [][]byte
Expand All @@ -980,6 +1024,9 @@
case schemapb.DataType_BFloat16Vector:
valueType = commonpb.PlaceholderType_BFloat16Vector
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2)
case schemapb.DataType_SparseFloatVector:
valueType = commonpb.PlaceholderType_SparseFloatVector
values, err = serializeSparseFloatVectors(gjson.Get(body, HTTPRequestData).Array(), dataType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether #20415 will happen again when sparse vector contains int64?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the parsing function to reject out of range numbers

}
if err != nil {
return nil, err
Expand Down Expand Up @@ -1070,6 +1117,8 @@
rowsNum = int64(len(fieldDataList[0].GetVectors().GetFloat16Vector())/2) / fieldDataList[0].GetVectors().GetDim()
case schemapb.DataType_BFloat16Vector:
rowsNum = int64(len(fieldDataList[0].GetVectors().GetBfloat16Vector())/2) / fieldDataList[0].GetVectors().GetDim()
case schemapb.DataType_SparseFloatVector:
rowsNum = int64(len(fieldDataList[0].GetVectors().GetSparseFloatVector().Contents))
default:
return nil, fmt.Errorf("the type(%v) of field(%v) is not supported, use other sdk please", fieldDataList[0].Type, fieldDataList[0].FieldName)
}
Expand Down Expand Up @@ -1125,6 +1174,8 @@
row[fieldDataList[j].FieldName] = fieldDataList[j].GetVectors().GetFloat16Vector()[i*(fieldDataList[j].GetVectors().GetDim()*2) : (i+1)*(fieldDataList[j].GetVectors().GetDim()*2)]
case schemapb.DataType_BFloat16Vector:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetVectors().GetBfloat16Vector()[i*(fieldDataList[j].GetVectors().GetDim()*2) : (i+1)*(fieldDataList[j].GetVectors().GetDim()*2)]
case schemapb.DataType_SparseFloatVector:
row[fieldDataList[j].FieldName] = typeutil.SparseFloatBytesToMap(fieldDataList[j].GetVectors().GetSparseFloatVector().Contents[i])
case schemapb.DataType_Array:
row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetArrayData().Data[i]
case schemapb.DataType_JSON:
Expand Down
Loading
Loading