Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: mmr search_type json naming; use min limit for mmr #234

Merged
merged 1 commit into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pkg/models/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ type MemorySearchResult struct {
}

type MemorySearchPayload struct {
Text string `json:"text"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Type SearchType `json:"type"`
MMRLambda float32 `json:"mmr_lambda,omitempty"`
Text string `json:"text"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
SearchType SearchType `json:"search_type"`
MMRLambda float32 `json:"mmr_lambda,omitempty"`
}

type DocumentSearchPayload struct {
CollectionName string `json:"collection_name"`
Text string `json:"text,omitempty"`
Embedding []float32 `json:"embedding,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Type SearchType `json:"type"`
SearchType SearchType `json:"search_type"`
MMRLambda float32 `json:"mmr_lambda,omitempty"`
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/server/document_routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ func TestSearchDocumentsHandler(t *testing.T) {
Metadata: map[string]interface{}{
"where": map[string]interface{}{"jsonpath": "$[*] ? (@.key == 'value')"},
},
Type: searchType,
SearchType: searchType,
}
p, err := json.Marshal(q)
assert.NoError(t, err)
Expand Down
7 changes: 5 additions & 2 deletions pkg/store/postgres/document_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (dso *documentSearchOperation) Execute() (*models.DocumentSearchResultPage,
return nil, fmt.Errorf("error executing search: %w", err)
}

if dso.searchPayload.Type == models.SearchTypeMMR {
if dso.searchPayload.SearchType == models.SearchTypeMMR {
results, err = dso.reRankMMR(results)
if err != nil {
return nil, fmt.Errorf("error reranking results: %w", err)
Expand Down Expand Up @@ -200,8 +200,11 @@ func (dso *documentSearchOperation) buildQuery(db bun.IDB) (*bun.SelectQuery, er
// If we're using MMR, we need to add a limit of 2x the requested limit to allow for the MMR
// algorithm to rerank and filter out results.
limit := dso.limit
if dso.searchPayload.Type == models.SearchTypeMMR {
if dso.searchPayload.SearchType == models.SearchTypeMMR {
limit *= DefaultMMRMultiplier
if limit < 10 {
limit = 10
}
}
query = query.Limit(limit)

Expand Down
4 changes: 2 additions & 2 deletions pkg/store/postgres/document_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ func TestReRankMMR(t *testing.T) {
// Initialize a documentSearchOperation with a searchPayload of type MMR
dso := &documentSearchOperation{
searchPayload: &models.DocumentSearchPayload{
Type: models.SearchTypeMMR,
MMRLambda: 0.5,
SearchType: models.SearchTypeMMR,
MMRLambda: 0.5,
},
queryVector: []float32{0.1, 0.2, 0.3},
limit: 2,
Expand Down
12 changes: 8 additions & 4 deletions pkg/store/postgres/search_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,15 @@ func searchMessages(

// If we're using MMR, we need to return more results than the limit so we can
// rerank them.
if query.Type == models.SearchTypeMMR {
if query.SearchType == models.SearchTypeMMR {
if query.MMRLambda == 0 {
query.MMRLambda = DefaultMMRLambda
}
dbQuery = dbQuery.Limit(limit * DefaultMMRMultiplier)
tmpLimit := limit * DefaultMMRMultiplier
if tmpLimit < 10 {
tmpLimit = 10
}
dbQuery = dbQuery.Limit(tmpLimit)
} else {
dbQuery = dbQuery.Limit(limit)
}
Expand All @@ -86,7 +90,7 @@ func searchMessages(
filteredResults := filterValidMessageSearchResults(results, query.Metadata)

// If we're using MMR, rerank the results.
if query.Type == models.SearchTypeMMR {
if query.SearchType == models.SearchTypeMMR {
filteredResults, err = rerankMMR(filteredResults, queryEmbedding, query.MMRLambda, limit)
if err != nil {
return nil, store.NewStorageError("error applying mmr", err)
Expand Down Expand Up @@ -128,7 +132,7 @@ func buildMessagesSelectQuery(
ColumnExpr("m.metadata AS message__metadata").
ColumnExpr("m.token_count AS message__token_count")

if query.Type == models.SearchTypeMMR {
if query.SearchType == models.SearchTypeMMR {
dbQuery = dbQuery.ColumnExpr("me.embedding AS embedding")
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/store/postgres/search_memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestMemorySearch(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
q := models.MemorySearchPayload{Text: tc.query, Type: tc.searchType}
q := models.MemorySearchPayload{Text: tc.query, SearchType: tc.searchType}
expectedLastN := tc.limit
if expectedLastN == 0 {
expectedLastN = 10 // Default value
Expand Down
Loading