Skip to content

Commit

Permalink
Chore Refactor similarity search results (#127)
Browse files Browse the repository at this point in the history
* chore: refactor SearchResponse

* chore: refactor examples

* chore: refactor naming
  • Loading branch information
henomis committed Sep 13, 2023
1 parent 1dfd66b commit be60c13
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 72 deletions.
6 changes: 3 additions & 3 deletions examples/embeddings/knowledge_base/main.go
Expand Up @@ -58,10 +58,10 @@ func main() {

for _, similarity := range similarities {
fmt.Printf("Similarity: %f\n", similarity.Score)
fmt.Printf("Document: %s\n", similarity.Document.Content)
fmt.Println("Metadata: ", similarity.Document.Metadata)
fmt.Printf("Document: %s\n", similarity.Content())
fmt.Println("Metadata: ", similarity.Metadata)
fmt.Println("----------")
content += similarity.Document.Content + "\n"
content += similarity.Content() + "\n"
}

systemPrompt := prompt.New("You are an helpful assistant. Answer to the questions using only " +
Expand Down
6 changes: 3 additions & 3 deletions examples/embeddings/pinecone/main.go
Expand Up @@ -59,11 +59,11 @@ func main() {
content := ""
for _, similarity := range similarities {
fmt.Printf("Similarity: %f\n", similarity.Score)
fmt.Printf("Document: %s\n", similarity.Document.Content)
fmt.Println("Metadata: ", similarity.Document.Metadata)
fmt.Printf("Document: %s\n", similarity.Content())
fmt.Println("Metadata: ", similarity.Metadata)
fmt.Println("ID: ", similarity.ID)
fmt.Println("----------")
content += similarity.Document.Content + "\n"
content += similarity.Content() + "\n"
}

llmOpenAI := openai.NewCompletion().WithVerbose(true)
Expand Down
6 changes: 3 additions & 3 deletions examples/embeddings/qdrant/main.go
Expand Up @@ -57,11 +57,11 @@ func main() {
content := ""
for _, similarity := range similarities {
fmt.Printf("Similarity: %f\n", similarity.Score)
fmt.Printf("Document: %s\n", similarity.Document.Content)
fmt.Println("Metadata: ", similarity.Document.Metadata)
fmt.Printf("Document: %s\n", similarity.Content())
fmt.Println("Metadata: ", similarity.Metadata)
fmt.Println("ID: ", similarity.ID)
fmt.Println("----------")
content += similarity.Document.Content + "\n"
content += similarity.Content() + "\n"
}

llmOpenAI := openai.NewCompletion().WithVerbose(true)
Expand Down
6 changes: 3 additions & 3 deletions examples/embeddings/simpleVector/main.go
Expand Up @@ -42,14 +42,14 @@ func main() {

for _, similarity := range similarities {
fmt.Printf("Similarity: %f\n", similarity.Score)
fmt.Printf("Document: %s\n", similarity.Document.Content)
fmt.Println("Metadata: ", similarity.Document.Metadata)
fmt.Printf("Document: %s\n", similarity.Content())
fmt.Println("Metadata: ", similarity.Metadata)
fmt.Println("----------")
}

documentContext := ""
for _, similarity := range similarities {
documentContext += similarity.Document.Content + "\n\n"
documentContext += similarity.Content() + "\n\n"
}

llmOpenAI := openai.NewCompletion()
Expand Down
36 changes: 24 additions & 12 deletions index/index.go
Expand Up @@ -18,18 +18,30 @@ const (
DefaultKeyContent = "content"
)

type SearchResponse struct {
type SearchResult struct {
ID string
Document document.Document
Values []float64
Metadata types.Meta
Score float64
}

type SearchResponses []SearchResponse
func (s *SearchResult) Content() string {
return s.Metadata[DefaultKeyContent].(string)
}

type SearchResults []SearchResult

func (s SearchResponses) ToDocuments() []document.Document {
func (s SearchResults) ToDocuments() []document.Document {
documents := make([]document.Document, len(s))
for i, searchResponse := range s {
documents[i] = searchResponse.Document
for i, searchResult := range s {
metadata := DeepCopyMetadata(searchResult.Metadata)
content := metadata[DefaultKeyContent].(string)
delete(metadata, DefaultKeyContent)

documents[i] = document.Document{
Content: content,
Metadata: metadata,
}
}
return documents
}
Expand All @@ -38,18 +50,18 @@ type Embedder interface {
Embed(ctx context.Context, texts []string) ([]embedder.Embedding, error)
}

func FilterSearchResponses(searchResponses SearchResponses, topK int) SearchResponses {
func FilterSearchResults(searchResults SearchResults, topK int) SearchResults {
//sort by similarity score
sort.Slice(searchResponses, func(i, j int) bool {
return searchResponses[i].Score > searchResponses[j].Score
sort.Slice(searchResults, func(i, j int) bool {
return searchResults[i].Score > searchResults[j].Score
})

maxTopK := topK
if maxTopK > len(searchResponses) {
maxTopK = len(searchResponses)
if maxTopK > len(searchResults) {
maxTopK = len(searchResults)
}

return searchResponses[:maxTopK]
return searchResults[:maxTopK]
}

func DeepCopyMetadata(metadata types.Meta) types.Meta {
Expand Down
30 changes: 12 additions & 18 deletions index/pinecone/pinecone.go
Expand Up @@ -126,7 +126,7 @@ func (p *Index) IsEmpty(ctx context.Context) (bool, error) {

}

func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResponses, error) {
func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {

pineconeOptions := &option.Options{
TopK: defaultTopK,
Expand All @@ -145,9 +145,9 @@ func (p *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti
return nil, fmt.Errorf("%s: %w", index.ErrInternal, err)
}

searchResponses := buildSearchReponsesFromPineconeMatches(matches, p.includeContent)
searchResults := buildSearchResultsFromPineconeMatches(matches, p.includeContent)

return index.FilterSearchResponses(searchResponses, pineconeOptions.TopK), nil
return index.FilterSearchResults(searchResults, pineconeOptions.TopK), nil
}

func (p *Index) similaritySearch(ctx context.Context, query string, opts *option.Options) ([]pineconeresponse.QueryMatch, error) {
Expand Down Expand Up @@ -353,17 +353,13 @@ func buildPineconeVectorsFromEmbeddingsAndDocuments(
return vectors, nil
}

func buildSearchReponsesFromPineconeMatches(matches []pineconeresponse.QueryMatch, includeContent bool) index.SearchResponses {
searchResponses := make([]index.SearchResponse, len(matches))
func buildSearchResultsFromPineconeMatches(matches []pineconeresponse.QueryMatch, includeContent bool) index.SearchResults {
searchResults := make([]index.SearchResult, len(matches))

for i, match := range matches {

metadata := index.DeepCopyMetadata(match.Metadata)

content := ""
// extract document content from vector metadata
if includeContent {
content = metadata[index.DefaultKeyContent].(string)
if !includeContent {
delete(metadata, index.DefaultKeyContent)
}

Expand All @@ -377,15 +373,13 @@ func buildSearchReponsesFromPineconeMatches(matches []pineconeresponse.QueryMatc
score = *match.Score
}

searchResponses[i] = index.SearchResponse{
ID: id,
Document: document.Document{
Metadata: metadata,
Content: content,
},
Score: score,
searchResults[i] = index.SearchResult{
ID: id,
Metadata: metadata,
Values: match.Values,
Score: score,
}
}

return searchResponses
return searchResults
}
30 changes: 12 additions & 18 deletions index/qdrant/qdrant.go
Expand Up @@ -115,7 +115,7 @@ func (p *Index) IsEmpty(ctx context.Context) (bool, error) {

}

func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResponses, error) {
func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {

qdrantOptions := &option.Options{
TopK: defaultTopK,
Expand All @@ -130,9 +130,9 @@ func (q *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti
return nil, fmt.Errorf("%s: %w", index.ErrInternal, err)
}

searchResponses := buildSearchReponsesFromQdrantMatches(matches, q.includeContent)
searchResults := buildSearchResultsFromQdrantMatches(matches, q.includeContent)

return index.FilterSearchResponses(searchResponses, qdrantOptions.TopK), nil
return index.FilterSearchResults(searchResults, qdrantOptions.TopK), nil
}

func (p *Index) similaritySearch(ctx context.Context, query string, opts *option.Options) ([]qdrantresponse.PointSearchResult, error) {
Expand Down Expand Up @@ -288,29 +288,23 @@ func buildQdrantPointsFromEmbeddingsAndDocuments(
return vectors, nil
}

func buildSearchReponsesFromQdrantMatches(matches []qdrantresponse.PointSearchResult, includeContent bool) index.SearchResponses {
searchResponses := make([]index.SearchResponse, len(matches))
func buildSearchResultsFromQdrantMatches(matches []qdrantresponse.PointSearchResult, includeContent bool) index.SearchResults {
searchResults := make([]index.SearchResult, len(matches))

for i, match := range matches {

metadata := index.DeepCopyMetadata(match.Payload)

content := ""
// extract document content from vector metadata
if includeContent {
content = metadata[index.DefaultKeyContent].(string)
if !includeContent {
delete(metadata, index.DefaultKeyContent)
}

searchResponses[i] = index.SearchResponse{
ID: match.ID,
Document: document.Document{
Metadata: metadata,
Content: content,
},
Score: match.Score,
searchResults[i] = index.SearchResult{
ID: match.ID,
Metadata: metadata,
Values: match.Vector,
Score: match.Score,
}
}

return searchResponses
return searchResults
}
22 changes: 10 additions & 12 deletions index/simpleVectorIndex/simpleVectorIndex.go
Expand Up @@ -33,7 +33,7 @@ type Index struct {
embedder index.Embedder
}

type SimpleVectorIndexFilterFn func([]index.SearchResponse) []index.SearchResponse
type SimpleVectorIndexFilterFn func([]index.SearchResult) []index.SearchResult

func New(name string, outputPath string, embedder index.Embedder) *Index {
simpleVectorIndex := &Index{
Expand Down Expand Up @@ -140,7 +140,7 @@ func (s *Index) IsEmpty() (bool, error) {
return len(s.data) == 0, nil
}

func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResponses, error) {
func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...option.Option) (index.SearchResults, error) {

sviOptions := &option.Options{
TopK: defaultTopK,
Expand All @@ -162,24 +162,22 @@ func (s *Index) SimilaritySearch(ctx context.Context, query string, opts ...opti

scores := s.cosineSimilarityBatch(embeddings[0])

searchResponses := make([]index.SearchResponse, len(scores))
searchResults := make([]index.SearchResult, len(scores))

for i, score := range scores {
searchResponses[i] = index.SearchResponse{
ID: s.data[i].ID,
Document: document.Document{
Content: s.data[i].Metadata[index.DefaultKeyContent].(string),
Metadata: s.data[i].Metadata,
},
Score: score,
searchResults[i] = index.SearchResult{
ID: s.data[i].ID,
Values: s.data[i].Values,
Metadata: s.data[i].Metadata,
Score: score,
}
}

if sviOptions.Filter != nil {
searchResponses = sviOptions.Filter.(SimpleVectorIndexFilterFn)(searchResponses)
searchResults = sviOptions.Filter.(SimpleVectorIndexFilterFn)(searchResults)
}

return index.FilterSearchResponses(searchResponses, sviOptions.TopK), nil
return index.FilterSearchResults(searchResults, sviOptions.TopK), nil
}

func (s *Index) cosineSimilarity(a embedder.Embedding, b embedder.Embedding) float64 {
Expand Down

0 comments on commit be60c13

Please sign in to comment.