Skip to content

Commit

Permalink
MB-61742: Interpreting vectors w/ spurious float32s from base64 encod…
Browse files Browse the repository at this point in the history
…ings (#2026)

+ Handle the situation when interpreted vectors have NaN or Inf when
decoded from base64 encoded strings.
+ Also, in the vector_base64 path - it seems we were unnecessarily
casting a []float32 into an interface{} to again be interpreted as a
[]float32 which we can avoid.
  • Loading branch information
abhinavdangeti committed May 3, 2024
1 parent 490c4b9 commit 1a19544
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 22 deletions.
19 changes: 14 additions & 5 deletions document/field_vector_base64.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"math"

"github.com/blevesearch/bleve/v2/size"
"github.com/blevesearch/bleve/v2/util"
index "github.com/blevesearch/bleve_index_api"
)

Expand Down Expand Up @@ -81,14 +82,14 @@ func (n *VectorBase64Field) GoString() string {
func NewVectorBase64Field(name string, arrayPositions []uint64, vectorBase64 string,
dims int, similarity, vectorIndexOptimizedFor string) (*VectorBase64Field, error) {

vector, err := DecodeVector(vectorBase64)
decodedVector, err := DecodeVector(vectorBase64)
if err != nil {
return nil, err
}

return &VectorBase64Field{
vectorField: NewVectorFieldWithIndexingOptions(name, arrayPositions,
vector, dims, similarity,
decodedVector, dims, similarity,
vectorIndexOptimizedFor, DefaultVectorIndexingOptions),

base64Encoding: vectorBase64,
Expand All @@ -98,7 +99,6 @@ func NewVectorBase64Field(name string, arrayPositions []uint64, vectorBase64 str
// This function takes a base64 encoded string and decodes it into
// a vector.
func DecodeVector(encodedValue string) ([]float32, error) {

// We first decode the encoded string into a byte array.
decodedString, err := base64.StdEncoding.DecodeString(encodedValue)
if err != nil {
Expand All @@ -108,16 +108,25 @@ func DecodeVector(encodedValue string) ([]float32, error) {
// The array is expected to be divisible by 4 because each float32
// should occupy 4 bytes
if len(decodedString)%size.SizeOfFloat32 != 0 {
return nil, fmt.Errorf("Decoded byte array not divisible by %d", size.SizeOfFloat32)
return nil, fmt.Errorf("decoded byte array not divisible by %d", size.SizeOfFloat32)
}
dims := int(len(decodedString) / size.SizeOfFloat32)

if dims <= 0 {
return nil, fmt.Errorf("unable to decode encoded vector")
}

decodedVector := make([]float32, dims)

// We iterate through the array 4 bytes at a time and convert each of
// them to a float32 value by reading them in a little endian notation
for i := 0; i < dims; i++ {
bytes := decodedString[i*size.SizeOfFloat32 : (i+1)*size.SizeOfFloat32]
decodedVector[i] = math.Float32frombits(binary.LittleEndian.Uint32(bytes))
entry := math.Float32frombits(binary.LittleEndian.Uint32(bytes))
if !util.IsValidFloat32(float64(entry)) {
return nil, fmt.Errorf("invalid float32 value: %f", entry)
}
decodedVector[i] = entry
}

return decodedVector, nil
Expand Down
9 changes: 4 additions & 5 deletions document/field_vector_base64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ func TestDecodeVector(t *testing.T) {
vecBytes := bytifyVec(vec)
encodedVec := base64.StdEncoding.EncodeToString(vecBytes)

decodedVec, err := DecodeVector(encodedVec)
decodedVector, err := DecodeVector(encodedVec)
if err != nil {
t.Error(err)
}
if len(decodedVec) != len(vec) {
if len(decodedVector) != len(vec) {
t.Errorf("Decoded vector dimensions not same as original vector dimensions")
}

for i := range vec {
if vec[i] != decodedVec[i] {
t.Errorf("Decoded vector not the same as original vector")
if vec[i] != decodedVector[i] {
t.Fatalf("Decoded vector not the same as original vector %v != %v", vec[i], decodedVector[i])
}
}
}
Expand Down Expand Up @@ -99,7 +99,6 @@ func BenchmarkDecodeVector1536(b *testing.B) {
}

func bytifyVec(vec []float32) []byte {

buf := new(bytes.Buffer)

for _, v := range vec {
Expand Down
13 changes: 10 additions & 3 deletions mapping/mapping_vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,19 @@ func (fm *FieldMapping) processVectorBase64(propertyMightBeVectorBase64 interfac
return
}

propertyMightBeVector, err := document.DecodeVector(encodedString)
if err != nil {
decodedVector, err := document.DecodeVector(encodedString)
if err != nil || len(decodedVector) != fm.Dims {
return
}

fm.processVector(propertyMightBeVector, pathString, path, indexes, context)
fieldName := getFieldName(pathString, path, fm)
options := fm.Options()
field := document.NewVectorFieldWithIndexingOptions(fieldName, indexes, decodedVector,
fm.Dims, fm.Similarity, fm.VectorIndexOptimizedFor, options)
context.doc.AddField(field)

// "_all" composite field is not applicable for vector_base64 field
context.excludedFromAll = append(context.excludedFromAll, fieldName)
}

// -----------------------------------------------------------------------------
Expand Down
15 changes: 7 additions & 8 deletions search_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,14 @@ func validateKNN(req *SearchRequest) error {
if q == nil {
return fmt.Errorf("knn query cannot be nil")
}
if q.VectorBase64 != "" {
if q.Vector == nil {
vec, err := document.DecodeVector(q.VectorBase64)
if err != nil {
return err
}

q.Vector = vec
if len(q.Vector) == 0 && q.VectorBase64 != "" {
// consider vector_base64 only if vector is not provided
decodedVector, err := document.DecodeVector(q.VectorBase64)
if err != nil {
return err
}

q.Vector = decodedVector
}
if q.K <= 0 || len(q.Vector) == 0 {
return fmt.Errorf("k must be greater than 0 and vector must be non-empty")
Expand Down
6 changes: 5 additions & 1 deletion util/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func ExtractNumericValFloat32(v interface{}) (float32, bool) {
switch {
case val.CanFloat():
floatVal := val.Float()
if floatVal > math.MaxFloat32 {
if !IsValidFloat32(floatVal) {
return 0, false
}
return float32(floatVal), true
Expand All @@ -60,3 +60,7 @@ func ExtractNumericValFloat32(v interface{}) (float32, bool) {

return 0, false
}

func IsValidFloat32(val float64) bool {
return !math.IsNaN(val) && !math.IsInf(val, 0) && val <= math.MaxFloat32
}

0 comments on commit 1a19544

Please sign in to comment.