Skip to content

Commit

Permalink
fix search_type json naming; use min limit for mmr (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielchalef committed Oct 17, 2023
1 parent 06fe219 commit 57c6be0
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 15 deletions.
10 changes: 5 additions & 5 deletions pkg/models/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ type MemorySearchResult struct {
}

type MemorySearchPayload struct {
Text string `json:"text"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Type SearchType `json:"type"`
MMRLambda float32 `json:"mmr_lambda,omitempty"`
Text string `json:"text"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
SearchType SearchType `json:"search_type"`
MMRLambda float32 `json:"mmr_lambda,omitempty"`
}

type DocumentSearchPayload struct {
CollectionName string `json:"collection_name"`
Text string `json:"text,omitempty"`
Embedding []float32 `json:"embedding,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Type SearchType `json:"type"`
SearchType SearchType `json:"search_type"`
MMRLambda float32 `json:"mmr_lambda,omitempty"`
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/server/document_routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ func TestSearchDocumentsHandler(t *testing.T) {
Metadata: map[string]interface{}{
"where": map[string]interface{}{"jsonpath": "$[*] ? (@.key == 'value')"},
},
Type: searchType,
SearchType: searchType,
}
p, err := json.Marshal(q)
assert.NoError(t, err)
Expand Down
7 changes: 5 additions & 2 deletions pkg/store/postgres/document_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (dso *documentSearchOperation) Execute() (*models.DocumentSearchResultPage,
return nil, fmt.Errorf("error executing search: %w", err)
}

if dso.searchPayload.Type == models.SearchTypeMMR {
if dso.searchPayload.SearchType == models.SearchTypeMMR {
results, err = dso.reRankMMR(results)
if err != nil {
return nil, fmt.Errorf("error reranking results: %w", err)
Expand Down Expand Up @@ -200,8 +200,11 @@ func (dso *documentSearchOperation) buildQuery(db bun.IDB) (*bun.SelectQuery, er
// If we're using MMR, we need to add a limit of 2x the requested limit to allow for the MMR
// algorithm to rerank and filter out results.
limit := dso.limit
if dso.searchPayload.Type == models.SearchTypeMMR {
if dso.searchPayload.SearchType == models.SearchTypeMMR {
limit *= DefaultMMRMultiplier
if limit < 10 {
limit = 10
}
}
query = query.Limit(limit)

Expand Down
4 changes: 2 additions & 2 deletions pkg/store/postgres/document_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ func TestReRankMMR(t *testing.T) {
// Initialize a documentSearchOperation with a searchPayload of type MMR
dso := &documentSearchOperation{
searchPayload: &models.DocumentSearchPayload{
Type: models.SearchTypeMMR,
MMRLambda: 0.5,
SearchType: models.SearchTypeMMR,
MMRLambda: 0.5,
},
queryVector: []float32{0.1, 0.2, 0.3},
limit: 2,
Expand Down
12 changes: 8 additions & 4 deletions pkg/store/postgres/search_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,15 @@ func searchMessages(

// If we're using MMR, we need to return more results than the limit so we can
// rerank them.
if query.Type == models.SearchTypeMMR {
if query.SearchType == models.SearchTypeMMR {
if query.MMRLambda == 0 {
query.MMRLambda = DefaultMMRLambda
}
dbQuery = dbQuery.Limit(limit * DefaultMMRMultiplier)
tmpLimit := limit * DefaultMMRMultiplier
if tmpLimit < 10 {
tmpLimit = 10
}
dbQuery = dbQuery.Limit(tmpLimit)
} else {
dbQuery = dbQuery.Limit(limit)
}
Expand All @@ -86,7 +90,7 @@ func searchMessages(
filteredResults := filterValidMessageSearchResults(results, query.Metadata)

// If we're using MMR, rerank the results.
if query.Type == models.SearchTypeMMR {
if query.SearchType == models.SearchTypeMMR {
filteredResults, err = rerankMMR(filteredResults, queryEmbedding, query.MMRLambda, limit)
if err != nil {
return nil, store.NewStorageError("error applying mmr", err)
Expand Down Expand Up @@ -128,7 +132,7 @@ func buildMessagesSelectQuery(
ColumnExpr("m.metadata AS message__metadata").
ColumnExpr("m.token_count AS message__token_count")

if query.Type == models.SearchTypeMMR {
if query.SearchType == models.SearchTypeMMR {
dbQuery = dbQuery.ColumnExpr("me.embedding AS embedding")
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/store/postgres/search_memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestMemorySearch(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
q := models.MemorySearchPayload{Text: tc.query, Type: tc.searchType}
q := models.MemorySearchPayload{Text: tc.query, SearchType: tc.searchType}
expectedLastN := tc.limit
if expectedLastN == 0 {
expectedLastN = 10 // Default value
Expand Down

0 comments on commit 57c6be0

Please sign in to comment.