diff --git a/pkg/models/search.go b/pkg/models/search.go index 37805a4c..24f45272 100644 --- a/pkg/models/search.go +++ b/pkg/models/search.go @@ -16,10 +16,10 @@ type MemorySearchResult struct { } type MemorySearchPayload struct { - Text string `json:"text"` - Metadata map[string]interface{} `json:"metadata,omitempty"` - Type SearchType `json:"type"` - MMRLambda float32 `json:"mmr_lambda,omitempty"` + Text string `json:"text"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + SearchType SearchType `json:"search_type"` + MMRLambda float32 `json:"mmr_lambda,omitempty"` } type DocumentSearchPayload struct { @@ -27,7 +27,7 @@ type DocumentSearchPayload struct { Text string `json:"text,omitempty"` Embedding []float32 `json:"embedding,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"` - Type SearchType `json:"type"` + SearchType SearchType `json:"search_type"` MMRLambda float32 `json:"mmr_lambda,omitempty"` } diff --git a/pkg/server/document_routes_test.go b/pkg/server/document_routes_test.go index 5d4c8d67..087d449d 100644 --- a/pkg/server/document_routes_test.go +++ b/pkg/server/document_routes_test.go @@ -370,7 +370,7 @@ func TestSearchDocumentsHandler(t *testing.T) { Metadata: map[string]interface{}{ "where": map[string]interface{}{"jsonpath": "$[*] ? (@.key == 'value')"}, }, - Type: searchType, + SearchType: searchType, } p, err := json.Marshal(q) assert.NoError(t, err) diff --git a/pkg/store/postgres/document_search.go b/pkg/store/postgres/document_search.go index 43cf25de..09649073 100644 --- a/pkg/store/postgres/document_search.go +++ b/pkg/store/postgres/document_search.go @@ -91,7 +91,7 @@ func (dso *documentSearchOperation) Execute() (*models.DocumentSearchResultPage, return nil, fmt.Errorf("error executing search: %w", err) } - if dso.searchPayload.Type == models.SearchTypeMMR { + if dso.searchPayload.SearchType == models.SearchTypeMMR { results, err = dso.reRankMMR(results) if err != nil { return nil, fmt.Errorf("error reranking results: %w", err) @@ -200,8 +200,11 @@ func (dso *documentSearchOperation) buildQuery(db bun.IDB) (*bun.SelectQuery, er // If we're using MMR, we need to add a limit of 2x the requested limit to allow for the MMR // algorithm to rerank and filter out results. limit := dso.limit - if dso.searchPayload.Type == models.SearchTypeMMR { + if dso.searchPayload.SearchType == models.SearchTypeMMR { limit *= DefaultMMRMultiplier + if limit < 10 { + limit = 10 + } } query = query.Limit(limit) diff --git a/pkg/store/postgres/document_search_test.go b/pkg/store/postgres/document_search_test.go index 40fa7101..82ed8289 100644 --- a/pkg/store/postgres/document_search_test.go +++ b/pkg/store/postgres/document_search_test.go @@ -78,8 +78,8 @@ func TestReRankMMR(t *testing.T) { // Initialize a documentSearchOperation with a searchPayload of type MMR dso := &documentSearchOperation{ searchPayload: &models.DocumentSearchPayload{ - Type: models.SearchTypeMMR, - MMRLambda: 0.5, + SearchType: models.SearchTypeMMR, + MMRLambda: 0.5, }, queryVector: []float32{0.1, 0.2, 0.3}, limit: 2, diff --git a/pkg/store/postgres/search_memory.go b/pkg/store/postgres/search_memory.go index fdce96a1..a42c234b 100644 --- a/pkg/store/postgres/search_memory.go +++ b/pkg/store/postgres/search_memory.go @@ -69,11 +69,15 @@ func searchMessages( // If we're using MMR, we need to return more results than the limit so we can // rerank them. - if query.Type == models.SearchTypeMMR { + if query.SearchType == models.SearchTypeMMR { if query.MMRLambda == 0 { query.MMRLambda = DefaultMMRLambda } - dbQuery = dbQuery.Limit(limit * DefaultMMRMultiplier) + tmpLimit := limit * DefaultMMRMultiplier + if tmpLimit < 10 { + tmpLimit = 10 + } + dbQuery = dbQuery.Limit(tmpLimit) } else { dbQuery = dbQuery.Limit(limit) } @@ -86,7 +90,7 @@ func searchMessages( filteredResults := filterValidMessageSearchResults(results, query.Metadata) // If we're using MMR, rerank the results. - if query.Type == models.SearchTypeMMR { + if query.SearchType == models.SearchTypeMMR { filteredResults, err = rerankMMR(filteredResults, queryEmbedding, query.MMRLambda, limit) if err != nil { return nil, store.NewStorageError("error applying mmr", err) @@ -128,7 +132,7 @@ func buildMessagesSelectQuery( ColumnExpr("m.metadata AS message__metadata"). ColumnExpr("m.token_count AS message__token_count") - if query.Type == models.SearchTypeMMR { + if query.SearchType == models.SearchTypeMMR { dbQuery = dbQuery.ColumnExpr("me.embedding AS embedding") } diff --git a/pkg/store/postgres/search_memory_test.go b/pkg/store/postgres/search_memory_test.go index 72ba8d20..98db7712 100644 --- a/pkg/store/postgres/search_memory_test.go +++ b/pkg/store/postgres/search_memory_test.go @@ -50,7 +50,7 @@ func TestMemorySearch(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - q := models.MemorySearchPayload{Text: tc.query, Type: tc.searchType} + q := models.MemorySearchPayload{Text: tc.query, SearchType: tc.searchType} expectedLastN := tc.limit if expectedLastN == 0 { expectedLastN = 10 // Default value