From 80d9b18390c4b2fdcab84f51e95efd856431fcf8 Mon Sep 17 00:00:00 2001 From: Mohd Shaad Khan <65341373+moshaad7@users.noreply.github.com> Date: Fri, 3 Nov 2023 20:29:06 +0530 Subject: [PATCH] Support vector data type, query + searcher (#1857) + Add support for indexing and querying vector data (array of float32) + Corresponding to each field of type vector, there will be associated properties like Dimensions and Similarity, which will be used to determine validity of a vector and a criteria for scoring search hits, respectively. + Supports vector reader interface (searcher) and the `knn` construct within the SearchRequest. Related PR: * https://github.com/blevesearch/bleve_index_api/pull/34 --------- Co-authored-by: Abhi Dangeti Co-authored-by: Aditi Ahuja --- document/field_vector.go | 138 +++++++++++++++++++++ geo/parse.go | 41 ++----- index/scorch/snapshot_index_vr.go | 156 ++++++++++++++++++++++++ index/scorch/snapshot_vector_index.go | 57 +++++++++ index_alias_impl.go | 16 +-- index_impl.go | 7 +- mapping/document.go | 18 ++- mapping/field.go | 18 +++ mapping/index.go | 27 +++++ mapping/mapping.go | 2 + mapping/mapping_no_vectors.go | 46 +++++++ mapping/mapping_vectors.go | 119 +++++++++++++++++++ mapping_vector.go | 24 ++++ search.go | 132 +++------------------ search/query/disjunction.go | 10 +- search/query/knn.go | 72 +++++++++++ search/scorer/scorer_knn.go | 104 ++++++++++++++++ search/searcher/search_knn.go | 123 +++++++++++++++++++ search_knn.go | 165 ++++++++++++++++++++++++++ search_no_knn.go | 149 +++++++++++++++++++++++ util/extract.go | 57 +++++++++ util/knn.go | 38 ++++++ 22 files changed, 1348 insertions(+), 171 deletions(-) create mode 100644 document/field_vector.go create mode 100644 index/scorch/snapshot_index_vr.go create mode 100644 index/scorch/snapshot_vector_index.go create mode 100644 mapping/mapping_no_vectors.go create mode 100644 mapping/mapping_vectors.go create mode 100644 mapping_vector.go create mode 100644 search/query/knn.go create mode 100644 search/scorer/scorer_knn.go create mode 100644 search/searcher/search_knn.go create mode 100644 search_knn.go create mode 100644 search_no_knn.go create mode 100644 util/extract.go create mode 100644 util/knn.go diff --git a/document/field_vector.go b/document/field_vector.go new file mode 100644 index 000000000..59ac02026 --- /dev/null +++ b/document/field_vector.go @@ -0,0 +1,138 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package document + +import ( + "fmt" + "reflect" + + "github.com/blevesearch/bleve/v2/size" + index "github.com/blevesearch/bleve_index_api" +) + +var reflectStaticSizeVectorField int + +func init() { + var f VectorField + reflectStaticSizeVectorField = int(reflect.TypeOf(f).Size()) +} + +const DefaultVectorIndexingOptions = index.IndexField + +type VectorField struct { + name string + dims int // Dimensionality of the vector + similarity string // Similarity metric to use for scoring + options index.FieldIndexingOptions + value []float32 + numPlainTextBytes uint64 +} + +func (n *VectorField) Size() int { + return reflectStaticSizeVectorField + size.SizeOfPtr + + len(n.name) + + int(numBytesFloat32s(n.value)) +} + +func (n *VectorField) Name() string { + return n.name +} + +func (n *VectorField) ArrayPositions() []uint64 { + return nil +} + +func (n *VectorField) Options() index.FieldIndexingOptions { + return n.options +} + +func (n *VectorField) NumPlainTextBytes() uint64 { + return n.numPlainTextBytes +} + +func (n *VectorField) AnalyzedLength() int { + // vectors aren't analyzed + return 0 +} + +func (n *VectorField) EncodedFieldType() byte { + return 'v' +} + +func (n *VectorField) AnalyzedTokenFrequencies() index.TokenFrequencies { + // vectors aren't analyzed + return nil +} + +func (n *VectorField) Analyze() { + // vectors aren't analyzed +} + +func (n *VectorField) Value() []byte { + return nil +} + +func (n *VectorField) GoString() string { + return fmt.Sprintf("&document.VectorField{Name:%s, Options: %s, "+ + "Value: %+v}", n.name, n.options, n.value) +} + +// For the sake of not polluting the API, we are keeping arrayPositions as a +// parameter, but it is not used. +func NewVectorField(name string, arrayPositions []uint64, + vector []float32, dims int, similarity string) *VectorField { + return NewVectorFieldWithIndexingOptions(name, arrayPositions, + vector, dims, similarity, DefaultVectorIndexingOptions) +} + +// For the sake of not polluting the API, we are keeping arrayPositions as a +// parameter, but it is not used. +func NewVectorFieldWithIndexingOptions(name string, arrayPositions []uint64, + vector []float32, dims int, similarity string, + options index.FieldIndexingOptions) *VectorField { + options = options | DefaultVectorIndexingOptions + + return &VectorField{ + name: name, + dims: dims, + similarity: similarity, + options: options, + value: vector, + numPlainTextBytes: numBytesFloat32s(vector), + } +} + +func numBytesFloat32s(value []float32) uint64 { + return uint64(len(value) * size.SizeOfFloat32) +} + +// ----------------------------------------------------------------------------- +// Following methods help in implementing the bleve_index_api's VectorField +// interface. + +func (n *VectorField) Vector() []float32 { + return n.value +} + +func (n *VectorField) Dims() int { + return n.dims +} + +func (n *VectorField) Similarity() string { + return n.similarity +} diff --git a/geo/parse.go b/geo/parse.go index 01ec1dd81..34f731a9e 100644 --- a/geo/parse.go +++ b/geo/parse.go @@ -18,6 +18,8 @@ import ( "reflect" "strconv" "strings" + + "github.com/blevesearch/bleve/v2/util" ) // ExtractGeoPoint takes an arbitrary interface{} and tries it's best to @@ -61,12 +63,12 @@ func ExtractGeoPoint(thing interface{}) (lon, lat float64, success bool) { first := thingVal.Index(0) if first.CanInterface() { firstVal := first.Interface() - lon, foundLon = extractNumericVal(firstVal) + lon, foundLon = util.ExtractNumericValFloat64(firstVal) } second := thingVal.Index(1) if second.CanInterface() { secondVal := second.Interface() - lat, foundLat = extractNumericVal(secondVal) + lat, foundLat = util.ExtractNumericValFloat64(secondVal) } } } @@ -105,12 +107,12 @@ func ExtractGeoPoint(thing interface{}) (lon, lat float64, success bool) { // is it a map if l, ok := thing.(map[string]interface{}); ok { if lval, ok := l["lon"]; ok { - lon, foundLon = extractNumericVal(lval) + lon, foundLon = util.ExtractNumericValFloat64(lval) } else if lval, ok := l["lng"]; ok { - lon, foundLon = extractNumericVal(lval) + lon, foundLon = util.ExtractNumericValFloat64(lval) } if lval, ok := l["lat"]; ok { - lat, foundLat = extractNumericVal(lval) + lat, foundLat = util.ExtractNumericValFloat64(lval) } } @@ -121,19 +123,19 @@ func ExtractGeoPoint(thing interface{}) (lon, lat float64, success bool) { if strings.HasPrefix(strings.ToLower(fieldName), "lon") { if thingVal.Field(i).CanInterface() { fieldVal := thingVal.Field(i).Interface() - lon, foundLon = extractNumericVal(fieldVal) + lon, foundLon = util.ExtractNumericValFloat64(fieldVal) } } if strings.HasPrefix(strings.ToLower(fieldName), "lng") { if thingVal.Field(i).CanInterface() { fieldVal := thingVal.Field(i).Interface() - lon, foundLon = extractNumericVal(fieldVal) + lon, foundLon = util.ExtractNumericValFloat64(fieldVal) } } if strings.HasPrefix(strings.ToLower(fieldName), "lat") { if thingVal.Field(i).CanInterface() { fieldVal := thingVal.Field(i).Interface() - lat, foundLat = extractNumericVal(fieldVal) + lat, foundLat = util.ExtractNumericValFloat64(fieldVal) } } } @@ -157,25 +159,6 @@ func ExtractGeoPoint(thing interface{}) (lon, lat float64, success bool) { return lon, lat, foundLon && foundLat } -// extract numeric value (if possible) and returns a float64 -func extractNumericVal(v interface{}) (float64, bool) { - val := reflect.ValueOf(v) - if !val.IsValid() { - return 0, false - } - typ := val.Type() - switch typ.Kind() { - case reflect.Float32, reflect.Float64: - return val.Float(), true - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return float64(val.Int()), true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return float64(val.Uint()), true - } - - return 0, false -} - // various support interfaces which can be used to find lat/lon type loner interface { Lon() float64 @@ -209,12 +192,12 @@ func extractCoordinates(thing interface{}) []float64 { first := thingVal.Index(0) if first.CanInterface() { firstVal := first.Interface() - lon, foundLon = extractNumericVal(firstVal) + lon, foundLon = util.ExtractNumericValFloat64(firstVal) } second := thingVal.Index(1) if second.CanInterface() { secondVal := second.Interface() - lat, foundLat = extractNumericVal(secondVal) + lat, foundLat = util.ExtractNumericValFloat64(secondVal) } if !foundLon || !foundLat { diff --git a/index/scorch/snapshot_index_vr.go b/index/scorch/snapshot_index_vr.go new file mode 100644 index 000000000..9c4de2560 --- /dev/null +++ b/index/scorch/snapshot_index_vr.go @@ -0,0 +1,156 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package scorch + +import ( + "bytes" + "context" + "fmt" + "reflect" + + "github.com/blevesearch/bleve/v2/size" + index "github.com/blevesearch/bleve_index_api" + segment_api "github.com/blevesearch/scorch_segment_api/v2" +) + +var reflectStaticSizeIndexSnapshotVectorReader int + +func init() { + var istfr IndexSnapshotVectorReader + reflectStaticSizeIndexSnapshotVectorReader = int(reflect.TypeOf(istfr).Size()) +} + +type IndexSnapshotVectorReader struct { + vector []float32 + field string + k int64 + snapshot *IndexSnapshot + postings []segment_api.VecPostingsList + iterators []segment_api.VecPostingsIterator + segmentOffset int + currPosting segment_api.VecPosting + currID index.IndexInternalID + ctx context.Context +} + +func (i *IndexSnapshotVectorReader) Size() int { + sizeInBytes := reflectStaticSizeIndexSnapshotVectorReader + size.SizeOfPtr + + len(i.vector) + len(i.field) + len(i.currID) + + for _, entry := range i.postings { + sizeInBytes += entry.Size() + } + + for _, entry := range i.iterators { + sizeInBytes += entry.Size() + } + + if i.currPosting != nil { + sizeInBytes += i.currPosting.Size() + } + + return sizeInBytes +} + +func (i *IndexSnapshotVectorReader) Next(preAlloced *index.VectorDoc) ( + *index.VectorDoc, error) { + rv := preAlloced + if rv == nil { + rv = &index.VectorDoc{} + } + + for i.segmentOffset < len(i.iterators) { + next, err := i.iterators[i.segmentOffset].Next() + if err != nil { + return nil, err + } + if next != nil { + // make segment number into global number by adding offset + globalOffset := i.snapshot.offsets[i.segmentOffset] + nnum := next.Number() + rv.ID = docNumberToBytes(rv.ID, nnum+globalOffset) + rv.Score = float64(next.Score()) + + i.currID = rv.ID + i.currPosting = next + + return rv, nil + } + i.segmentOffset++ + } + + return nil, nil +} + +func (i *IndexSnapshotVectorReader) Advance(ID index.IndexInternalID, + preAlloced *index.VectorDoc) (*index.VectorDoc, error) { + + if i.currPosting != nil && bytes.Compare(i.currID, ID) >= 0 { + i2, err := i.snapshot.VectorReader(i.ctx, i.vector, i.field, i.k) + if err != nil { + return nil, err + } + // close the current term field reader before replacing it with a new one + _ = i.Close() + *i = *(i2.(*IndexSnapshotVectorReader)) + } + + num, err := docInternalToNumber(ID) + if err != nil { + return nil, fmt.Errorf("error converting to doc number % x - %v", ID, err) + } + segIndex, ldocNum := i.snapshot.segmentIndexAndLocalDocNumFromGlobal(num) + if segIndex >= len(i.snapshot.segment) { + return nil, fmt.Errorf("computed segment index %d out of bounds %d", + segIndex, len(i.snapshot.segment)) + } + // skip directly to the target segment + i.segmentOffset = segIndex + next, err := i.iterators[i.segmentOffset].Advance(ldocNum) + if err != nil { + return nil, err + } + if next == nil { + // we jumped directly to the segment that should have contained it + // but it wasn't there, so reuse Next() which should correctly + // get the next hit after it (we moved i.segmentOffset) + return i.Next(preAlloced) + } + + if preAlloced == nil { + preAlloced = &index.VectorDoc{} + } + preAlloced.ID = docNumberToBytes(preAlloced.ID, next.Number()+ + i.snapshot.offsets[segIndex]) + i.currID = preAlloced.ID + i.currPosting = next + return preAlloced, nil +} + +func (i *IndexSnapshotVectorReader) Count() uint64 { + var rv uint64 + for _, posting := range i.postings { + rv += posting.Count() + } + return rv +} + +func (i *IndexSnapshotVectorReader) Close() error { + // TODO Consider if any scope of recycling here. + return nil +} diff --git a/index/scorch/snapshot_vector_index.go b/index/scorch/snapshot_vector_index.go new file mode 100644 index 000000000..86aa6df54 --- /dev/null +++ b/index/scorch/snapshot_vector_index.go @@ -0,0 +1,57 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package scorch + +import ( + "context" + + index "github.com/blevesearch/bleve_index_api" + segment_api "github.com/blevesearch/scorch_segment_api/v2" +) + +func (is *IndexSnapshot) VectorReader(ctx context.Context, vector []float32, + field string, k int64) ( + index.VectorReader, error) { + + rv := &IndexSnapshotVectorReader{ + vector: vector, + field: field, + k: k, + snapshot: is, + } + + if rv.postings == nil { + rv.postings = make([]segment_api.VecPostingsList, len(is.segment)) + } + if rv.iterators == nil { + rv.iterators = make([]segment_api.VecPostingsIterator, len(is.segment)) + } + + for i, seg := range is.segment { + if sv, ok := seg.segment.(segment_api.VectorSegment); ok { + pl, err := sv.SimilarVectors(field, vector, k, seg.deleted) + if err != nil { + return nil, err + } + rv.postings[i] = pl + rv.iterators[i] = pl.Iterator(rv.iterators[i]) + } + } + + return rv, nil +} diff --git a/index_alias_impl.go b/index_alias_impl.go index a73dd6b8f..ccb52f244 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -430,21 +430,7 @@ func (i *indexAliasImpl) Swap(in, out []Index) { // Perhaps that part needs to be optional, // could be slower in remote usages. func createChildSearchRequest(req *SearchRequest) *SearchRequest { - rv := SearchRequest{ - Query: req.Query, - Size: req.Size + req.From, - From: 0, - Highlight: req.Highlight, - Fields: req.Fields, - Facets: req.Facets, - Explain: req.Explain, - Sort: req.Sort.Copy(), - IncludeLocations: req.IncludeLocations, - Score: req.Score, - SearchAfter: req.SearchAfter, - SearchBefore: req.SearchBefore, - } - return &rv + return copySearchRequest(req) } type asyncSearchResult struct { diff --git a/index_impl.go b/index_impl.go index d5f34a2a3..fe3a62e9e 100644 --- a/index_impl.go +++ b/index_impl.go @@ -496,7 +496,12 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr ctx = context.WithValue(ctx, search.GeoBufferPoolCallbackKey, search.GeoBufferPoolCallbackFunc(getBufferPool)) - searcher, err := req.Query.Searcher(ctx, indexReader, i.m, search.SearcherOptions{ + + // Using a disjunction query to get union of results from KNN query + // and the original query + searchQuery := disjunctQueryWithKNN(req) + + searcher, err := searchQuery.Searcher(ctx, indexReader, i.m, search.SearcherOptions{ Explain: req.Explain, IncludeTermVectors: req.IncludeLocations || req.Highlight != nil, Score: req.Score, diff --git a/mapping/document.go b/mapping/document.go index aacaa0a55..9f5aea581 100644 --- a/mapping/document.go +++ b/mapping/document.go @@ -77,10 +77,17 @@ func (dm *DocumentMapping) Validate(cache *registry.Cache) error { return err } } - switch field.Type { - case "text", "datetime", "number", "boolean", "geopoint", "geoshape", "IP": - default: - return fmt.Errorf("unknown field type: '%s'", field.Type) + + err := validateFieldType(field.Type) + if err != nil { + return err + } + + if field.Type == "vector" { + err := validateVectorField(field) + if err != nil { + return err + } } } return nil @@ -505,6 +512,9 @@ func (dm *DocumentMapping) processProperty(property interface{}, path []string, if subDocMapping != nil { for _, fieldMapping := range subDocMapping.Fields { switch fieldMapping.Type { + case "vector": + fieldMapping.processVector(property, pathString, path, + indexes, context) case "geopoint": fieldMapping.processGeoPoint(property, pathString, path, indexes, context) case "IP": diff --git a/mapping/field.go b/mapping/field.go index 82d51f317..41aeb1512 100644 --- a/mapping/field.go +++ b/mapping/field.go @@ -69,6 +69,14 @@ type FieldMapping struct { // the processing of freq/norm details when the default score based relevancy // isn't needed. SkipFreqNorm bool `json:"skip_freq_norm,omitempty"` + + // Dimensionality of the vector + Dims int `json:"dims,omitempty"` + + // Similarity is the similarity algorithm used for scoring + // vector fields. + // See: util.DefaultSimilarityMetric & util.SupportedSimilarityMetrics + Similarity string `json:"similarity,omitempty"` } // NewTextFieldMapping returns a default field mapping for text @@ -448,6 +456,16 @@ func (fm *FieldMapping) UnmarshalJSON(data []byte) error { if err != nil { return err } + case "dims": + err := json.Unmarshal(v, &fm.Dims) + if err != nil { + return err + } + case "similarity": + err := json.Unmarshal(v, &fm.Similarity) + if err != nil { + return err + } default: invalidKeys = append(invalidKeys, k) } diff --git a/mapping/index.go b/mapping/index.go index 0de4147a4..1c08bc589 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -431,6 +431,33 @@ func (im *IndexMappingImpl) FieldAnalyzer(field string) string { return im.AnalyzerNameForPath(field) } +// FieldMappingForPath returns the mapping for a specific field 'path'. +func (im *IndexMappingImpl) FieldMappingForPath(path string) FieldMapping { + if im.TypeMapping != nil { + for _, v := range im.TypeMapping { + for field, property := range v.Properties { + for _, v1 := range property.Fields { + if field == path { + // Return field mapping if the name matches the path param. + return *v1 + } + } + } + } + } + + for field, property := range im.DefaultMapping.Properties { + for _, v1 := range property.Fields { + if field == path { + // Return field mapping if the name matches the path param. + return *v1 + } + } + } + + return FieldMapping{} +} + // wrapper to satisfy new interface func (im *IndexMappingImpl) DefaultSearchField() string { diff --git a/mapping/mapping.go b/mapping/mapping.go index a3e5a54e0..cbfc98faa 100644 --- a/mapping/mapping.go +++ b/mapping/mapping.go @@ -55,4 +55,6 @@ type IndexMapping interface { AnalyzerNameForPath(path string) string AnalyzerNamed(name string) analysis.Analyzer + + FieldMappingForPath(path string) FieldMapping } diff --git a/mapping/mapping_no_vectors.go b/mapping/mapping_no_vectors.go new file mode 100644 index 000000000..f4987596a --- /dev/null +++ b/mapping/mapping_no_vectors.go @@ -0,0 +1,46 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !vectors +// +build !vectors + +package mapping + +import "fmt" + +func NewVectorFieldMapping() *FieldMapping { + return nil +} + +func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, + pathString string, path []string, indexes []uint64, context *walkContext) { + +} + +// ----------------------------------------------------------------------------- +// document validation functions + +func validateVectorField(fieldMapping *FieldMapping) error { + return nil +} + +func validateFieldType(fieldType string) error { + switch fieldType { + case "text", "datetime", "number", "boolean", "geopoint", "geoshape", "IP": + default: + return fmt.Errorf("unknown field type: '%s'", fieldType) + } + + return nil +} diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go new file mode 100644 index 000000000..a39820d96 --- /dev/null +++ b/mapping/mapping_vectors.go @@ -0,0 +1,119 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package mapping + +import ( + "fmt" + "reflect" + + "github.com/blevesearch/bleve/v2/document" + "github.com/blevesearch/bleve/v2/util" +) + +func NewVectorFieldMapping() *FieldMapping { + return &FieldMapping{ + Type: "vector", + Store: false, + Index: true, + IncludeInAll: false, + DocValues: false, + SkipFreqNorm: true, + } +} + +func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, + pathString string, path []string, indexes []uint64, context *walkContext) { + propertyVal := reflect.ValueOf(propertyMightBeVector) + if !propertyVal.IsValid() { + return + } + + // Validating the length of the vector is required here, in order to + // help zapx in deciding the shape of the batch of vectors to be indexed. + if propertyVal.Kind() == reflect.Slice && propertyVal.Len() == fm.Dims { + vector := make([]float32, propertyVal.Len()) + isVectorValid := true + for i := 0; i < propertyVal.Len(); i++ { + item := propertyVal.Index(i) + if item.CanInterface() { + itemVal := item.Interface() + itemFloat, ok := util.ExtractNumericValFloat32(itemVal) + if !ok { + isVectorValid = false + break + } + vector[i] = itemFloat + } + } + // Even if one of the vector elements is not a float32, we do not index + // this field and return silently + if !isVectorValid { + return + } + + fieldName := getFieldName(pathString, path, fm) + options := fm.Options() + field := document.NewVectorFieldWithIndexingOptions(fieldName, + indexes, vector, fm.Dims, fm.Similarity, options) + context.doc.AddField(field) + + // "_all" composite field is not applicable for vector field + context.excludedFromAll = append(context.excludedFromAll, fieldName) + } +} + +// ----------------------------------------------------------------------------- +// document validation functions + +func validateVectorField(field *FieldMapping) error { + if field.Dims <= 0 || field.Dims > 2048 { + return fmt.Errorf("invalid vector dimension,"+ + " value should be in range (%d, %d)", 0, 2048) + } + + if field.Similarity == "" { + field.Similarity = util.DefaultSimilarityMetric + } + + // following fields are not applicable for vector + // thus, we set them to default values + field.IncludeInAll = false + field.IncludeTermVectors = false + field.Store = false + field.DocValues = false + field.SkipFreqNorm = true + + if _, ok := util.SupportedSimilarityMetrics[field.Similarity]; !ok { + return fmt.Errorf("invalid similarity metric: '%s', "+ + "valid metrics are: %+v", field.Similarity, + reflect.ValueOf(util.SupportedSimilarityMetrics).MapKeys()) + } + + return nil +} + +func validateFieldType(fieldType string) error { + switch fieldType { + case "text", "datetime", "number", "boolean", "geopoint", "geoshape", + "IP", "vector": + default: + return fmt.Errorf("unknown field type: '%s'", fieldType) + } + + return nil +} diff --git a/mapping_vector.go b/mapping_vector.go new file mode 100644 index 000000000..594313861 --- /dev/null +++ b/mapping_vector.go @@ -0,0 +1,24 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package bleve + +import "github.com/blevesearch/bleve/v2/mapping" + +func NewVectorFieldMapping() *mapping.FieldMapping { + return mapping.NewVectorFieldMapping() +} diff --git a/search.go b/search.go index 8ca0310fb..b9da73d0b 100644 --- a/search.go +++ b/search.go @@ -15,7 +15,6 @@ package bleve import ( - "encoding/json" "fmt" "reflect" "sort" @@ -32,20 +31,20 @@ import ( "github.com/blevesearch/bleve/v2/util" ) -const defaultDateTimeParser = optional.Name - -var cache = registry.NewCache() - -var ( - reflectStaticSizeSearchResult int - reflectStaticSizeSearchStatus int -) +var reflectStaticSizeSearchResult int +var reflectStaticSizeSearchStatus int func init() { - reflectStaticSizeSearchResult = int(reflect.TypeOf(SearchResult{}).Size()) - reflectStaticSizeSearchStatus = int(reflect.TypeOf(SearchStatus{}).Size()) + var sr SearchResult + reflectStaticSizeSearchResult = int(reflect.TypeOf(sr).Size()) + var ss SearchStatus + reflectStaticSizeSearchStatus = int(reflect.TypeOf(ss).Size()) } +var cache = registry.NewCache() + +const defaultDateTimeParser = optional.Name + type dateTimeRange struct { Name string `json:"name,omitempty"` Start time.Time `json:"start,omitempty"` @@ -285,51 +284,10 @@ func (h *HighlightRequest) AddField(field string) { h.Fields = append(h.Fields, field) } -// A SearchRequest describes all the parameters -// needed to search the index. -// Query is required. -// Size/From describe how much and which part of the -// result set to return. -// Highlight describes optional search result -// highlighting. -// Fields describes a list of field values which -// should be retrieved for result documents, provided they -// were stored while indexing. -// Facets describe the set of facets to be computed. -// Explain triggers inclusion of additional search -// result score explanations. -// Sort describes the desired order for the results to be returned. -// Score controls the kind of scoring performed -// SearchAfter supports deep paging by providing a minimum sort key -// SearchBefore supports deep paging by providing a maximum sort key -// sortFunc specifies the sort implementation to use for sorting results. -// -// A special field named "*" can be used to return all fields. -type SearchRequest struct { - ClientContextID string `json:"client_context_id,omitempty"` - Query query.Query `json:"query"` - Size int `json:"size"` - From int `json:"from"` - Highlight *HighlightRequest `json:"highlight"` - Fields []string `json:"fields"` - Facets FacetsRequest `json:"facets"` - Explain bool `json:"explain"` - Sort search.SortOrder `json:"sort"` - IncludeLocations bool `json:"includeLocations"` - Score string `json:"score,omitempty"` - SearchAfter []string `json:"search_after"` - SearchBefore []string `json:"search_before"` - - sortFunc func(sort.Interface) -} - -func (r *SearchRequest) SetClientContextID(id string) { - r.ClientContextID = id -} - func (r *SearchRequest) Validate() error { if srq, ok := r.Query.(query.ValidatableQuery); ok { - if err := srq.Validate(); err != nil { + err := srq.Validate() + if err != nil { return err } } @@ -393,69 +351,6 @@ func (r *SearchRequest) SetSearchBefore(before []string) { r.SearchBefore = before } -// UnmarshalJSON deserializes a JSON representation of -// a SearchRequest -func (r *SearchRequest) UnmarshalJSON(input []byte) error { - var ( - temp struct { - ClientContextID string `json:"client_context_id"` - Q json.RawMessage `json:"query"` - Size *int `json:"size"` - From int `json:"from"` - Highlight *HighlightRequest `json:"highlight"` - Fields []string `json:"fields"` - Facets FacetsRequest `json:"facets"` - Explain bool `json:"explain"` - Sort []json.RawMessage `json:"sort"` - IncludeLocations bool `json:"includeLocations"` - Score string `json:"score"` - SearchAfter []string `json:"search_after"` - SearchBefore []string `json:"search_before"` - } - err error - ) - - if err = util.UnmarshalJSON(input, &temp); err != nil { - return err - } - - if temp.Size == nil { - r.Size = 10 - } else { - r.Size = *temp.Size - } - if temp.Sort == nil { - r.Sort = search.SortOrder{&search.SortScore{Desc: true}} - } else { - if r.Sort, err = search.ParseSortOrderJSON(temp.Sort); err != nil { - return err - } - } - r.ClientContextID = temp.ClientContextID - r.From = temp.From - r.Explain = temp.Explain - r.Highlight = temp.Highlight - r.Fields = temp.Fields - r.Facets = temp.Facets - r.IncludeLocations = temp.IncludeLocations - r.Score = temp.Score - r.SearchAfter = temp.SearchAfter - r.SearchBefore = temp.SearchBefore - if r.Query, err = query.ParseQuery(temp.Q); err != nil { - return err - } - - if r.Size < 0 { - r.Size = 10 - } - if r.From < 0 { - r.From = 0 - } - - return nil - -} - // NewSearchRequest creates a new SearchRequest // for the Query, using default values for all // other search parameters. @@ -491,7 +386,8 @@ func (iem IndexErrMap) MarshalJSON() ([]byte, error) { func (iem IndexErrMap) UnmarshalJSON(data []byte) error { var tmp map[string]string - if err := util.UnmarshalJSON(data, &tmp); err != nil { + err := util.UnmarshalJSON(data, &tmp) + if err != nil { return err } for k, v := range tmp { diff --git a/search/query/disjunction.go b/search/query/disjunction.go index f8573d081..e008a042a 100644 --- a/search/query/disjunction.go +++ b/search/query/disjunction.go @@ -73,11 +73,13 @@ func (q *DisjunctionQuery) Searcher(ctx context.Context, i index.IndexReader, m } return nil, err } - if _, ok := sr.(*searcher.MatchNoneSearcher); ok && q.queryStringMode { - // in query string mode, skip match none - continue + if sr != nil { + if _, ok := sr.(*searcher.MatchNoneSearcher); ok && q.queryStringMode { + // in query string mode, skip match none + continue + } + ss = append(ss, sr) } - ss = append(ss, sr) } if len(ss) < 1 { diff --git a/search/query/knn.go b/search/query/knn.go new file mode 100644 index 000000000..c485b4a12 --- /dev/null +++ b/search/query/knn.go @@ -0,0 +1,72 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package query + +import ( + "context" + + "github.com/blevesearch/bleve/v2/mapping" + "github.com/blevesearch/bleve/v2/search" + "github.com/blevesearch/bleve/v2/search/searcher" + "github.com/blevesearch/bleve/v2/util" + index "github.com/blevesearch/bleve_index_api" +) + +type KNNQuery struct { + VectorField string `json:"field"` + Vector []float32 `json:"vector"` + K int64 `json:"k"` + BoostVal *Boost `json:"boost,omitempty"` +} + +func NewKNNQuery(vector []float32) *KNNQuery { + return &KNNQuery{Vector: vector} +} + +func (q *KNNQuery) Field() string { + return q.VectorField +} + +func (q *KNNQuery) SetK(k int64) { + q.K = k +} + +func (q *KNNQuery) SetFieldVal(field string) { + q.VectorField = field +} + +func (q *KNNQuery) SetBoost(b float64) { + boost := Boost(b) + q.BoostVal = &boost +} + +func (q *KNNQuery) Boost() float64 { + return q.BoostVal.Value() +} + +func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader, + m mapping.IndexMapping, options search.SearcherOptions) (search.Searcher, error) { + fieldMapping := m.FieldMappingForPath(q.VectorField) + similarityMetric := fieldMapping.Similarity + if similarityMetric == "" { + similarityMetric = util.DefaultSimilarityMetric + } + + return searcher.NewKNNSearcher(ctx, i, m, options, q.VectorField, + q.Vector, q.K, q.BoostVal.Value(), similarityMetric) +} diff --git a/search/scorer/scorer_knn.go b/search/scorer/scorer_knn.go new file mode 100644 index 000000000..511a47ecb --- /dev/null +++ b/search/scorer/scorer_knn.go @@ -0,0 +1,104 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package scorer + +import ( + "reflect" + + "github.com/blevesearch/bleve/v2/search" + "github.com/blevesearch/bleve/v2/util" + index "github.com/blevesearch/bleve_index_api" +) + +var reflectStaticSizeKNNQueryScorer int + +func init() { + var sqs KNNQueryScorer + reflectStaticSizeKNNQueryScorer = int(reflect.TypeOf(sqs).Size()) +} + +type KNNQueryScorer struct { + queryVector []float32 + queryField string + queryWeight float64 + queryBoost float64 + queryNorm float64 + docTerm uint64 + docTotal uint64 + options search.SearcherOptions + includeScore bool + similarityMetric string +} + +func NewKNNQueryScorer(queryVector []float32, queryField string, queryBoost float64, + docTerm uint64, docTotal uint64, options search.SearcherOptions, + similarityMetric string) *KNNQueryScorer { + return &KNNQueryScorer{ + queryVector: queryVector, + queryField: queryField, + queryBoost: queryBoost, + queryWeight: 1.0, + docTerm: docTerm, + docTotal: docTotal, + options: options, + includeScore: options.Score != "none", + similarityMetric: similarityMetric, + } +} + +func (sqs *KNNQueryScorer) Score(ctx *search.SearchContext, + knnMatch *index.VectorDoc) *search.DocumentMatch { + rv := ctx.DocumentMatchPool.Get() + + if sqs.includeScore || sqs.options.Explain { + var scoreExplanation *search.Explanation + score := knnMatch.Score + if sqs.similarityMetric == util.EuclideanDistance { + // eucliden distances need to be inverted to work + // tf-idf scoring + score = 1.0 / score + } + + // if the query weight isn't 1, multiply + if sqs.queryWeight != 1.0 { + score = score * sqs.queryWeight + } + + if sqs.includeScore { + rv.Score = score + } + + if sqs.options.Explain { + rv.Expl = scoreExplanation + } + } + + rv.IndexInternalID = append(rv.IndexInternalID, knnMatch.ID...) + return rv +} + +func (sqs *KNNQueryScorer) Weight() float64 { + return sqs.queryBoost * sqs.queryBoost +} + +func (sqs *KNNQueryScorer) SetQueryNorm(qnorm float64) { + sqs.queryNorm = qnorm + + // update the query weight + sqs.queryWeight = sqs.queryBoost * sqs.queryNorm +} diff --git a/search/searcher/search_knn.go b/search/searcher/search_knn.go new file mode 100644 index 000000000..7dd59967e --- /dev/null +++ b/search/searcher/search_knn.go @@ -0,0 +1,123 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package searcher + +import ( + "context" + + "github.com/blevesearch/bleve/v2/mapping" + "github.com/blevesearch/bleve/v2/search" + "github.com/blevesearch/bleve/v2/search/scorer" + index "github.com/blevesearch/bleve_index_api" +) + +type KNNSearcher struct { + field string + vector []float32 + k int64 + indexReader index.IndexReader + vectorReader index.VectorReader + scorer *scorer.KNNQueryScorer + count uint64 + vd index.VectorDoc +} + +func NewKNNSearcher(ctx context.Context, i index.IndexReader, m mapping.IndexMapping, + options search.SearcherOptions, field string, vector []float32, k int64, + boost float64, similarityMetric string) (search.Searcher, error) { + if vr, ok := i.(index.VectorIndexReader); ok { + vectorReader, _ := vr.VectorReader(ctx, vector, field, k) + + count, err := i.DocCount() + if err != nil { + _ = vectorReader.Close() + return nil, err + } + + knnScorer := scorer.NewKNNQueryScorer(vector, field, boost, + vectorReader.Count(), count, options, similarityMetric) + return &KNNSearcher{ + indexReader: i, + vectorReader: vectorReader, + field: field, + vector: vector, + k: k, + scorer: knnScorer, + }, nil + } + return nil, nil +} + +func (s *KNNSearcher) Advance(ctx *search.SearchContext, ID index.IndexInternalID) ( + *search.DocumentMatch, error) { + knnMatch, err := s.vectorReader.Next(s.vd.Reset()) + if err != nil { + return nil, err + } + + if knnMatch == nil { + return nil, nil + } + + docMatch := s.scorer.Score(ctx, knnMatch) + + return docMatch, nil +} + +func (s *KNNSearcher) Close() error { + return s.vectorReader.Close() +} + +func (s *KNNSearcher) Count() uint64 { + return s.vectorReader.Count() +} + +func (s *KNNSearcher) DocumentMatchPoolSize() int { + return 1 +} + +func (s *KNNSearcher) Min() int { + return 0 +} + +func (s *KNNSearcher) Next(ctx *search.SearchContext) (*search.DocumentMatch, error) { + knnMatch, err := s.vectorReader.Next(s.vd.Reset()) + if err != nil { + return nil, err + } + + if knnMatch == nil { + return nil, nil + } + + docMatch := s.scorer.Score(ctx, knnMatch) + + return docMatch, nil +} + +func (s *KNNSearcher) SetQueryNorm(qnorm float64) { + s.scorer.SetQueryNorm(qnorm) +} + +func (s *KNNSearcher) Size() int { + return 0 +} + +func (s *KNNSearcher) Weight() float64 { + return s.scorer.Weight() +} diff --git a/search_knn.go b/search_knn.go new file mode 100644 index 000000000..a2f8d343c --- /dev/null +++ b/search_knn.go @@ -0,0 +1,165 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package bleve + +import ( + "encoding/json" + "sort" + + "github.com/blevesearch/bleve/v2/search" + "github.com/blevesearch/bleve/v2/search/query" +) + +type SearchRequest struct { + Query query.Query `json:"query"` + Size int `json:"size"` + From int `json:"from"` + Highlight *HighlightRequest `json:"highlight"` + Fields []string `json:"fields"` + Facets FacetsRequest `json:"facets"` + Explain bool `json:"explain"` + Sort search.SortOrder `json:"sort"` + IncludeLocations bool `json:"includeLocations"` + Score string `json:"score,omitempty"` + SearchAfter []string `json:"search_after"` + SearchBefore []string `json:"search_before"` + + KNN []*KNNRequest `json:"knn"` + + sortFunc func(sort.Interface) +} + +type KNNRequest struct { + Field string `json:"field"` + Vector []float32 `json:"vector"` + K int64 `json:"k"` + Boost *query.Boost `json:"boost,omitempty"` +} + +func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost float64) { + b := query.Boost(boost) + r.KNN = append(r.KNN, &KNNRequest{ + Field: field, + Vector: vector, + K: k, + Boost: &b, + }) +} + +// UnmarshalJSON deserializes a JSON representation of +// a SearchRequest +func (r *SearchRequest) UnmarshalJSON(input []byte) error { + var temp struct { + Q json.RawMessage `json:"query"` + Size *int `json:"size"` + From int `json:"from"` + Highlight *HighlightRequest `json:"highlight"` + Fields []string `json:"fields"` + Facets FacetsRequest `json:"facets"` + Explain bool `json:"explain"` + Sort []json.RawMessage `json:"sort"` + IncludeLocations bool `json:"includeLocations"` + Score string `json:"score"` + SearchAfter []string `json:"search_after"` + SearchBefore []string `json:"search_before"` + KNN []*KNNRequest `json:"knn"` + } + + err := json.Unmarshal(input, &temp) + if err != nil { + return err + } + + if temp.Size == nil { + r.Size = 10 + } else { + r.Size = *temp.Size + } + if temp.Sort == nil { + r.Sort = search.SortOrder{&search.SortScore{Desc: true}} + } else { + r.Sort, err = search.ParseSortOrderJSON(temp.Sort) + if err != nil { + return err + } + } + r.From = temp.From + r.Explain = temp.Explain + r.Highlight = temp.Highlight + r.Fields = temp.Fields + r.Facets = temp.Facets + r.IncludeLocations = temp.IncludeLocations + r.Score = temp.Score + r.SearchAfter = temp.SearchAfter + r.SearchBefore = temp.SearchBefore + r.Query, err = query.ParseQuery(temp.Q) + if err != nil { + return err + } + + if r.Size < 0 { + r.Size = 10 + } + if r.From < 0 { + r.From = 0 + } + + r.KNN = temp.KNN + + return nil + +} + +// ----------------------------------------------------------------------------- + +func copySearchRequest(req *SearchRequest) *SearchRequest { + rv := SearchRequest{ + Query: req.Query, + Size: req.Size + req.From, + From: 0, + Highlight: req.Highlight, + Fields: req.Fields, + Facets: req.Facets, + Explain: req.Explain, + Sort: req.Sort.Copy(), + IncludeLocations: req.IncludeLocations, + Score: req.Score, + SearchAfter: req.SearchAfter, + SearchBefore: req.SearchBefore, + KNN: req.KNN, + } + return &rv + +} + +func disjunctQueryWithKNN(req *SearchRequest) query.Query { + if len(req.KNN) > 0 { + disjuncts := []query.Query{req.Query} + for _, knn := range req.KNN { + if knn != nil { + knnQuery := query.NewKNNQuery(knn.Vector) + knnQuery.SetFieldVal(knn.Field) + knnQuery.SetK(knn.K) + knnQuery.SetBoost(knn.Boost.Value()) + disjuncts = append(disjuncts, knnQuery) + } + } + return query.NewDisjunctionQuery(disjuncts) + } + return req.Query +} diff --git a/search_no_knn.go b/search_no_knn.go new file mode 100644 index 000000000..fb3814911 --- /dev/null +++ b/search_no_knn.go @@ -0,0 +1,149 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !vectors +// +build !vectors + +package bleve + +import ( + "encoding/json" + "sort" + + "github.com/blevesearch/bleve/v2/search" + "github.com/blevesearch/bleve/v2/search/query" +) + +// A SearchRequest describes all the parameters +// needed to search the index. +// Query is required. +// Size/From describe how much and which part of the +// result set to return. +// Highlight describes optional search result +// highlighting. +// Fields describes a list of field values which +// should be retrieved for result documents, provided they +// were stored while indexing. +// Facets describe the set of facets to be computed. +// Explain triggers inclusion of additional search +// result score explanations. +// Sort describes the desired order for the results to be returned. +// Score controls the kind of scoring performed +// SearchAfter supports deep paging by providing a minimum sort key +// SearchBefore supports deep paging by providing a maximum sort key +// sortFunc specifies the sort implementation to use for sorting results. +// +// A special field named "*" can be used to return all fields. +type SearchRequest struct { + ClientContextID string `json:"client_context_id,omitempty"` + Query query.Query `json:"query"` + Size int `json:"size"` + From int `json:"from"` + Highlight *HighlightRequest `json:"highlight"` + Fields []string `json:"fields"` + Facets FacetsRequest `json:"facets"` + Explain bool `json:"explain"` + Sort search.SortOrder `json:"sort"` + IncludeLocations bool `json:"includeLocations"` + Score string `json:"score,omitempty"` + SearchAfter []string `json:"search_after"` + SearchBefore []string `json:"search_before"` + + sortFunc func(sort.Interface) +} + +// UnmarshalJSON deserializes a JSON representation of +// a SearchRequest +func (r *SearchRequest) UnmarshalJSON(input []byte) error { + var temp struct { + Q json.RawMessage `json:"query"` + Size *int `json:"size"` + From int `json:"from"` + Highlight *HighlightRequest `json:"highlight"` + Fields []string `json:"fields"` + Facets FacetsRequest `json:"facets"` + Explain bool `json:"explain"` + Sort []json.RawMessage `json:"sort"` + IncludeLocations bool `json:"includeLocations"` + Score string `json:"score"` + SearchAfter []string `json:"search_after"` + SearchBefore []string `json:"search_before"` + } + + err := json.Unmarshal(input, &temp) + if err != nil { + return err + } + + if temp.Size == nil { + r.Size = 10 + } else { + r.Size = *temp.Size + } + if temp.Sort == nil { + r.Sort = search.SortOrder{&search.SortScore{Desc: true}} + } else { + r.Sort, err = search.ParseSortOrderJSON(temp.Sort) + if err != nil { + return err + } + } + r.From = temp.From + r.Explain = temp.Explain + r.Highlight = temp.Highlight + r.Fields = temp.Fields + r.Facets = temp.Facets + r.IncludeLocations = temp.IncludeLocations + r.Score = temp.Score + r.SearchAfter = temp.SearchAfter + r.SearchBefore = temp.SearchBefore + r.Query, err = query.ParseQuery(temp.Q) + if err != nil { + return err + } + + if r.Size < 0 { + r.Size = 10 + } + if r.From < 0 { + r.From = 0 + } + + return nil + +} + +// ----------------------------------------------------------------------------- + +func copySearchRequest(req *SearchRequest) *SearchRequest { + rv := SearchRequest{ + Query: req.Query, + Size: req.Size + req.From, + From: 0, + Highlight: req.Highlight, + Fields: req.Fields, + Facets: req.Facets, + Explain: req.Explain, + Sort: req.Sort.Copy(), + IncludeLocations: req.IncludeLocations, + Score: req.Score, + SearchAfter: req.SearchAfter, + SearchBefore: req.SearchBefore, + } + return &rv +} + +func disjunctQueryWithKNN(req *SearchRequest) query.Query { + return req.Query +} diff --git a/util/extract.go b/util/extract.go new file mode 100644 index 000000000..f8e61546a --- /dev/null +++ b/util/extract.go @@ -0,0 +1,57 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "reflect" +) + +// extract numeric value (if possible) and returns a float64 +func ExtractNumericValFloat64(v interface{}) (float64, bool) { + val := reflect.ValueOf(v) + if !val.IsValid() { + return 0, false + } + typ := val.Type() + switch typ.Kind() { + case reflect.Float32, reflect.Float64: + return val.Float(), true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return float64(val.Int()), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return float64(val.Uint()), true + } + + return 0, false +} + +// extract numeric value (if possible) and returns a float32 +func ExtractNumericValFloat32(v interface{}) (float32, bool) { + val := reflect.ValueOf(v) + if !val.IsValid() { + return 0, false + } + typ := val.Type() + switch typ.Kind() { + case reflect.Float32, reflect.Float64: + return float32(val.Float()), true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return float32(val.Int()), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return float32(val.Uint()), true + } + + return 0, false +} diff --git a/util/knn.go b/util/knn.go new file mode 100644 index 000000000..e50ff01da --- /dev/null +++ b/util/knn.go @@ -0,0 +1,38 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package util + +const ( + EuclideanDistance = "l2_norm" + + // dotProduct(vecA, vecB) = vecA . vecB = |vecA| * |vecB| * cos(theta); + // where, theta is the angle between vecA and vecB + // If vecA and vecB are normalized (unit magnitude), then + // vecA . vecB = cos(theta), which is the cosine similarity. + // Thus, we don't need a separate similarity type for cosine similarity + CosineSimilarity = "dot_product" +) + +const DefaultSimilarityMetric = EuclideanDistance + +// Supported similarity metrics for vector fields +var SupportedSimilarityMetrics = map[string]struct{}{ + EuclideanDistance: {}, + CosineSimilarity: {}, +} +