From 0ebc184a431c2e9e7e100d2737aee44819038895 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 10 Jun 2024 08:24:00 -0400 Subject: [PATCH 1/2] [Go] googleai.Init: return actions Init returns the models and embedders that it defines. Users can access defined actions in one of three ways: by slice index, by calling Action.Name on a slice element, or by calling the Model or Embedder functions. --- go/plugins/googleai/googleai.go | 27 ++++++++++++++++----------- go/plugins/googleai/googleai_test.go | 2 +- go/samples/coffee-shop/main.go | 2 +- go/samples/rag/main.go | 2 +- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index dd4aa42ec3..73df7ebf8d 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -42,7 +42,12 @@ type Config struct { Embedders []string } -func Init(ctx context.Context, cfg Config) (err error) { +// Init initializes the plugin. +// It defines all the configured models and embedders, and returns their actions. +// If [Config.Models] or [Config.Embedders] is non-empty, the actions are returned in the same +// order; otherwise the order is undefined. Call the [Model] or [Embedder] functions to get an action +// from the registry by name, or call the Name method of the action to get its name. +func Init(ctx context.Context, cfg Config) (models []*ai.ModelAction, embedders []*ai.EmbedderAction, err error) { defer func() { if err != nil { err = fmt.Errorf("googleai.Init: %w", err) @@ -50,12 +55,12 @@ func Init(ctx context.Context, cfg Config) (err error) { }() if cfg.APIKey == "" { - return errors.New("missing API key") + return nil, nil, errors.New("missing API key") } client, err := genai.NewClient(ctx, option.WithAPIKey(cfg.APIKey)) if err != nil { - return err + return nil, nil, err } needModels := len(cfg.Models) == 0 @@ -68,7 +73,7 @@ func Init(ctx context.Context, cfg Config) (err error) { break } if err != nil { - return err + return nil, nil, err } // Model names are of the form "models/name". name := path.Base(mi.Name) @@ -81,15 +86,15 @@ func Init(ctx context.Context, cfg Config) (err error) { } } for _, name := range cfg.Models { - defineModel(name, client) + models = append(models, defineModel(name, client)) } for _, name := range cfg.Embedders { - defineEmbedder(name, client) + embedders = append(embedders, defineEmbedder(name, client)) } - return nil + return models, embedders, nil } -func defineModel(name string, client *genai.Client) { +func defineModel(name string, client *genai.Client) *ai.ModelAction { meta := &ai.ModelMetadata{ Label: "Google AI - " + name, Supports: ai.ModelCapabilities{ @@ -97,11 +102,11 @@ func defineModel(name string, client *genai.Client) { }, } g := generator{model: name, client: client} - ai.DefineModel(provider, name, meta, g.generate) + return ai.DefineModel(provider, name, meta, g.generate) } -func defineEmbedder(name string, client *genai.Client) { - ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) { +func defineEmbedder(name string, client *genai.Client) *ai.EmbedderAction { + return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) { em := client.EmbeddingModel(name) parts, err := convertParts(input.Document.Content) if err != nil { diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index 57df79c4c9..90357f91b4 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -40,7 +40,7 @@ func TestLive(t *testing.T) { t.Skipf("no -key provided") } ctx := context.Background() - err := googleai.Init(ctx, googleai.Config{ + _, _, err := googleai.Init(ctx, googleai.Config{ APIKey: *apiKey, Embedders: []string{embeddingModel}, Models: []string{generativeModel}, diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index a9899ad0f1..4bf28d0caf 100755 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -104,7 +104,7 @@ func main() { fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.") os.Exit(1) } - err := googleai.Init(context.Background(), googleai.Config{APIKey: apiKey}) + _, _, err := googleai.Init(context.Background(), googleai.Config{APIKey: apiKey}) if err != nil { log.Fatal(err) } diff --git a/go/samples/rag/main.go b/go/samples/rag/main.go index 544fdfab8e..463cd018f3 100644 --- a/go/samples/rag/main.go +++ b/go/samples/rag/main.go @@ -75,7 +75,7 @@ func main() { fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.") os.Exit(1) } - err := googleai.Init(context.Background(), googleai.Config{APIKey: apiKey}) + _, _, err := googleai.Init(context.Background(), googleai.Config{APIKey: apiKey}) if err != nil { log.Fatal(err) } From c828f5f045d6b32f0eaacf778aece43523dd0810 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 10 Jun 2024 17:05:39 -0400 Subject: [PATCH 2/2] [Go] pgvector sample (#375) A sample demonstrating how to use Postgres's vector extension to build an indexer and retriever. --- go/go.mod | 2 + go/go.sum | 4 + go/samples/pgvector/main.go | 215 +++++++++++++++++++++++++++++++ go/samples/pgvector/pgvector.sql | 22 ++++ 4 files changed, 243 insertions(+) create mode 100644 go/samples/pgvector/main.go create mode 100644 go/samples/pgvector/pgvector.sql diff --git a/go/go.mod b/go/go.mod index f1a645f8ae..e012213c88 100644 --- a/go/go.mod +++ b/go/go.mod @@ -49,7 +49,9 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.4 // indirect github.com/kr/text v0.2.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/pgvector/pgvector-go v0.1.1 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect go.opencensus.io v0.24.0 // indirect diff --git a/go/go.sum b/go/go.sum index fa72e192d0..cfc7e1a148 100644 --- a/go/go.sum +++ b/go/go.sum @@ -100,8 +100,12 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/pgvector/pgvector-go v0.1.1 h1:kqJigGctFnlWvskUiYIvJRNwUtQl/aMSUZVs0YWQe+g= +github.com/pgvector/pgvector-go v0.1.1/go.mod h1:wLJgD/ODkdtd2LJK4l6evHXTuG+8PxymYAVomKHOWac= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= diff --git a/go/samples/pgvector/main.go b/go/samples/pgvector/main.go new file mode 100644 index 0000000000..b6dbc985a6 --- /dev/null +++ b/go/samples/pgvector/main.go @@ -0,0 +1,215 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This program shows how to use Postgres's pgvector extension with Genkit. + +// This program can be manually tested like so: +// +// In development mode (with the environment variable GENKIT_ENV="dev"): +// Start the server listening on port 3100: +// +// go run . -dbconn "$DBCONN" -apikey $API_KEY & +// +// Ask a question: +// +// curl -d '{"Show": "Best Friends", "Question": "Who does Alice love?"}' http://localhost:3400/askQuestion +package main + +import ( + "context" + "database/sql" + "errors" + "flag" + "fmt" + "log" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googleai" + _ "github.com/lib/pq" + pgv "github.com/pgvector/pgvector-go" +) + +var ( + connString = flag.String("dbconn", "", "database connection string") + apiKey = flag.String("apikey", "", "Gemini API key") + index = flag.Bool("index", false, "index the existing data") +) + +func main() { + flag.Parse() + if err := run(); err != nil { + log.Fatal(err) + } +} + +func run() error { + if *connString == "" { + return errors.New("need -dbconn") + } + if *apiKey == "" { + return errors.New("need -apikey") + } + ctx := context.Background() + _, ems, err := googleai.Init(ctx, googleai.Config{ + APIKey: *apiKey, + Embedders: []string{"embedding-001"}, + }) + if err != nil { + return err + } + embedder := ems[0] + + db, err := sql.Open("postgres", *connString) + if err != nil { + return err + } + defer db.Close() + + if *index { + indexer := defineIndexer(db, embedder) + if err := indexExistingRows(ctx, db, indexer); err != nil { + return err + } + } + + retriever := defineRetriever(db, embedder) + + type input struct { + Question string + Show string + } + + genkit.DefineFlow("askQuestion", func(ctx context.Context, in input, _ genkit.NoStream) (string, error) { + res, err := ai.Retrieve(ctx, retriever, &ai.RetrieverRequest{ + Document: &ai.Document{Content: []*ai.Part{ai.NewTextPart(in.Question)}}, + Options: in.Show, + }) + if err != nil { + return "", err + } + for _, doc := range res.Documents { + fmt.Printf("%+v %q\n", doc.Metadata, doc.Content[0].Text) + } + // Use documents in RAG prompts. + return "", nil + }) + + return genkit.StartFlowServer("") +} + +const provider = "pgvector" + +func defineRetriever(db *sql.DB, embedder *ai.EmbedderAction) *ai.RetrieverAction { + f := func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { + vals, err := ai.Embed(ctx, embedder, &ai.EmbedRequest{Document: req.Document}) + if err != nil { + return nil, err + } + rows, err := db.QueryContext(ctx, ` + SELECT episode_id, season_number, chunk as content + FROM embeddings + WHERE show_id = $1 + ORDER BY embedding <#> $2 + LIMIT 2`, + req.Options, pgv.NewVector(vals)) + if err != nil { + return nil, err + } + defer rows.Close() + + res := &ai.RetrieverResponse{} + for rows.Next() { + var eid, sn int + var content string + if err := rows.Scan(&eid, &sn, &content); err != nil { + return nil, err + } + meta := map[string]any{ + "episode_id": eid, + "season_number": sn, + } + doc := &ai.Document{ + Content: []*ai.Part{ai.NewTextPart(content)}, + Metadata: meta, + } + res.Documents = append(res.Documents, doc) + } + if err := rows.Err(); err != nil { + return nil, err + } + return res, nil + } + return ai.DefineRetriever(provider, "shows", f) +} + +func defineIndexer(db *sql.DB, embedder *ai.EmbedderAction) *ai.IndexerAction { + // The indexer assumes that each Document has a single part, to be embedded, and metadata fields + // for the table primary key: show_id, season_number, episode_id. + const query = ` + UPDATE embeddings + SET embedding = $4 + WHERE show_id = $1 AND season_number = $2 AND episode_id = $3 + ` + return ai.DefineIndexer(provider, "shows", func(ctx context.Context, req *ai.IndexerRequest) error { + for i, doc := range req.Documents { + vals, err := ai.Embed(ctx, embedder, &ai.EmbedRequest{Document: doc}) + if err != nil { + return err + } + args := make([]any, 4) + for j, k := range []string{"show_id", "season_number", "episode_id"} { + if a, ok := doc.Metadata[k]; ok { + args[j] = a + } else { + return fmt.Errorf("doc[%d]: missing metadata key %q", i, k) + } + } + args[3] = pgv.NewVector(vals) + if _, err := db.ExecContext(ctx, query, args...); err != nil { + return err + } + } + return nil + }) +} + +func indexExistingRows(ctx context.Context, db *sql.DB, indexer *ai.IndexerAction) error { + rows, err := db.QueryContext(ctx, `SELECT show_id, season_number, episode_id, chunk FROM embeddings`) + if err != nil { + return err + } + defer rows.Close() + + req := &ai.IndexerRequest{} + for rows.Next() { + var sid, chunk string + var sn, eid int + if err := rows.Scan(&sid, &sn, &eid, &chunk); err != nil { + return err + } + req.Documents = append(req.Documents, &ai.Document{ + Content: []*ai.Part{ai.NewTextPart(chunk)}, + Metadata: map[string]any{ + "show_id": sid, + "season_number": sn, + "episode_id": eid, + }, + }) + } + if err := rows.Err(); err != nil { + return err + } + return ai.Index(ctx, indexer, req) +} diff --git a/go/samples/pgvector/pgvector.sql b/go/samples/pgvector/pgvector.sql new file mode 100644 index 0000000000..a252e9e3d5 --- /dev/null +++ b/go/samples/pgvector/pgvector.sql @@ -0,0 +1,22 @@ +-- This SQL enables the vector extension and creates the table and data used +-- in the accompanying sample. + +CREATE EXTENSION IF NOT EXISTS vector; + +CREATE TABLE embeddings ( + show_id TEXT NOT NULL, + season_number INTEGER NOT NULL, + episode_id INTEGER NOT NULL, + chunk TEXT, + embedding vector(768), + PRIMARY KEY (show_id, season_number, episode_id) +); + +INSERT INTO embeddings (show_id, season_number, episode_id, chunk) VALUES + ('La Vie', 1, 1, 'Natasha confesses her love for Pierre.'), + ('La Vie', 1, 2, 'Pierre and Natasha become engaged.'), + ('La Vie', 1, 3, 'Margot and Henri divorce.'), + ('Best Friends', 1, 1, 'Alice confesses her love for Oscar.'), + ('Best Friends', 1, 2, 'Oscar and Alice become engaged.'), + ('Best Friends', 1, 3, 'Bob and Pat divorce.') +;