From 44342b8fc6bdb2dd884e3f7a3288f9bb67d81504 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 10 Jun 2024 14:06:00 -0400 Subject: [PATCH 1/4] [Go] proposal: revised API In this API, configuration (defining actions) should still happen at the beginning of the program, but can be done incrementally by calling functions on a plugin. For example: err := googleai.Init(...) // creates state common to all actions m1 := googleai.DefineModel("gemini-1.0-pro") e1 := googleai.DefineEmbedder("emb-01") When everything has been configured, the program must call genkit.Init. This has two effects: - Attempts to register additional actions will fail. - Servers are started. This API combines the convenience and clarity of building and storing actions individually with the requirement that everything that the program needs is registered early. --- go/core/registry.go | 10 ++ go/core/servers.go | 27 ++++- go/genkit/genkit.go | 24 ++++ go/plugins/googleai/googleai.go | 158 ++++++++++++++++++--------- go/plugins/googleai/googleai_test.go | 46 +++++--- go/plugins/localvec/localvec.go | 46 +++----- go/plugins/localvec/localvec_test.go | 20 ++-- go/plugins/pinecone/genkit.go | 91 ++++++++++----- go/plugins/pinecone/genkit_test.go | 15 ++- go/plugins/vertexai/vertexai.go | 113 +++++++++---------- go/plugins/vertexai/vertexai_test.go | 15 +-- go/samples/coffee-shop/main.go | 6 +- go/samples/flow-sample1/main.go | 2 +- go/samples/menu/main.go | 31 ++---- go/samples/rag/main.go | 32 +++--- 15 files changed, 383 insertions(+), 253 deletions(-) diff --git a/go/core/registry.go b/go/core/registry.go index 55385bfa08..30060d98aa 100644 --- a/go/core/registry.go +++ b/go/core/registry.go @@ -49,6 +49,7 @@ func init() { type registry struct { tstate *tracing.State mu sync.Mutex + frozen bool // when true, no more additions actions map[string]action flows []flow // TraceStores, at most one for each [Environment]. @@ -98,6 +99,9 @@ func (r *registry) registerAction(a action) { key := fmt.Sprintf("/%s/%s", a.actionType(), a.Name()) r.mu.Lock() defer r.mu.Unlock() + if r.frozen { + panic(fmt.Sprintf("attempt to register action %s in a frozen registry. Register before calling genkit.Init", key)) + } if _, ok := r.actions[key]; ok { panic(fmt.Sprintf("action %q is already registered", key)) } @@ -108,6 +112,12 @@ func (r *registry) registerAction(a action) { "name", a.Name()) } +func (r *registry) freeze() { + r.mu.Lock() + defer r.mu.Unlock() + r.frozen = true +} + // lookupAction returns the action for the given key, or nil if there is none. func (r *registry) lookupAction(key string) action { r.mu.Lock() diff --git a/go/core/servers.go b/go/core/servers.go index 460948bf34..41458272eb 100644 --- a/go/core/servers.go +++ b/go/core/servers.go @@ -52,13 +52,36 @@ import ( // // StartFlowServer always returns a non-nil error, the one returned by http.ListenAndServe. func StartFlowServer(addr string) error { + return startProdServer(addr) +} + +// InternalInit is for use by the genkit package only. +// It is not subject to compatibility guarantees. +func InternalInit(opts *Options) error { + if opts == nil { + opts = &Options{} + } + globalRegistry.freeze() + if currentEnvironment() == EnvironmentDev { go func() { - err := startDevServer("") + err := startDevServer(opts.DevAddr) slog.Error("dev server stopped", "err", err) }() } - return startProdServer(addr) + if opts.FlowAddr == "-" { + return nil + } + return StartFlowServer(opts.FlowAddr) +} + +// Options are options to [InternalInit]. +type Options struct { + DevAddr string + // If "-", do not start a FlowServer. + // Otherwise, start a FlowServer on the given address, or the + // default if empty. + FlowAddr string } // startDevServer starts the development server (reflection API) listening at the given address. diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index ce39922287..d80a633521 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -23,6 +23,30 @@ import ( "github.com/firebase/genkit/go/core" ) +// Options are options to [Init]. +type Options = core.Options + +// Init initializes Genkit. +// After it is called, no further actions can be defined. +// +// Init starts servers depending on the value of the GENKIT_ENV +// environment variable and the provided options. +// + +// If GENKIT_ENV = "dev", a development server is started +// in a separate goroutine at the address in opts.DevAddr, or the default +// if empty. +// +// If opts.FlowAddr is a value other than "-", a flow server is started (see [StartFlowServer]) +// and the call to Init waits for the server to shut down. +// If opts.FlowAddr == "-", no flow server is started and Init returns immediately. +// +// Thus Init(nil) will start a dev server in the "dev" environment, will always start +// a flow server, and will pause execution until the flow server terminates. +func Init(opts *Options) error { + return core.InternalInit(opts) +} + // DefineFlow creates a Flow that runs fn, and registers it as an action. // // fn takes an input of type In and returns an output of type Out. diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 9b8128b823..25cd93e702 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -16,10 +16,10 @@ package googleai import ( "context" - "errors" "fmt" "path" "slices" + "sync" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/plugins/internal/uri" @@ -30,79 +30,75 @@ import ( const provider = "googleai" -// Config provides configuration options for the Init function. -type Config struct { - // API key. Required. - APIKey string - // Generative models to provide. - // If empty, a complete list will be obtained from the service. - Models []string - // Embedding models to provide. - // If empty, a complete list will be obtained from the service. - Embedders []string +var state struct { + mu sync.Mutex + initted bool + client *genai.Client + // Results from ListModels + modelNames []string + embedderNames []string } -func Init(ctx context.Context, cfg Config) (err error) { +// Init initializes the plugin. +// After calling Init, call [DefineModel] and [DefineEmbedder] to create and register +// generative models and embedders. +func Init(ctx context.Context, apiKey string) (err error) { + state.mu.Lock() + defer state.mu.Unlock() + if state.initted { + panic("googleai.Init already called") + } defer func() { if err != nil { err = fmt.Errorf("googleai.Init: %w", err) } }() - if cfg.APIKey == "" { - return errors.New("missing API key") - } - - client, err := genai.NewClient(ctx, option.WithAPIKey(cfg.APIKey)) + client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) if err != nil { return err } + state.client = client + state.initted = true + return nil +} - needModels := len(cfg.Models) == 0 - needEmbedders := len(cfg.Embedders) == 0 - if needModels || needEmbedders { - iter := client.ListModels(ctx) - for { - mi, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return err - } - // Model names are of the form "models/name". - name := path.Base(mi.Name) - if needModels && slices.Contains(mi.SupportedGenerationMethods, "generateContent") { - cfg.Models = append(cfg.Models, name) - } - if needEmbedders && slices.Contains(mi.SupportedGenerationMethods, "embedContent") { - cfg.Embedders = append(cfg.Embedders, name) - } - } - } - for _, name := range cfg.Models { - defineModel(name, client) - } - for _, name := range cfg.Embedders { - defineEmbedder(name, client) +// DefineModel defines a model with the given name. +func DefineModel(name string) *ai.ModelAction { + state.mu.Lock() + defer state.mu.Unlock() + if !state.initted { + panic("googleai.Init not called") } - return nil + return defineModel(name) } -func defineModel(name string, client *genai.Client) { +// requires state.mu +func defineModel(name string) *ai.ModelAction { meta := &ai.ModelMetadata{ Label: "Google AI - " + name, Supports: ai.ModelCapabilities{ Multiturn: true, }, } - g := generator{model: name, client: client} - ai.DefineModel(provider, name, meta, g.generate) + g := generator{model: name, client: state.client} + return ai.DefineModel(provider, name, meta, g.generate) +} + +// DefineEmbedder defines an embedder with a given name. +func DefineEmbedder(name string) *ai.EmbedderAction { + state.mu.Lock() + defer state.mu.Unlock() + if !state.initted { + panic("googleai.Init not called") + } + return defineEmbedder(name) } -func defineEmbedder(name string, client *genai.Client) { - ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) { - em := client.EmbeddingModel(name) +// requires state.mu +func defineEmbedder(name string) *ai.EmbedderAction { + return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) { + em := state.client.EmbeddingModel(name) parts, err := convertParts(input.Document.Content) if err != nil { return nil, err @@ -115,6 +111,66 @@ func defineEmbedder(name string, client *genai.Client) { }) } +// DefineAllModels defines all models known to the service. +func DefineAllModels(ctx context.Context) ([]*ai.ModelAction, error) { + state.mu.Lock() + defer state.mu.Unlock() + if !state.initted { + panic("googleai.Init not called") + } + if err := listModels(ctx); err != nil { + return nil, err + } + var mas []*ai.ModelAction + for _, mod := range state.modelNames { + mas = append(mas, defineModel(mod)) + } + return mas, nil +} + +// DefineAllEmbedders defines all embedders known to the service. +func DefineAllEmbedders(ctx context.Context) ([]*ai.EmbedderAction, error) { + state.mu.Lock() + defer state.mu.Unlock() + if !state.initted { + panic("googleai.Init not called") + } + if err := listModels(ctx); err != nil { + return nil, err + } + var eas []*ai.EmbedderAction + for _, em := range state.embedderNames { + eas = append(eas, defineEmbedder(em)) + } + return eas, nil +} + +// requires state.mu +func listModels(ctx context.Context) error { + if len(state.modelNames) > 0 || len(state.embedderNames) > 0 { + // already called + return nil + } + iter := state.client.ListModels(ctx) + for { + mi, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + // Model names are of the form "models/name". + name := path.Base(mi.Name) + if slices.Contains(mi.SupportedGenerationMethods, "generateContent") { + state.modelNames = append(state.modelNames, name) + } else if slices.Contains(mi.SupportedGenerationMethods, "embedContent") { + state.embedderNames = append(state.embedderNames, name) + } + } + return nil +} + // Model returns the [ai.ModelAction] with the given name. // It returns nil if the model was not configured. func Model(name string) *ai.ModelAction { diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index 9b68654751..dd69198bdd 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -30,24 +30,24 @@ import ( // The tests here only work with an API key set to a valid value. var apiKey = flag.String("key", "", "Gemini API key") -const ( - embeddingModel = "embedding-001" - generativeModel = "gemini-1.0-pro" -) +// We can't test the DefineAll functions along with the other tests because +// we get duplicate definitions of models. +var testAll = flag.Bool("all", false, "test DefineAllXXX functions") func TestLive(t *testing.T) { if *apiKey == "" { t.Skipf("no -key provided") } + if *testAll { + t.Skip("-all provided") + } ctx := context.Background() - err := googleai.Init(ctx, googleai.Config{ - APIKey: *apiKey, - Embedders: []string{embeddingModel}, - Models: []string{generativeModel}, - }) + err := googleai.Init(ctx, *apiKey) if err != nil { t.Fatal(err) } + embedder := googleai.DefineEmbedder("embedding-001") + model := googleai.DefineModel("gemini-1.0-pro") toolDef := &ai.ToolDefinition{ Name: "exponentiation", InputSchema: map[string]any{"base": "float64", "exponent": "int"}, @@ -82,7 +82,7 @@ func TestLive(t *testing.T) { }, ) t.Run("embedder", func(t *testing.T) { - out, err := ai.Embed(ctx, googleai.Embedder(embeddingModel), &ai.EmbedRequest{ + out, err := ai.Embed(ctx, embedder, &ai.EmbedRequest{ Document: ai.DocumentFromText("yellow banana", nil), }) if err != nil { @@ -103,7 +103,6 @@ func TestLive(t *testing.T) { } }) t.Run("generate", func(t *testing.T) { - g := googleai.Model(generativeModel) req := &ai.GenerateRequest{ Candidates: 1, Messages: []*ai.Message{ @@ -114,7 +113,7 @@ func TestLive(t *testing.T) { }, } - resp, err := ai.Generate(ctx, g, req, nil) + resp, err := ai.Generate(ctx, model, req, nil) if err != nil { t.Fatal(err) } @@ -140,8 +139,7 @@ func TestLive(t *testing.T) { out := "" parts := 0 - g := googleai.Model(generativeModel) - final, err := ai.Generate(ctx, g, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error { + final, err := ai.Generate(ctx, model, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error { parts++ out += c.Content[0].Text return nil @@ -177,7 +175,7 @@ func TestLive(t *testing.T) { Tools: []*ai.ToolDefinition{toolDef}, } - resp, err := ai.Generate(ctx, googleai.Model(generativeModel), req, nil) + resp, err := ai.Generate(ctx, model, req, nil) if err != nil { t.Fatal(err) } @@ -189,3 +187,21 @@ func TestLive(t *testing.T) { } }) } + +func TestAllModels(t *testing.T) { + if !*testAll { + t.Skip("-all not set") + } + ctx := context.Background() + if err := googleai.Init(ctx, *apiKey); err != nil { + t.Fatal(err) + } + mods, err := googleai.DefineAllModels(ctx) + if err != nil || len(mods) == 0 { + t.Fatalf("got %d, %v, want >0, nil", len(mods), err) + } + embs, err := googleai.DefineAllEmbedders(ctx) + if err != nil || len(embs) == 0 { + t.Fatalf("got %d, %v, want >0, nil", len(mods), err) + } +} diff --git a/go/plugins/localvec/localvec.go b/go/plugins/localvec/localvec.go index a63a930c56..3258af9aee 100644 --- a/go/plugins/localvec/localvec.go +++ b/go/plugins/localvec/localvec.go @@ -36,47 +36,29 @@ import ( const provider = "devLocalVectorStore" -// Config provides configuration options for the Init function. type Config struct { - Stores []StoreConfig -} - -type StoreConfig struct { // Where to store the data. Defaults to os.TempDir. - Dir string - // A name that uniquely identifies the store. - Name string + Dir string Embedder *ai.EmbedderAction EmbedderOptions any } -// Init initializes a new local vector database. This will register new -// indexers and retrievers, and return them in the same order as [Config.Stores]. -// You can also call the Name method on an indexer or retriever. -// Each indexer and retriever may only be used by a single goroutine at a time. -func Init(ctx context.Context, cfg Config) (_ []*ai.IndexerAction, _ []*ai.RetrieverAction, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("localvec.Init: %w", err) - } - }() - var ias []*ai.IndexerAction - var ras []*ai.RetrieverAction - for _, sc := range cfg.Stores { - ds, err := newDocStore(sc.Dir, sc.Name, sc.Embedder, sc.EmbedderOptions) - if err != nil { - return nil, nil, err - } - ia := ai.DefineIndexer(provider, sc.Name, ds.index) - ra := ai.DefineRetriever(provider, sc.Name, ds.retrieve) - ias = append(ias, ia) - ras = append(ras, ra) +// Init initializes the plugin. +func Init() error { return nil } + +// DefineStore defines an indexer and retriever that share the same underlying storage. +// The name uniquely identifies the the indexer and retriever in the registry. +func DefineStore(name string, cfg Config) (*ai.IndexerAction, *ai.RetrieverAction, error) { + ds, err := newDocStore(cfg.Dir, name, cfg.Embedder, cfg.EmbedderOptions) + if err != nil { + return nil, nil, err } - return ias, ras, nil + return ai.DefineIndexer(provider, name, ds.index), + ai.DefineRetriever(provider, name, ds.retrieve), + nil } -// Indexer returns the indexer with the given name. -// The name must match the [StoreConfig.Name] value passed to [Init]. +// Indexer returns the registered indexer with the given name. func Indexer(name string) *ai.IndexerAction { return ai.LookupIndexer(provider, name) } diff --git a/go/plugins/localvec/localvec_test.go b/go/plugins/localvec/localvec_test.go index 30774681e5..c108eebaec 100644 --- a/go/plugins/localvec/localvec_test.go +++ b/go/plugins/localvec/localvec_test.go @@ -17,7 +17,6 @@ package localvec import ( "context" "math" - "slices" "strings" "testing" @@ -190,20 +189,19 @@ func TestSimilarity(t *testing.T) { func TestInit(t *testing.T) { embedder := ai.DefineEmbedder("fake", "e", fakeembedder.New().Embed) - is, rs, err := Init(context.Background(), Config{Stores: []StoreConfig{ - {Name: "a", Embedder: embedder}, - {Name: "b", Embedder: embedder}, - }}) + if err := Init(); err != nil { + t.Fatal(err) + } + const name = "mystore" + ind, ret, err := DefineStore(name, Config{Embedder: embedder}) if err != nil { t.Fatal(err) } - want := []string{"devLocalVectorStore/a", "devLocalVectorStore/b"} - - if got := names(is); !slices.Equal(got, want) { - t.Errorf("got %v, want %v", got, want) + if g := ind.Name(); g != name { + t.Errorf("got %q, want %q", g, name) } - if got := names(rs); !slices.Equal(got, want) { - t.Errorf("got %v, want %v", got, want) + if g := ret.Name(); g != name { + t.Errorf("got %q, want %q", g, name) } } diff --git a/go/plugins/pinecone/genkit.go b/go/plugins/pinecone/genkit.go index 8dd72347dc..a6f23919f1 100644 --- a/go/plugins/pinecone/genkit.go +++ b/go/plugins/pinecone/genkit.go @@ -23,6 +23,7 @@ import ( "maps" "slices" "strings" + "sync" "time" "github.com/firebase/genkit/go/ai" @@ -38,12 +39,39 @@ const provider = "pinecone" // documents in pinecone. const defaultTextKey = "_content" -// Config provides configuration options for the Init function. +var state struct { + mu sync.Mutex + initted bool + client *client +} + +// Init initializes the Pinecone plugin. +// If apiKey is the empty string, it is read from the PINECONE_API_KEY +// environment variable. +func Init(ctx context.Context, apiKey string) (err error) { + // Init initializes the Pinecone plugin. + state.mu.Lock() + defer state.mu.Unlock() + if state.initted { + panic("pinecone.Init already called") + } + defer func() { + if err != nil { + err = fmt.Errorf("pinecone.Init: %w", err) + } + }() + + client, err := newClient(ctx, apiKey) + if err != nil { + return err + } + state.client = client + state.initted = true + return nil +} + +// Config provides configuration options for [DefineIndexer] and [DefineRetriever]. type Config struct { - // API key for Pinecone. - // If it is the empty string, it is read from the PINECONE_API_KEY - // environment variable. - APIKey string // The index ID to use. IndexID string // Embedder to use. Required. @@ -54,45 +82,52 @@ type Config struct { TextKey string } -// Init initializes the Pinecone plugin. -func Init(ctx context.Context, cfg Config) (err error) { - defer func() { - if err != nil { - err = fmt.Errorf("pinecone.Init: %w", err) - } - }() +func DefineIndexer(ctx context.Context, cfg Config) (*ai.IndexerAction, error) { + ds, err := newDocStore(ctx, cfg) + if err != nil { + return nil, err + } + return ai.DefineIndexer(provider, cfg.IndexID, ds.Index), nil +} +func DefineRetriever(ctx context.Context, cfg Config) (*ai.RetrieverAction, error) { + ds, err := newDocStore(ctx, cfg) + if err != nil { + return nil, err + } + return ai.DefineRetriever(provider, cfg.IndexID, ds.Retrieve), nil +} + +func newDocStore(ctx context.Context, cfg Config) (*docStore, error) { + state.mu.Lock() + defer state.mu.Unlock() + if !state.initted { + panic("pinecone.Init not called") + } if cfg.IndexID == "" { - return errors.New("IndexID required") + return nil, errors.New("IndexID required") } if cfg.Embedder == nil { - return errors.New("Embedder required") + return nil, errors.New("Embedder required") } - - client, err := newClient(ctx, cfg.APIKey) + // TODO(jba): cache these calls so we don't make them twice for the indexer and retriever. + indexData, err := state.client.indexData(ctx, cfg.IndexID) if err != nil { - return err - } - indexData, err := client.indexData(ctx, cfg.IndexID) - if err != nil { - return err + return nil, err } - index, err := client.index(ctx, indexData.Host) + index, err := state.client.index(ctx, indexData.Host) if err != nil { - return err + return nil, err } if cfg.TextKey == "" { cfg.TextKey = defaultTextKey } - r := &docStore{ + return &docStore{ index: index, embedder: cfg.Embedder, embedderOptions: cfg.EmbedderOptions, textKey: cfg.TextKey, - } - ai.DefineIndexer(provider, cfg.IndexID, r.Index) - ai.DefineRetriever(provider, cfg.IndexID, r.Retrieve) - return nil + }, nil } // Indexer returns the indexer with the given index name. diff --git a/go/plugins/pinecone/genkit_test.go b/go/plugins/pinecone/genkit_test.go index 00977ecc4a..7597d3c9a8 100644 --- a/go/plugins/pinecone/genkit_test.go +++ b/go/plugins/pinecone/genkit_test.go @@ -72,12 +72,19 @@ func TestGenkit(t *testing.T) { embedder.Register(d2, v2) embedder.Register(d3, v3) + if err := Init(ctx, *testAPIKey); err != nil { + t.Fatal(err) + } cfg := Config{ - APIKey: *testAPIKey, IndexID: *testIndex, Embedder: ai.DefineEmbedder("fake", "embedder3", embedder.Embed), } - if err := Init(ctx, cfg); err != nil { + indexer, err := DefineIndexer(ctx, cfg) + if err != nil { + t.Fatal(err) + } + retriever, err := DefineRetriever(ctx, cfg) + if err != nil { t.Fatal(err) } @@ -90,7 +97,7 @@ func TestGenkit(t *testing.T) { Options: indexerOptions, } t.Logf("index flag = %q, indexData.Host = %q", *testIndex, indexData.Host) - err = ai.Index(ctx, Indexer(*testIndex), indexerReq) + err = ai.Index(ctx, indexer, indexerReq) if err != nil { t.Fatalf("Index operation failed: %v", err) } @@ -127,7 +134,7 @@ func TestGenkit(t *testing.T) { Document: d1, Options: retrieverOptions, } - retrieverResp, err := ai.Retrieve(ctx, Retriever(*testIndex), retrieverReq) + retrieverResp, err := ai.Retrieve(ctx, retriever, retrieverReq) if err != nil { t.Fatalf("Retrieve operation failed: %v", err) } diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 084eb7c6a3..55b9d6fd63 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -16,9 +16,9 @@ package vertexai import ( "context" - "errors" "fmt" "runtime" + "sync" aiplatform "cloud.google.com/go/aiplatform/apiv1" "cloud.google.com/go/vertexai/genai" @@ -30,84 +30,77 @@ import ( const provider = "vertexai" -// Config provides configuration options for the Init function. -type Config struct { - // The project holding the resources. - ProjectID string - // The location of the resources. - // Defaults to "us-central1". - Location string - // Generative models to provide. - Models []string - // Embedding models to provide. - Embedders []string +var state struct { + mu sync.Mutex + initted bool + projectID string + location string + gclient *genai.Client + pclient *aiplatform.PredictionClient } -func Init(ctx context.Context, cfg Config) (err error) { - defer func() { - if err != nil { - err = fmt.Errorf("vertexai.Init: %w", err) - } - }() - - if cfg.ProjectID == "" { - return errors.New("missing ProjectID") - } - if cfg.Location == "" { - cfg.Location = "us-central1" - } - // TODO(#345): call ListModels. See googleai.go. - if len(cfg.Models) == 0 && len(cfg.Embedders) == 0 { - return errors.New("need at least one model or embedder") - } - - if err := initModels(ctx, cfg); err != nil { - return err +// Init initializes the plugin. +// After calling this function, call [DefineModel] and [DefineEmbedder] to create and register +// generative models and embedders. +func Init(ctx context.Context, projectID, location string) error { + state.mu.Lock() + defer state.mu.Unlock() + if state.initted { + panic("vertexai.Init already called") } - return initEmbedders(ctx, cfg) -} - -func initModels(ctx context.Context, cfg Config) error { + state.projectID = projectID + state.location = location + var err error // Client for Gemini SDK. - gclient, err := genai.NewClient(ctx, cfg.ProjectID, cfg.Location) + state.gclient, err = genai.NewClient(ctx, projectID, location) if err != nil { return err } - for _, name := range cfg.Models { - meta := &ai.ModelMetadata{ - Label: "Vertex AI - " + name, - Supports: ai.ModelCapabilities{ - Multiturn: true, - }, - } - g := &generator{model: name, client: gclient} - ai.DefineModel(provider, name, meta, g.generate) - } - return nil -} - -func initEmbedders(ctx context.Context, cfg Config) error { - // Client for prediction SDK. - endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", cfg.Location) + endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location) numConns := max(runtime.GOMAXPROCS(0), 4) o := []option.ClientOption{ option.WithEndpoint(endpoint), option.WithGRPCConnectionPool(numConns), } - pclient, err := aiplatform.NewPredictionClient(ctx, o...) + state.pclient, err = aiplatform.NewPredictionClient(ctx, o...) if err != nil { return err } - for _, name := range cfg.Embedders { - fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", cfg.ProjectID, cfg.Location, name) - ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) { - return embed(ctx, fullName, pclient, req) - }) - } + state.initted = true return nil } +// DefineModel defines a model with the given name. +func DefineModel(name string) *ai.ModelAction { + state.mu.Lock() + defer state.mu.Unlock() + if !state.initted { + panic("vertexai.Init not called") + } + meta := &ai.ModelMetadata{ + Label: "Vertex AI - " + name, + Supports: ai.ModelCapabilities{ + Multiturn: true, + }, + } + g := &generator{model: name, client: state.gclient} + return ai.DefineModel(provider, name, meta, g.generate) +} + +// DefineModel defines an embedder with the given name. +func DefineEmbedder(name string) *ai.EmbedderAction { + state.mu.Lock() + defer state.mu.Unlock() + if !state.initted { + panic("vertexai.Init not called") + } + fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", state.projectID, state.location, name) + return ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) { + return embed(ctx, fullName, state.pclient, req) + }) +} + // Model returns the [ai.ModelAction] with the given name. // It returns nil if the model was not configured. func Model(name string) *ai.ModelAction { diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index 8f928484b2..f936839fc9 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -40,15 +40,12 @@ func TestLive(t *testing.T) { ctx := context.Background() const modelName = "gemini-1.0-pro" const embedderName = "textembedding-gecko" - err := vertexai.Init(ctx, vertexai.Config{ - ProjectID: *projectID, - Location: *location, - Models: []string{modelName}, - Embedders: []string{embedderName}, - }) + err := vertexai.Init(ctx, *projectID, *location) if err != nil { t.Fatal(err) } + model := vertexai.DefineModel(modelName) + embedder := vertexai.DefineEmbedder(embedderName) toolDef := &ai.ToolDefinition{ Name: "exponentiation", @@ -94,7 +91,7 @@ func TestLive(t *testing.T) { }, } - resp, err := ai.Generate(ctx, vertexai.Model(modelName), req, nil) + resp, err := ai.Generate(ctx, model, req, nil) if err != nil { t.Fatal(err) } @@ -158,7 +155,7 @@ func TestLive(t *testing.T) { Tools: []*ai.ToolDefinition{toolDef}, } - resp, err := ai.Generate(ctx, vertexai.Model(modelName), req, nil) + resp, err := ai.Generate(ctx, model, req, nil) if err != nil { t.Fatal(err) } @@ -169,7 +166,7 @@ func TestLive(t *testing.T) { } }) t.Run("embedder", func(t *testing.T) { - out, err := ai.Embed(ctx, vertexai.Embedder(embedderName), &ai.EmbedRequest{ + out, err := ai.Embed(ctx, embedder, &ai.EmbedRequest{ Document: ai.DocumentFromText("time flies like an arrow", nil), }) if err != nil { diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 9cfda041f1..fa42fbcc90 100755 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -104,7 +104,7 @@ func main() { fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.") os.Exit(1) } - err := googleai.Init(context.Background(), googleai.Config{APIKey: apiKey}) + err := googleai.Init(context.Background(), apiKey) if err != nil { log.Fatal(err) } @@ -113,7 +113,7 @@ func main() { AllowAdditionalProperties: false, DoNotReference: true, } - g := googleai.Model("gemini-1.5-pro") + g := googleai.DefineModel("gemini-1.5-pro") simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate, dotprompt.Config{ ModelAction: g, @@ -263,7 +263,7 @@ func main() { } return out, nil }) - if err := genkit.StartFlowServer(""); err != nil { + if err := genkit.Init(nil); err != nil { log.Fatal(err) } } diff --git a/go/samples/flow-sample1/main.go b/go/samples/flow-sample1/main.go index af90d0f077..14490a7b38 100644 --- a/go/samples/flow-sample1/main.go +++ b/go/samples/flow-sample1/main.go @@ -83,7 +83,7 @@ func main() { return fmt.Sprintf("done: %d, streamed: %d times", count, i), nil }) - if err := genkit.StartFlowServer(""); err != nil { + if err := genkit.Init(nil); err != nil { log.Fatal(err) } } diff --git a/go/samples/menu/main.go b/go/samples/menu/main.go index fef2cf2434..e7ecbcdbdb 100644 --- a/go/samples/menu/main.go +++ b/go/samples/menu/main.go @@ -26,11 +26,6 @@ import ( "github.com/invopop/jsonschema" ) -const ( - geminiPro = "gemini-1.0-pro" - embeddingGecko = "textembedding-gecko" -) - // menuItem is the data model for an item on the menu. type menuItem struct { Title string `json:"title" jsonschema_description:"The name of the menu item"` @@ -79,16 +74,13 @@ func main() { } ctx := context.Background() - err := vertexai.Init(ctx, vertexai.Config{ - ProjectID: projectID, - Location: os.Getenv("GCLOUD_LOCATION"), - Models: []string{geminiPro, geminiPro + "-vision"}, - Embedders: []string{embeddingGecko}, - }) + err := vertexai.Init(ctx, projectID, os.Getenv("GCLOUD_LOCATION")) if err != nil { log.Fatal(err) } - model := vertexai.Model(geminiPro) + model := vertexai.DefineModel("gemini-1.0-pro") + visionModel := vertexai.DefineModel("gemini-1.0-pro-vision") + embedder := vertexai.DefineEmbedder("textembedding-gecko") if err := setup01(ctx, model); err != nil { log.Fatal(err) } @@ -100,23 +92,22 @@ func main() { log.Fatal(err) } - indexers, retrievers, err := localvec.Init(ctx, localvec.Config{ - Stores: []localvec.StoreConfig{ - {Name: "go-menu-items", Embedder: vertexai.Embedder(embeddingGecko)}, - }, - }) + err = localvec.Init() if err != nil { log.Fatal(err) } - if err := setup04(ctx, indexers[0], retrievers[0], model); err != nil { + indexer, retriever, err := localvec.DefineStore("go-menu_items", localvec.Config{ + Embedder: embedder, + }) + if err := setup04(ctx, indexer, retriever, model); err != nil { log.Fatal(err) } - if err := setup05(ctx, model, vertexai.Model(geminiPro+"-vision")); err != nil { + if err := setup05(ctx, model, visionModel); err != nil { log.Fatal(err) } - if err := genkit.StartFlowServer(""); err != nil { + if err := genkit.Init(nil); err != nil { log.Fatal(err) } } diff --git a/go/samples/rag/main.go b/go/samples/rag/main.go index 390f4957d4..f9ae4f5b5f 100644 --- a/go/samples/rag/main.go +++ b/go/samples/rag/main.go @@ -75,14 +75,24 @@ func main() { fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.") os.Exit(1) } - err := googleai.Init(context.Background(), googleai.Config{APIKey: apiKey}) + err := googleai.Init(context.Background(), apiKey) if err != nil { log.Fatal(err) } + model := googleai.DefineModel("gemini-1.0-pro") + embedder := googleai.DefineEmbedder("embedding-001") + if err := localvec.Init(); err != nil { + log.Fatal(err) + } + indexer, retriever, err := localvec.DefineStore("simpleQa", localvec.Config{Embedder: embedder}) + if err != nil { + log.Fatal(err) + } + simpleQaPrompt, err := dotprompt.Define("simpleQaPrompt", simpleQaPromptTemplate, dotprompt.Config{ - ModelAction: googleai.Model("gemini-1.0-pro"), + ModelAction: model, InputSchema: jsonschema.Reflect(simpleQaPromptInput{}), OutputFormat: ai.OutputFormatText, }, @@ -91,18 +101,6 @@ func main() { log.Fatal(err) } - indexers, retrievers, err := localvec.Init(context.Background(), localvec.Config{ - Stores: []localvec.StoreConfig{ - { - Name: "simpleQa", - Embedder: googleai.Embedder("embedding-001"), - }, - }, - }) - if err != nil { - log.Fatal(err) - } - genkit.DefineFlow("simpleQaFlow", func(ctx context.Context, input *simpleQaInput) (string, error) { d1 := ai.DocumentFromText("Paris is the capital of France", nil) d2 := ai.DocumentFromText("USA is the largest importer of coffee", nil) @@ -111,7 +109,7 @@ func main() { indexerReq := &ai.IndexerRequest{ Documents: []*ai.Document{d1, d2, d3}, } - err := ai.Index(ctx, indexers[0], indexerReq) + err := ai.Index(ctx, indexer, indexerReq) if err != nil { return "", err } @@ -120,7 +118,7 @@ func main() { retrieverReq := &ai.RetrieverRequest{ Document: dRequest, } - response, err := ai.Retrieve(ctx, retrievers[0], retrieverReq) + response, err := ai.Retrieve(ctx, retriever, retrieverReq) if err != nil { return "", err } @@ -152,7 +150,7 @@ func main() { return text, nil }) - if err := genkit.StartFlowServer(""); err != nil { + if err := genkit.Init(nil); err != nil { log.Fatal(err) } } From bb7f933ed177e735e573510e637f6d285af6fe29 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 18 Jun 2024 03:28:43 -0400 Subject: [PATCH 2/4] reviewer comments --- go/genkit/genkit.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index d80a633521..94ffc7e782 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -24,7 +24,13 @@ import ( ) // Options are options to [Init]. -type Options = core.Options +type Options struct { + DevAddr string + // If "-", do not start a FlowServer. + // Otherwise, start a FlowServer on the given address, or the + // default of ":3400" if empty. + FlowAddr string +} // Init initializes Genkit. // After it is called, no further actions can be defined. @@ -32,10 +38,9 @@ type Options = core.Options // Init starts servers depending on the value of the GENKIT_ENV // environment variable and the provided options. // - // If GENKIT_ENV = "dev", a development server is started // in a separate goroutine at the address in opts.DevAddr, or the default -// if empty. +// of ":3100" if empty. // // If opts.FlowAddr is a value other than "-", a flow server is started (see [StartFlowServer]) // and the call to Init waits for the server to shut down. @@ -44,7 +49,7 @@ type Options = core.Options // Thus Init(nil) will start a dev server in the "dev" environment, will always start // a flow server, and will pause execution until the flow server terminates. func Init(opts *Options) error { - return core.InternalInit(opts) + return core.InternalInit((*core.Options)(opts)) } // DefineFlow creates a Flow that runs fn, and registers it as an action. From 495c6771e5974f0d64c8d1a7caa331ac60aed94e Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 19 Jun 2024 11:46:47 -0400 Subject: [PATCH 3/4] reviewer comments --- go/core/servers.go | 48 ++++++++++++++++++++++++++--------------- go/core/servers_test.go | 2 +- go/genkit/genkit.go | 16 ++++++++------ 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/go/core/servers.go b/go/core/servers.go index 41458272eb..019adc96b4 100644 --- a/go/core/servers.go +++ b/go/core/servers.go @@ -51,8 +51,8 @@ import ( // a dev server. // // StartFlowServer always returns a non-nil error, the one returned by http.ListenAndServe. -func StartFlowServer(addr string) error { - return startProdServer(addr) +func StartFlowServer(addr string, flows []string) error { + return startProdServer(addr, flows) } // InternalInit is for use by the genkit package only. @@ -65,32 +65,37 @@ func InternalInit(opts *Options) error { if currentEnvironment() == EnvironmentDev { go func() { - err := startDevServer(opts.DevAddr) + err := startDevServer() slog.Error("dev server stopped", "err", err) }() } if opts.FlowAddr == "-" { return nil } - return StartFlowServer(opts.FlowAddr) + return StartFlowServer(opts.FlowAddr, opts.Flows) } // Options are options to [InternalInit]. type Options struct { - DevAddr string // If "-", do not start a FlowServer. // Otherwise, start a FlowServer on the given address, or the // default if empty. FlowAddr string + // The names of flows to serve. + // If empty, all registered flows are served. + Flows []string } -// startDevServer starts the development server (reflection API) listening at the given address. -// If addr is "", it uses the value of the environment variable GENKIT_REFLECTION_PORT -// for the port, and if that is empty it uses ":3100". +// startDevServer starts the development server (reflection API) listening at the +// value of the environment variable GENKIT_REFLECTION_PORT for the port, or ":3100" +// if it is empty. // startDevServer always returns a non-nil error, the one returned by http.ListenAndServe. -func startDevServer(addr string) error { +func startDevServer() error { slog.Info("starting dev server") - addr = serverAddress(addr, "GENKIT_REFLECTION_PORT", "127.0.0.1:3100") + // Don't use "localhost" here. That only binds the IPv4 address, and the genkit tool + // wants to connect to the IPv6 address even when you tell it to use "localhost". + // Omitting the host works. + addr := serverAddress("", "GENKIT_REFLECTION_PORT", "127.0.0.1:3100") mux := newDevServeMux(globalRegistry) return listenAndServe(addr, mux) } @@ -266,14 +271,17 @@ type listFlowStatesResult struct { // startProdServer always returns a non-nil error, the one returned by http.ListenAndServe. // // To construct a server with additional routes, use [NewFlowServeMux]. -func startProdServer(addr string) error { +func startProdServer(addr string, flows []string) error { slog.Info("starting flow server") addr = serverAddress(addr, "PORT", "127.0.0.1:3400") - mux := NewFlowServeMux() + mux := NewFlowServeMux(flows) return listenAndServe(addr, mux) } -// NewFlowServeMux constructs a [net/http.ServeMux] where each defined flow is a route. +// NewFlowServeMux constructs a [net/http.ServeMux]. +// If flows is non-empty, the each of the named flows is registered as a route. +// Otherwise, all defined flows are registered. +// // All routes take a single query parameter, "stream", which if true will stream the // flow's results back to the client. (Not all flows support streaming, however.) // @@ -282,14 +290,20 @@ func startProdServer(addr string) error { // // mainMux := http.NewServeMux() // mainMux.Handle("POST /flow/", http.StripPrefix("/flow/", NewFlowServeMux())) -func NewFlowServeMux() *http.ServeMux { - return newFlowServeMux(globalRegistry) +func NewFlowServeMux(flows []string) *http.ServeMux { + return newFlowServeMux(globalRegistry, flows) } -func newFlowServeMux(r *registry) *http.ServeMux { +func newFlowServeMux(r *registry, flows []string) *http.ServeMux { mux := http.NewServeMux() + m := map[string]bool{} + for _, f := range flows { + m[f] = true + } for _, f := range r.listFlows() { - handle(mux, "POST /"+f.Name(), nonDurableFlowHandler(f)) + if len(flows) == 0 || m[f.Name()] { + handle(mux, "POST /"+f.Name(), nonDurableFlowHandler(f)) + } } return mux } diff --git a/go/core/servers_test.go b/go/core/servers_test.go index 3db86ca73a..130566399e 100644 --- a/go/core/servers_test.go +++ b/go/core/servers_test.go @@ -129,7 +129,7 @@ func TestProdServer(t *testing.T) { defineFlow(r, "inc", func(_ context.Context, i int, _ NoStream) (int, error) { return i + 1, nil }) - srv := httptest.NewServer(newFlowServeMux(r)) + srv := httptest.NewServer(newFlowServeMux(r, nil)) defer srv.Close() check := func(t *testing.T, input string, wantStatus, wantResult int) { diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 94ffc7e782..50a162c083 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -25,11 +25,13 @@ import ( // Options are options to [Init]. type Options struct { - DevAddr string // If "-", do not start a FlowServer. // Otherwise, start a FlowServer on the given address, or the // default of ":3400" if empty. FlowAddr string + // The names of flows to serve. + // If empty, all registered flows are served. + Flows []string } // Init initializes Genkit. @@ -149,11 +151,13 @@ var errStop = errors.New("stop") // a dev server. // // StartFlowServer always returns a non-nil error, the one returned by http.ListenAndServe. -func StartFlowServer(addr string) error { - return core.StartFlowServer(addr) +func StartFlowServer(addr string, flows []string) error { + return core.StartFlowServer(addr, flows) } -// NewFlowServeMux constructs a [net/http.ServeMux] where each defined flow is a route. +// NewFlowServeMux constructs a [net/http.ServeMux]. +// If flows is non-empty, the each of the named flows is registered as a route. +// Otherwise, all defined flows are registered. // All routes take a single query parameter, "stream", which if true will stream the // flow's results back to the client. (Not all flows support streaming, however.) // @@ -162,6 +166,6 @@ func StartFlowServer(addr string) error { // // mainMux := http.NewServeMux() // mainMux.Handle("POST /flow/", http.StripPrefix("/flow/", NewFlowServeMux())) -func NewFlowServeMux() *http.ServeMux { - return core.NewFlowServeMux() +func NewFlowServeMux(flows []string) *http.ServeMux { + return core.NewFlowServeMux(flows) } From f30d1e6a0c1a2d05f5d66636b3e4f743aca9ea1f Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Fri, 21 Jun 2024 02:42:53 -0400 Subject: [PATCH 4/4] fix test --- go/plugins/localvec/localvec_test.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/go/plugins/localvec/localvec_test.go b/go/plugins/localvec/localvec_test.go index c108eebaec..206bc73d01 100644 --- a/go/plugins/localvec/localvec_test.go +++ b/go/plugins/localvec/localvec_test.go @@ -197,11 +197,12 @@ func TestInit(t *testing.T) { if err != nil { t.Fatal(err) } - if g := ind.Name(); g != name { - t.Errorf("got %q, want %q", g, name) + want := "devLocalVectorStore/" + name + if g := ind.Name(); g != want { + t.Errorf("got %q, want %q", g, want) } - if g := ret.Name(); g != name { - t.Errorf("got %q, want %q", g, name) + if g := ret.Name(); g != want { + t.Errorf("got %q, want %q", g, want) } }