Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add filter to index #99

Merged
merged 1 commit into from Jul 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/embeddings/qdrant/main.go
Expand Up @@ -12,7 +12,8 @@ import (
"github.com/henomis/lingoose/textsplitter"
)

// download https://frontiernerds.com/files/state_of_the_union.txt
// download https://raw.githubusercontent.com/hwchase17/chat-your-data/master/state_of_the_union.txt
// run qdrant docker run -p 6333:6333 qdrant/qdrant

func main() {

Expand Down
9 changes: 8 additions & 1 deletion index/options.go
Expand Up @@ -3,11 +3,18 @@ package index
type Option func(*options)

type options struct {
topK int
topK int
filter interface{}
}

func WithTopK(topK int) Option {
return func(opts *options) {
opts.topK = topK
}
}

func WithFilter(filter interface{}) Option {
return func(opts *options) {
opts.filter = filter
}
}
7 changes: 4 additions & 3 deletions index/pinecone.go
Expand Up @@ -135,7 +135,7 @@ func (p *Pinecone) SimilaritySearch(ctx context.Context, query string, opts ...O
opt(pineconeOptions)
}

matches, err := p.similaritySearch(ctx, query, pineconeOptions.topK)
matches, err := p.similaritySearch(ctx, query, pineconeOptions)
if err != nil {
return nil, fmt.Errorf("%s: %w", ErrInternal, err)
}
Expand All @@ -145,7 +145,7 @@ func (p *Pinecone) SimilaritySearch(ctx context.Context, query string, opts ...O
return filterSearchResponses(searchResponses, pineconeOptions.topK), nil
}

func (p *Pinecone) similaritySearch(ctx context.Context, query string, topK int) ([]pineconeresponse.QueryMatch, error) {
func (p *Pinecone) similaritySearch(ctx context.Context, query string, opts *options) ([]pineconeresponse.QueryMatch, error) {

err := p.getProjectID(ctx)
if err != nil {
Expand All @@ -164,10 +164,11 @@ func (p *Pinecone) similaritySearch(ctx context.Context, query string, topK int)
&pineconerequest.VectorQuery{
IndexName: p.indexName,
ProjectID: *p.projectID,
TopK: int32(topK),
TopK: int32(opts.topK),
Vector: embeddings[0],
IncludeMetadata: &includeMetadata,
Namespace: &p.namespace,
Filter: opts.filter.(map[string]string),
},
res,
)
Expand Down
7 changes: 4 additions & 3 deletions index/qdrant.go
Expand Up @@ -123,7 +123,7 @@ func (q *Qdrant) SimilaritySearch(ctx context.Context, query string, opts ...Opt
opt(qdrantOptions)
}

matches, err := q.similaritySearch(ctx, query, qdrantOptions.topK)
matches, err := q.similaritySearch(ctx, query, qdrantOptions)
if err != nil {
return nil, fmt.Errorf("%s: %w", ErrInternal, err)
}
Expand All @@ -133,7 +133,7 @@ func (q *Qdrant) SimilaritySearch(ctx context.Context, query string, opts ...Opt
return filterSearchResponses(searchResponses, qdrantOptions.topK), nil
}

func (p *Qdrant) similaritySearch(ctx context.Context, query string, topK int) ([]qdrantresponse.PointSearchResult, error) {
func (p *Qdrant) similaritySearch(ctx context.Context, query string, opts *options) ([]qdrantresponse.PointSearchResult, error) {

embeddings, err := p.embedder.Embed(ctx, []string{query})
if err != nil {
Expand All @@ -146,9 +146,10 @@ func (p *Qdrant) similaritySearch(ctx context.Context, query string, topK int) (
ctx,
&qdrantrequest.PointSearch{
CollectionName: p.collectionName,
Limit: topK,
Limit: opts.topK,
Vector: embeddings[0],
WithPayload: &includeMetadata,
Filter: opts.filter,
},
res,
)
Expand Down
6 changes: 6 additions & 0 deletions index/simpleVectorIndex.go
Expand Up @@ -30,6 +30,8 @@ type SimpleVectorIndex struct {
embedder Embedder
}

type SimpleVectorIndexFilterFn func([]SearchResponse) []SearchResponse

func NewSimpleVectorIndex(name string, outputPath string, embedder Embedder) *SimpleVectorIndex {
simpleVectorIndex := &SimpleVectorIndex{
data: []simpleVectorIndexData{},
Expand Down Expand Up @@ -152,6 +154,10 @@ func (s *SimpleVectorIndex) SimilaritySearch(ctx context.Context, query string,
}
}

if sviOptions.filter != nil {
searchResponses = sviOptions.filter.(SimpleVectorIndexFilterFn)(searchResponses)
}

return filterSearchResponses(searchResponses, sviOptions.topK), nil
}

Expand Down