Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore Refactor similarity search results #127

Merged
merged 3 commits into from Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/embeddings/knowledge_base/main.go
Expand Up @@ -58,10 +58,10 @@ func main() {

for _, similarity := range similarities {
fmt.Printf("Similarity: %f\n", similarity.Score)
fmt.Printf("Document: %s\n", similarity.Document.Content)
fmt.Println("Metadata: ", similarity.Document.Metadata)
fmt.Printf("Document: %s\n", similarity.Content())
fmt.Println("Metadata: ", similarity.Metadata)
fmt.Println("----------")
content += similarity.Document.Content + "\n"
content += similarity.Content() + "\n"
}

systemPrompt := prompt.New("You are an helpful assistant. Answer to the questions using only " +
Expand Down
6 changes: 3 additions & 3 deletions examples/embeddings/pinecone/main.go
Expand Up @@ -59,11 +59,11 @@ func main() {
content := ""
for _, similarity := range similarities {
fmt.Printf("Similarity: %f\n", similarity.Score)
fmt.Printf("Document: %s\n", similarity.Document.Content)
fmt.Println("Metadata: ", similarity.Document.Metadata)
fmt.Printf("Document: %s\n", similarity.Content())
fmt.Println("Metadata: ", similarity.Metadata)
fmt.Println("ID: ", similarity.ID)
fmt.Println("----------")
content += similarity.Document.Content + "\n"
content += similarity.Content() + "\n"
}

llmOpenAI := openai.NewCompletion().WithVerbose(true)
Expand Down
6 changes: 3 additions & 3 deletions examples/embeddings/qdrant/main.go
Expand Up @@ -57,11 +57,11 @@ func main() {
content := ""
for _, similarity := range similarities {
fmt.Printf("Similarity: %f\n", similarity.Score)
fmt.Printf("Document: %s\n", similarity.Document.Content)
fmt.Println("Metadata: ", similarity.Document.Metadata)
fmt.Printf("Document: %s\n", similarity.Content())
fmt.Println("Metadata: ", similarity.Metadata)
fmt.Println("ID: ", similarity.ID)
fmt.Println("----------")
content += similarity.Document.Content + "\n"
content += similarity.Content() + "\n"
}

llmOpenAI := openai.NewCompletion().WithVerbose(true)
Expand Down
6 changes: 3 additions & 3 deletions examples/embeddings/simpleVector/main.go
Expand Up @@ -42,14 +42,14 @@ func main() {

for _, similarity := range similarities {
fmt.Printf("Similarity: %f\n", similarity.Score)
fmt.Printf("Document: %s\n", similarity.Document.Content)
fmt.Println("Metadata: ", similarity.Document.Metadata)
fmt.Printf("Document: %s\n", similarity.Content())
fmt.Println("Metadata: ", similarity.Metadata)
fmt.Println("----------")
}

documentContext := ""
for _, similarity := range similarities {
documentContext += similarity.Document.Content + "\n\n"
documentContext += similarity.Content() + "\n\n"
}

llmOpenAI := openai.NewCompletion()
Expand Down
36 changes: 24 additions & 12 deletions index/index.go
Expand Up @@ -18,18 +18,30 @@ const (
DefaultKeyContent = "content"
)

type SearchResponse struct {
type SearchResult struct {
ID string
Document document.Document
Values []float64
Metadata types.Meta
Score float64
}

type SearchResponses []SearchResponse
func (s *SearchResult) Content() string {
return s.Metadata[DefaultKeyContent].(string)
}

type SearchResults []SearchResult

func (s SearchResponses) ToDocuments() []document.Document {
func (s SearchResults) ToDocuments() []document.Document {
documents := make([]document.Document, len(s))
for i, searchResponse := range s {
documents[i] = searchResponse.Document
for i, searchResult := range s {
metadata := DeepCopyMetadata(searchResult.Metadata)
content := metadata[DefaultKeyContent].(string)
delete(metadata, DefaultKeyContent)

documents[i] = document.Document{
Content: content,
Metadata: metadata,
}
}
return documents
}
Expand All @@ -38,18 +50,18 @@ type Embedder interface {
Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error)
}

func FilterSearchResponses(searchResponses SearchResponses, topK int) SearchResponses {
func FilterSearchResults(searchResults SearchResults, topK int) SearchResults {
//sort by similarity score
sort.Slice(searchResponses, func(i, j int) bool {
return searchResponses[i].Score > searchResponses[j].Score
sort.Slice(searchResults, func(i, j int) bool {
return searchResults[i].Score > searchResults[j].Score
})

maxTopK := topK
if maxTopK > len(searchResponses) {
maxTopK = len(searchResponses)
if maxTopK > len(searchResults) {
maxTopK = len(searchResults)
}

return searchResponses[:maxTopK]
return searchResults[:maxTopK]
}

func DeepCopyMetadata(metadata types.Meta) types.Meta {
Expand Down
30 changes: 12 additions & 18 deletions index/pinecone/pinecone.go
Expand Up @@ -126,7 +126,7 @@ func (p *Index) IsEmpty(ctx context.Context) (bool, error) {

}

func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResponses, error) {
func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {

pineconeOptions := &option.Options{
TopK: defaultTopK,
Expand All @@ -145,9 +145,9 @@ func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti
return nil, fmt.Errorf("%s: %w", index.ErrInternal, err)
}

searchResponses := buildSearchReponsesFromPineconeMatches(matches, p.includeContent)
searchResults := buildSearchResultsFromPineconeMatches(matches, p.includeContent)

return index.FilterSearchResponses(searchResponses, pineconeOptions.TopK), nil
return index.FilterSearchResults(searchResults, pineconeOptions.TopK), nil
}

func (p *Index) similaritySearch(ctx context.Context, query string, opts *option.Options) ([]pineconeresponse.QueryMatch, error) {
Expand Down Expand Up @@ -353,17 +353,13 @@ func buildPineconeVectorsFromEmbeddingsAndDocuments(
return vectors, nil
}

func buildSearchReponsesFromPineconeMatches(matches []pineconeresponse.QueryMatch, includeContent bool) index.SearchResponses {
searchResponses := make([]index.SearchResponse, len(matches))
func buildSearchResultsFromPineconeMatches(matches []pineconeresponse.QueryMatch, includeContent bool) index.SearchResults {
searchResults := make([]index.SearchResult, len(matches))

for i, match := range matches {

metadata := index.DeepCopyMetadata(match.Metadata)

content := ""
// extract document content from vector metadata
if includeContent {
content = metadata[index.DefaultKeyContent].(string)
if !includeContent {
delete(metadata, index.DefaultKeyContent)
}

Expand All @@ -377,15 +373,13 @@ func buildSearchReponsesFromPineconeMatches(matches []pineconeresponse.QueryMatc
score = *match.Score
}

searchResponses[i] = index.SearchResponse{
ID: id,
Document: document.Document{
Metadata: metadata,
Content: content,
},
Score: score,
searchResults[i] = index.SearchResult{
ID: id,
Metadata: metadata,
Values: match.Values,
Score: score,
}
}

return searchResponses
return searchResults
}
30 changes: 12 additions & 18 deletions index/qdrant/qdrant.go
Expand Up @@ -115,7 +115,7 @@ func (p *Index) IsEmpty(ctx context.Context) (bool, error) {

}

func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResponses, error) {
func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {

qdrantOptions := &option.Options{
TopK: defaultTopK,
Expand All @@ -130,9 +130,9 @@ func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti
return nil, fmt.Errorf("%s: %w", index.ErrInternal, err)
}

searchResponses := buildSearchReponsesFromQdrantMatches(matches, q.includeContent)
searchResults := buildSearchResultsFromQdrantMatches(matches, q.includeContent)

return index.FilterSearchResponses(searchResponses, qdrantOptions.TopK), nil
return index.FilterSearchResults(searchResults, qdrantOptions.TopK), nil
}

func (p *Index) similaritySearch(ctx context.Context, query string, opts *option.Options) ([]qdrantresponse.PointSearchResult, error) {
Expand Down Expand Up @@ -288,29 +288,23 @@ func buildQdrantPointsFromEmbeddingsAndDocuments(
return vectors, nil
}

func buildSearchReponsesFromQdrantMatches(matches []qdrantresponse.PointSearchResult, includeContent bool) index.SearchResponses {
searchResponses := make([]index.SearchResponse, len(matches))
func buildSearchResultsFromQdrantMatches(matches []qdrantresponse.PointSearchResult, includeContent bool) index.SearchResults {
searchResults := make([]index.SearchResult, len(matches))

for i, match := range matches {

metadata := index.DeepCopyMetadata(match.Payload)

content := ""
// extract document content from vector metadata
if includeContent {
content = metadata[index.DefaultKeyContent].(string)
if !includeContent {
delete(metadata, index.DefaultKeyContent)
}

searchResponses[i] = index.SearchResponse{
ID: match.ID,
Document: document.Document{
Metadata: metadata,
Content: content,
},
Score: match.Score,
searchResults[i] = index.SearchResult{
ID: match.ID,
Metadata: metadata,
Values: match.Vector,
Score: match.Score,
}
}

return searchResponses
return searchResults
}
22 changes: 10 additions & 12 deletions index/simpleVectorIndex/simpleVectorIndex.go
Expand Up @@ -33,7 +33,7 @@ type Index struct {
embedder index.Embedder
}

type SimpleVectorIndexFilterFn func([]index.SearchResponse) []index.SearchResponse
type SimpleVectorIndexFilterFn func([]index.SearchResult) []index.SearchResult

func New(name string, outputPath string, embedder index.Embedder) *Index {
simpleVectorIndex := &Index{
Expand Down Expand Up @@ -140,7 +140,7 @@ func (s *Index) IsEmpty() (bool, error) {
return len(s.data) == 0, nil
}

func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResponses, error) {
func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {

sviOptions := &option.Options{
TopK: defaultTopK,
Expand All @@ -162,24 +162,22 @@ func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti

scores := s.cosineSimilarityBatch(embeddings[0])

searchResponses := make([]index.SearchResponse, len(scores))
searchResults := make([]index.SearchResult, len(scores))

for i, score := range scores {
searchResponses[i] = index.SearchResponse{
ID: s.data[i].ID,
Document: document.Document{
Content: s.data[i].Metadata[index.DefaultKeyContent].(string),
Metadata: s.data[i].Metadata,
},
Score: score,
searchResults[i] = index.SearchResult{
ID: s.data[i].ID,
Values: s.data[i].Values,
Metadata: s.data[i].Metadata,
Score: score,
}
}

if sviOptions.Filter != nil {
searchResponses = sviOptions.Filter.(SimpleVectorIndexFilterFn)(searchResponses)
searchResults = sviOptions.Filter.(SimpleVectorIndexFilterFn)(searchResults)
}

return index.FilterSearchResponses(searchResponses, sviOptions.TopK), nil
return index.FilterSearchResults(searchResults, sviOptions.TopK), nil
}

func (s *Index) cosineSimilarity(a embedder.Embedding, b embedder.Embedding) float64 {
Expand Down