From c1f37e217800442b3eb0116eb752da1b7a631e68 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 10 Jun 2024 07:17:46 -0400 Subject: [PATCH] [Go] pgvector sample A sample demonstrating how to use Postgres's vector extension to build an indexer and retriever. --- go/go.mod | 2 + go/go.sum | 4 + go/samples/pgvector/main.go | 215 +++++++++++++++++++++++++++++++ go/samples/pgvector/pgvector.sql | 22 ++++ 4 files changed, 243 insertions(+) create mode 100644 go/samples/pgvector/main.go create mode 100644 go/samples/pgvector/pgvector.sql diff --git a/go/go.mod b/go/go.mod index f1a645f8ae..e012213c88 100644 --- a/go/go.mod +++ b/go/go.mod @@ -49,7 +49,9 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.4 // indirect github.com/kr/text v0.2.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/pgvector/pgvector-go v0.1.1 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect go.opencensus.io v0.24.0 // indirect diff --git a/go/go.sum b/go/go.sum index fa72e192d0..cfc7e1a148 100644 --- a/go/go.sum +++ b/go/go.sum @@ -100,8 +100,12 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/pgvector/pgvector-go v0.1.1 h1:kqJigGctFnlWvskUiYIvJRNwUtQl/aMSUZVs0YWQe+g= +github.com/pgvector/pgvector-go v0.1.1/go.mod h1:wLJgD/ODkdtd2LJK4l6evHXTuG+8PxymYAVomKHOWac= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= diff --git a/go/samples/pgvector/main.go b/go/samples/pgvector/main.go new file mode 100644 index 0000000000..b6dbc985a6 --- /dev/null +++ b/go/samples/pgvector/main.go @@ -0,0 +1,215 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This program shows how to use Postgres's pgvector extension with Genkit. + +// This program can be manually tested like so: +// +// In development mode (with the environment variable GENKIT_ENV="dev"): +// Start the server listening on port 3100: +// +// go run . -dbconn "$DBCONN" -apikey $API_KEY & +// +// Ask a question: +// +// curl -d '{"Show": "Best Friends", "Question": "Who does Alice love?"}' http://localhost:3400/askQuestion +package main + +import ( + "context" + "database/sql" + "errors" + "flag" + "fmt" + "log" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googleai" + _ "github.com/lib/pq" + pgv "github.com/pgvector/pgvector-go" +) + +var ( + connString = flag.String("dbconn", "", "database connection string") + apiKey = flag.String("apikey", "", "Gemini API key") + index = flag.Bool("index", false, "index the existing data") +) + +func main() { + flag.Parse() + if err := run(); err != nil { + log.Fatal(err) + } +} + +func run() error { + if *connString == "" { + return errors.New("need -dbconn") + } + if *apiKey == "" { + return errors.New("need -apikey") + } + ctx := context.Background() + _, ems, err := googleai.Init(ctx, googleai.Config{ + APIKey: *apiKey, + Embedders: []string{"embedding-001"}, + }) + if err != nil { + return err + } + embedder := ems[0] + + db, err := sql.Open("postgres", *connString) + if err != nil { + return err + } + defer db.Close() + + if *index { + indexer := defineIndexer(db, embedder) + if err := indexExistingRows(ctx, db, indexer); err != nil { + return err + } + } + + retriever := defineRetriever(db, embedder) + + type input struct { + Question string + Show string + } + + genkit.DefineFlow("askQuestion", func(ctx context.Context, in input, _ genkit.NoStream) (string, error) { + res, err := ai.Retrieve(ctx, retriever, &ai.RetrieverRequest{ + Document: &ai.Document{Content: []*ai.Part{ai.NewTextPart(in.Question)}}, + Options: in.Show, + }) + if err != nil { + return "", err + } + for _, doc := range res.Documents { + fmt.Printf("%+v %q\n", doc.Metadata, doc.Content[0].Text) + } + // Use documents in RAG prompts. + return "", nil + }) + + return genkit.StartFlowServer("") +} + +const provider = "pgvector" + +func defineRetriever(db *sql.DB, embedder *ai.EmbedderAction) *ai.RetrieverAction { + f := func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { + vals, err := ai.Embed(ctx, embedder, &ai.EmbedRequest{Document: req.Document}) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, ` + SELECT episode_id, season_number, chunk as content + FROM embeddings + WHERE show_id = $1 + ORDER BY embedding <#> $2 + LIMIT 2`, + req.Options, pgv.NewVector(vals)) + if err != nil { + return nil, err + } + defer rows.Close() + + res := &ai.RetrieverResponse{} + for rows.Next() { + var eid, sn int + var content string + if err := rows.Scan(&eid, &sn, &content); err != nil { + return nil, err + } + meta := map[string]any{ + "episode_id": eid, + "season_number": sn, + } + doc := &ai.Document{ + Content: []*ai.Part{ai.NewTextPart(content)}, + Metadata: meta, + } + res.Documents = append(res.Documents, doc) + } + if err := rows.Err(); err != nil { + return nil, err + } + return res, nil + } + return ai.DefineRetriever(provider, "shows", f) +} + +func defineIndexer(db *sql.DB, embedder *ai.EmbedderAction) *ai.IndexerAction { + // The indexer assumes that each Document has a single part, to be embedded, and metadata fields + // for the table primary key: show_id, season_number, episode_id. + const query = ` + UPDATE embeddings + SET embedding = $4 + WHERE show_id = $1 AND season_number = $2 AND episode_id = $3 + ` + return ai.DefineIndexer(provider, "shows", func(ctx context.Context, req *ai.IndexerRequest) error { + for i, doc := range req.Documents { + vals, err := ai.Embed(ctx, embedder, &ai.EmbedRequest{Document: doc}) + if err != nil { + return err + } + args := make([]any, 4) + for j, k := range []string{"show_id", "season_number", "episode_id"} { + if a, ok := doc.Metadata[k]; ok { + args[j] = a + } else { + return fmt.Errorf("doc[%d]: missing metadata key %q", i, k) + } + } + args[3] = pgv.NewVector(vals) + if _, err := db.ExecContext(ctx, query, args...); err != nil { + return err + } + } + return nil + }) +} + +func indexExistingRows(ctx context.Context, db *sql.DB, indexer *ai.IndexerAction) error { + rows, err := db.QueryContext(ctx, `SELECT show_id, season_number, episode_id, chunk FROM embeddings`) + if err != nil { + return err + } + defer rows.Close() + + req := &ai.IndexerRequest{} + for rows.Next() { + var sid, chunk string + var sn, eid int + if err := rows.Scan(&sid, &sn, &eid, &chunk); err != nil { + return err + } + req.Documents = append(req.Documents, &ai.Document{ + Content: []*ai.Part{ai.NewTextPart(chunk)}, + Metadata: map[string]any{ + "show_id": sid, + "season_number": sn, + "episode_id": eid, + }, + }) + } + if err := rows.Err(); err != nil { + return err + } + return ai.Index(ctx, indexer, req) +} diff --git a/go/samples/pgvector/pgvector.sql b/go/samples/pgvector/pgvector.sql new file mode 100644 index 0000000000..a252e9e3d5 --- /dev/null +++ b/go/samples/pgvector/pgvector.sql @@ -0,0 +1,22 @@ +-- This SQL enables the vector extension and creates the table and data used +-- in the accompanying sample. + +CREATE EXTENSION IF NOT EXISTS vector; + +CREATE TABLE embeddings ( + show_id TEXT NOT NULL, + season_number INTEGER NOT NULL, + episode_id INTEGER NOT NULL, + chunk TEXT, + embedding vector(768), + PRIMARY KEY (show_id, season_number, episode_id) +); + +INSERT INTO embeddings (show_id, season_number, episode_id, chunk) VALUES + ('La Vie', 1, 1, 'Natasha confesses her love for Pierre.'), + ('La Vie', 1, 2, 'Pierre and Natasha become engaged.'), + ('La Vie', 1, 3, 'Margot and Henri divorce.'), + ('Best Friends', 1, 1, 'Alice confesses her love for Oscar.'), + ('Best Friends', 1, 2, 'Oscar and Alice become engaged.'), + ('Best Friends', 1, 3, 'Bob and Pat divorce.') +;