Skip to content

Commit

Permalink
feat: MMR Reranking for Document and Memory Search (#232)
Browse files Browse the repository at this point in the history
* mmr memory search

* remove withMMR from documents

* update swagger

* mmr for document search
  • Loading branch information
danielchalef committed Oct 16, 2023
1 parent 01077ee commit 06fe219
Show file tree
Hide file tree
Showing 20 changed files with 381 additions and 213 deletions.
2 changes: 1 addition & 1 deletion docs/docs.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/swagger.json

Large diffs are not rendered by default.

24 changes: 20 additions & 4 deletions docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,12 @@ definitions:
metadata:
additionalProperties: true
type: object
mmr_lambda:
type: number
text:
type: string
type:
$ref: '#/definitions/models.SearchType'
type: object
models.GetDocumentListRequest:
properties:
Expand Down Expand Up @@ -190,13 +194,21 @@ definitions:
metadata:
additionalProperties: true
type: object
mmr_lambda:
type: number
text:
type: string
type:
$ref: '#/definitions/models.SearchType'
type: object
models.MemorySearchResult:
properties:
dist:
type: number
embedding:
items:
type: number
type: array
message:
$ref: '#/definitions/models.Message'
metadata:
Expand Down Expand Up @@ -225,6 +237,14 @@ definitions:
uuid:
type: string
type: object
models.SearchType:
enum:
- similarity
- mmr
type: string
x-enum-varnames:
- SearchTypeSimilarity
- SearchTypeMMR
models.Session:
properties:
created_at:
Expand Down Expand Up @@ -910,10 +930,6 @@ paths:
in: query
name: limit
type: integer
- description: Use MMR to rerank the search results. Not Implemented
in: query
name: mmr
type: boolean
- description: Search criteria
in: body
name: searchPayload
Expand Down
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ require (
github.com/uptrace/bun v1.1.16
github.com/uptrace/bun/dialect/pgdialect v1.1.16
github.com/uptrace/bun/driver/pgdriver v1.1.16
gonum.org/v1/gonum v0.14.0
)

require (
Expand All @@ -40,12 +39,14 @@ require (
github.com/tmc/langchaingo v0.0.0-20230929160525-e16b77704b8d
github.com/uptrace/bun/dbfixture v1.1.16
github.com/uptrace/bun/extra/bundebug v1.1.16
github.com/viterin/vek v0.4.2
gopkg.in/yaml.v3 v3.0.1
)

require (
github.com/KyleBanks/depth v1.2.1 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/chewxy/math32 v1.10.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
Expand Down Expand Up @@ -92,13 +93,15 @@ require (
github.com/subosito/gotenv v1.6.0 // indirect
github.com/sv-tools/openapi v0.2.2 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/viterin/partial v1.1.0 // indirect
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/tools v0.13.0 // indirect
golang.org/x/tools v0.14.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
mellium.im/sasl v0.3.1 // indirect
Expand Down
20 changes: 12 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ github.com/avast/retry-go/v4 v4.5.0/go.mod h1:7hLEXp0oku2Nir2xBAsg0PTphp9z71bN5A
github.com/brianvoe/gofakeit/v6 v6.23.2 h1:lVde18uhad5wII/f5RMVFLtdQNE0HaGFuBUXmYKk8i8=
github.com/brianvoe/gofakeit/v6 v6.23.2/go.mod h1:Ow6qC71xtwm79anlwKRlWZW6zVq9D2XHE4QSSMP/rU8=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/chewxy/math32 v1.10.1 h1:LFpeY0SLJXeaiej/eIp2L40VYfscTvKh/FSEZ68uMkU=
github.com/chewxy/math32 v1.10.1/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs=
github.com/chi-middleware/logrus-logger v0.2.0 h1:Do3vcVSRsLh7zSRKxsVg5Kr5//rTqytwprCR1HzVqT8=
github.com/chi-middleware/logrus-logger v0.2.0/go.mod h1:ie/rvKsXrtqqsnJd3qtSEnLxgCs1I758WYmHdv6CRt0=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
Expand Down Expand Up @@ -347,6 +349,10 @@ github.com/uptrace/bun/driver/pgdriver v1.1.16 h1:b/NiSXk6Ldw7KLfMLbOqIkm4odHd7Q
github.com/uptrace/bun/driver/pgdriver v1.1.16/go.mod h1:Rmfbc+7lx1z/umjMyAxkOHK81LgnGj71XC5YpA6k1vU=
github.com/uptrace/bun/extra/bundebug v1.1.16 h1:SgicRQGtnjhrIhlYOxdkOm1Em4s6HykmT3JblHnoTBM=
github.com/uptrace/bun/extra/bundebug v1.1.16/go.mod h1:SkiOkfUirBiO1Htc4s5bQKEq+JSeU1TkBVpMsPz2ePM=
github.com/viterin/partial v1.1.0 h1:iH1l1xqBlapXsYzADS1dcbizg3iQUKTU1rbwkHv/80E=
github.com/viterin/partial v1.1.0/go.mod h1:oKGAo7/wylWkJTLrWX8n+f4aDPtQMQ6VG4dd2qur5QA=
github.com/viterin/vek v0.4.2 h1:Vyv04UjQT6gcjEFX82AS9ocgNbAJqsHviheIBdPlv5U=
github.com/viterin/vek v0.4.2/go.mod h1:A4JRAe8OvbhdzBL5ofzjBS0J29FyUrf95tQogvtHHUc=
github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94=
github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ=
github.com/vmihailenco/msgpack/v5 v5.3.4/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
Expand Down Expand Up @@ -389,8 +395,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea h1:vLCWI/yYrdEHyN2JzIzPO3aaQJHQdp89IZBA/+azVC4=
golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
Expand All @@ -416,8 +422,8 @@ golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
Expand Down Expand Up @@ -599,14 +605,12 @@ golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc=
golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
Expand Down
2 changes: 1 addition & 1 deletion pkg/models/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ type Document struct {
Embedding []float32 `bun:"type:vector,nullzero" json:"embedding,omitempty"`
}

type SearchDocumentQuery struct {
type SearchDocumentResult struct {
*Document
Score float64 `json:"score" bun:"score"`
}
Expand Down
1 change: 0 additions & 1 deletion pkg/models/documentstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ type DocumentStore[T any] interface {
ctx context.Context,
query *DocumentSearchPayload,
limit int,
withMMR bool, // withMMR is used to enable/disable the Maximal Marginal Relevance algorithm for search results.
pageNumber int,
pageSize int,
) (*DocumentSearchResultPage, error)
Expand Down
24 changes: 18 additions & 6 deletions pkg/models/search.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
package models

type SearchType string

const (
SearchTypeSimilarity SearchType = "similarity"
SearchTypeMMR SearchType = "mmr"
)

type MemorySearchResult struct {
Message *Message `json:"message"`
Summary *Summary `json:"summary"` // reserved for future use
Metadata map[string]interface{} `json:"metadata,omitempty"`
Dist float64 `json:"dist"`
Message *Message `json:"message"`
Summary *Summary `json:"summary"` // reserved for future use
Metadata map[string]interface{} `json:"metadata,omitempty"`
Dist float64 `json:"dist"`
Embedding []float32 `json:"embedding"`
}

type MemorySearchPayload struct {
Text string `json:"text"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Text string `json:"text"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Type SearchType `json:"type"`
MMRLambda float32 `json:"mmr_lambda,omitempty"`
}

type DocumentSearchPayload struct {
CollectionName string `json:"collection_name"`
Text string `json:"text,omitempty"`
Embedding []float32 `json:"embedding,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Type SearchType `json:"type"`
MMRLambda float32 `json:"mmr_lambda,omitempty"`
}

type DocumentSearchResult struct {
Expand Down
121 changes: 43 additions & 78 deletions pkg/search/mmr.go
Original file line number Diff line number Diff line change
@@ -1,120 +1,92 @@
package search

import (
"errors"
"fmt"
"math"

"gonum.org/v1/gonum/floats"

"gonum.org/v1/gonum/mat"
"github.com/getzep/zep/internal"
"github.com/viterin/vek"
"github.com/viterin/vek/vek32"
)

// CosineSimilarity calculates the cosine similarity between two vectors.
// The vectors must be of the same length.
func CosineSimilarity(X, Y *mat.Dense) (*mat.Dense, error) { // nolint: gocritic
rX, cX := X.Dims()
rY, cY := Y.Dims()

if rX == 0 || rY == 0 {
return mat.NewDense(0, 0, nil), nil
}

if cX != cY {
return nil, fmt.Errorf(
"number of columns in X and Y must be the same. X has shape [%d, %d] and Y has shape [%d, %d]",
rX,
cX,
rY,
cY,
)
}

Xnorm := mat.NewVecDense(rX, nil)
Ynorm := mat.NewVecDense(rY, nil)

for i := 0; i < rX; i++ {
Xnorm.SetVec(i, mat.Norm(X.RowView(i), 2))
}

for i := 0; i < rY; i++ {
Ynorm.SetVec(i, mat.Norm(Y.RowView(i), 2))
}
var log = internal.GetLogger()

var XT mat.Dense
XT.CloneFrom(X.T())

similarity := mat.NewDense(rX, rY, nil)
similarity.Product(X, &XT)
func init() {
log.Infof("MMR acceleration status: %v", vek.Info())
}

for i := 0; i < rX; i++ {
for j := 0; j < rY; j++ {
val := similarity.At(i, j) / (Xnorm.AtVec(i) * Ynorm.AtVec(j))
if math.IsNaN(val) || math.IsInf(val, 0) {
val = 0.0
// pairwiseCosineSimilarity takes two matrices of vectors and returns a matrix, where
// the value at [i][j] is the cosine similarity between the ith vector in matrix1 and
// the jth vector in matrix2.
func pairwiseCosineSimilarity(matrix1 [][]float32, matrix2 [][]float32) ([][]float32, error) {
result := make([][]float32, len(matrix1))
for i, vec1 := range matrix1 {
result[i] = make([]float32, len(matrix2))
for j, vec2 := range matrix2 {
if len(vec1) != len(vec2) {
return nil, fmt.Errorf("vector lengths do not match: %d != %d", len(vec1), len(vec2))
}
similarity.Set(i, j, val)
result[i][j] = vek32.CosineSimilarity(vec1, vec2)
}
}

return similarity, nil
return result, nil
}

// MaximalMarginalRelevance implements the Maximal Marginal Relevance algorithm.
// It takes a query embedding, a list of embeddings, a lambda multiplier, and a
// number of results to return. It returns a list of indices of the embeddings
// that are most relevant to the query.
// This is a relatively naive and unoptimized implementation of MMR. :-/
// See https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf
func MaximalMarginalRelevance(
queryEmbedding *mat.Dense,
embeddingList *mat.Dense,
lambdaMult float64,
k int,
) ([]int, error) {
rEmbed, _ := embeddingList.Dims()
if k <= 0 || rEmbed == 0 {
// Implementation borrowed from LangChain
// https://github.com/langchain-ai/langchain/blob/4a2f0c51a116cc3141142ea55254e270afb6acde/libs/langchain/langchain/vectorstores/utils.py
func MaximalMarginalRelevance(queryEmbedding []float32, embeddingList [][]float32, lambdaMult float32, k int) ([]int, error) {
// if either k or the length of the embedding list is 0, return an empty list
if min(k, len(embeddingList)) <= 0 {
return []int{}, nil
}

var mostSimilar int
var bestScore float64
var idxToAdd int
// We expect the query embedding and the embeddings in the list to have the same width
if len(queryEmbedding) != len(embeddingList[0]) {
return []int{}, errors.New("query embedding width does not match embedding vector width")
}

similarityToQuery, err := CosineSimilarity(queryEmbedding, embeddingList)
similarityToQueryMatrix, err := pairwiseCosineSimilarity([][]float32{queryEmbedding}, embeddingList)
if err != nil {
return nil, err
}
mostSimilar = floats.MaxIdx(similarityToQuery.RawMatrix().Data)
similarityToQuery := similarityToQueryMatrix[0]

mostSimilar := vek32.ArgMax(similarityToQuery)
idxs := []int{mostSimilar}
selected := mat.DenseCopyOf(embeddingList.RowView(mostSimilar))
selected := [][]float32{embeddingList[mostSimilar]}

for len(idxs) < min(k, rEmbed) {
bestScore = math.Inf(-1)
idxToAdd = -1
r, c := selected.Dims()
selectedTransposed := mat.NewDense(c, r, nil)
selectedTransposed.CloneFrom(selected.T())
similarityToSelected, err := CosineSimilarity(embeddingList, selectedTransposed)
for len(idxs) < min(k, len(embeddingList)) {
var bestScore float32 = -math.MaxFloat32
idxToAdd := -1
similarityToSelected, err := pairwiseCosineSimilarity(embeddingList, selected)
if err != nil {
return nil, err
}
for i, queryScore := range similarityToQuery.RawMatrix().Data {

for i, queryScore := range similarityToQuery {
if contains(idxs, i) {
continue
}
redundantScore := floats.Max(similarityToSelected.RawMatrix().Data)
redundantScore := vek32.Max(similarityToSelected[i])
equationScore := lambdaMult*queryScore - (1-lambdaMult)*redundantScore
if equationScore > bestScore {
bestScore = equationScore
idxToAdd = i
}
}
idxs = append(idxs, idxToAdd)
selected.Stack(selected, embeddingList.RowView(idxToAdd))
selected = append(selected, embeddingList[idxToAdd])
}
return idxs, nil
}

// contains returns true if the slice contains the value
func contains(slice []int, val int) bool {
for _, item := range slice {
if item == val {
Expand All @@ -123,10 +95,3 @@ func contains(slice []int, val int) bool {
}
return false
}

func min(a, b int) int {
if a < b {
return a
}
return b
}
Loading

0 comments on commit 06fe219

Please sign in to comment.