Skip to content
This repository was archived by the owner on Oct 30, 2024. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 63 additions & 40 deletions pkg/vectorstore/pgvector/pgvector.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ import (
"fmt"
"log/slog"
"strings"
"sync"

"github.com/google/uuid"
"github.com/gptscript-ai/knowledge/pkg/env"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore/types"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/pgvector/pgvector-go"
cg "github.com/philippgille/chromem-go"
"github.com/tmc/langchaingo/embeddings"
)

const (
Expand All @@ -34,6 +35,9 @@ const (
// pgLockIDCreateCollection is used for advisor lock to fix issue arising from concurrent
// creation of the collection. The same value represents the same lock.
pgLockIDCreateCollection = 1573678846307946497

// VsPgvectorEmbeddingConcurrency can be set as an environment variable to control the number of parallel API calls to create embedding for documents. Default is 100
VsPgvectorEmbeddingConcurrency = "VS_PGVECTOR_EMBEDDING_CONCURRENCY"
)

var (
Expand All @@ -58,12 +62,13 @@ type CloseNoErr interface {
}

type VectorStore struct {
embedder embeddings.Embedder
conn PGXConn
embeddingTableName string
collectionTableName string
vectorDimensions int
hnswIndex *HNSWIndex
embeddingFunc cg.EmbeddingFunc
embeddingConcurrency int
conn PGXConn
embeddingTableName string
collectionTableName string
vectorDimensions int
hnswIndex *HNSWIndex
}

// HNSWIndex lets you specify the HNSW index parameters.
Expand All @@ -84,35 +89,18 @@ var DefaultHNSWIndex = &HNSWIndex{
distanceFunction: "vector_l2_ops",
}

func embeddingFuncToEmbedderClientFunc(embeddingFunc cg.EmbeddingFunc) embeddings.EmbedderClientFunc {
return func(ctx context.Context, texts []string) ([][]float32, error) {
var emb [][]float32
for _, text := range texts {
vec, err := embeddingFunc(ctx, text)
if err != nil {
return nil, err
}
emb = append(emb, vec)
}
return emb, nil
}
}

func New(ctx context.Context, dsn string, embeddingFunc cg.EmbeddingFunc) (*VectorStore, error) {
dsn = "postgres://" + strings.TrimPrefix(dsn, "pgvector://")

embedder, err := embeddings.NewEmbedder(embeddingFuncToEmbedderClientFunc(embeddingFunc))
if err != nil {
return nil, err
}

store := &VectorStore{
embeddingTableName: "knowledge_embeddings",
collectionTableName: "knowledge_collections",
embedder: embedder,
hnswIndex: nil,
embeddingTableName: "knowledge_embeddings",
collectionTableName: "knowledge_collections",
embeddingFunc: embeddingFunc,
embeddingConcurrency: env.GetIntFromEnvOrDefault(VsPgvectorEmbeddingConcurrency, 100),
hnswIndex: nil,
}

var err error
store.conn, err = pgxpool.New(ctx, dsn)
if err != nil {
return nil, err
Expand Down Expand Up @@ -277,25 +265,60 @@ func (v VectorStore) AddDocuments(ctx context.Context, docs []vs.Document, colle
texts = append(texts, doc.Content)
}

vectors, err := v.embedder.EmbedDocuments(ctx, texts)
if err != nil {
return nil, err
}
b := &pgx.Batch{}
ids := make([]string, len(docs))

if len(vectors) != len(docs) {
return nil, ErrEmbedderWrongNumberVectors
var sharedErr error
sharedErrLock := sync.Mutex{}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
setSharedErr := func(err error) {
sharedErrLock.Lock()
defer sharedErrLock.Unlock()
// Another goroutine might have already set the error.
if sharedErr == nil {
sharedErr = err
// Cancel the operation for all other goroutines.
cancel(sharedErr)
}
}

b := &pgx.Batch{}
sql := fmt.Sprintf(`INSERT INTO %s (uuid, document, embedding, cmetadata, collection_id)
VALUES($1, $2, $3, $4, $5)`, v.embeddingTableName)

ids := make([]string, len(docs))
var wg sync.WaitGroup
semaphore := make(chan struct{}, v.embeddingConcurrency)
for docIdx, doc := range docs {
id := uuid.New().String()
ids[docIdx] = id
b.Queue(sql, id, doc.Content, pgvector.NewVector(vectors[docIdx]), doc.Metadata, cid)
doc.ID = id

wg.Add(1)
go func(doc vs.Document) {
defer wg.Done()

// Don't even start if another goroutine already failed.
if ctx.Err() != nil {
return
}

// Wait here while $concurrency other goroutines are creating documents.
semaphore <- struct{}{}
defer func() { <-semaphore }()

vec, err := v.embeddingFunc(ctx, doc.Content)
if err != nil {
setSharedErr(fmt.Errorf("failed to embed document %s: %w", doc.ID, err))
return
}

b.Queue(sql, doc.ID, doc.Content, pgvector.NewVector(vec), doc.Metadata, cid)

}(doc)

docs[docIdx] = doc
}

return ids, v.conn.SendBatch(ctx, b).Close()
}

Expand All @@ -317,7 +340,7 @@ func (v VectorStore) SimilaritySearch(ctx context.Context, query string, numDocu
return nil, fmt.Errorf("pgvector does not support whereDocument")
}

queryEmbedding, err := v.embedder.EmbedQuery(ctx, query)
queryEmbedding, err := v.embeddingFunc(ctx, query)
if err != nil {
return nil, err
}
Expand Down