From 6240c139043683e3bcb571e53c12c09784c4cb49 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Jul 2024 11:57:00 -0400 Subject: [PATCH] [Go] add support GCP client options Allow users to pass client options to the GCP clients for the googleai and vertexai plugins. To allow this and also to enable additional config in the future without introducing breaking changes, add a Config struct to each plugin, and have Init take it as an argument. Also, slightly rewrite the samples to avoid redundant code for looking up environment variables. --- go/internal/doc-snippets/dotprompt.go | 6 ++-- go/internal/doc-snippets/init/main.go | 6 ++-- go/internal/doc-snippets/models.go | 7 ++-- go/plugins/googleai/googleai.go | 19 +++++++++-- go/plugins/googleai/googleai_test.go | 2 +- go/plugins/vertexai/vertexai.go | 47 ++++++++++++++++++--------- go/plugins/vertexai/vertexai_test.go | 3 +- go/samples/coffee-shop/main.go | 2 +- go/samples/menu/main.go | 9 +---- go/samples/rag/main.go | 9 +---- 10 files changed, 65 insertions(+), 45 deletions(-) diff --git a/go/internal/doc-snippets/dotprompt.go b/go/internal/doc-snippets/dotprompt.go index bd688db997..f7262d66b2 100644 --- a/go/internal/doc-snippets/dotprompt.go +++ b/go/internal/doc-snippets/dotprompt.go @@ -35,9 +35,9 @@ func dot01() error { ctx := context.Background() // The .prompt file specifies vertexai/gemini-1.5-pro, so make sure it's set - // up: - projectID := os.Getenv("GCLOUD_PROJECT") - vertexai.Init(ctx, projectID, "us-central1") + // up. + // Default to the project in GCLOUD_PROJECT and the location "us-central1". + vertexai.Init(ctx, nil) vertexai.DefineModel("gemini-1.5-pro", nil) type GreetingPromptInput struct { diff --git a/go/internal/doc-snippets/init/main.go b/go/internal/doc-snippets/init/main.go index 0343558844..62c2c81c0a 100644 --- a/go/internal/doc-snippets/init/main.go +++ b/go/internal/doc-snippets/init/main.go @@ -30,11 +30,11 @@ import ( func main() { ctx := context.Background() - // Initialize the Google AI plugin. When you pass an empty string for the - // apiKey parameter, the Google AI plugin will use the value from the + // Initialize the Google AI plugin. When you pass nil for the + // Config parameter, the Google AI plugin will get the API key from the // GOOGLE_GENAI_API_KEY environment variable, which is the recommended // practice. - if err := googleai.Init(ctx, ""); err != nil { + if err := googleai.Init(ctx, nil); err != nil { log.Fatal(err) } diff --git a/go/internal/doc-snippets/models.go b/go/internal/doc-snippets/models.go index e7258b6f54..dd3262271d 100644 --- a/go/internal/doc-snippets/models.go +++ b/go/internal/doc-snippets/models.go @@ -34,12 +34,13 @@ var gemini15pro *ai.Model func m1() error { // !+init - projectID := os.Getenv("GCLOUD_PROJECT") - if err := vertexai.Init(ctx, projectID, "us-central1"); err != nil { + // Default to the value of GCLOUD_PROJECT for the project, + // and "us-central1" for the location. + // To specify these values directly, pass a vertexai.Config value to Init. + if err := vertexai.Init(ctx, nil); err != nil { return err } // !-init - _ = projectID // !+model gemini15pro := vertexai.Model("gemini-1.5-pro") diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 8b192542a4..69bf559f63 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -59,10 +59,23 @@ var ( knownEmbedders = []string{"text-embedding-004", "embedding-001"} ) +// Config is the configuration for the plugin. +type Config struct { + // The API key to access the service. + // If empty, the values of the environment variables GOOGLE_GENAI_API_KEY + // and GOOGLE_API_KEY will be consulted, in that order. + APIKey string + // Options to the Google AI client. + ClientOptions []option.ClientOption +} + // Init initializes the plugin and all known models and embedders. // After calling Init, you may call [DefineModel] and [DefineEmbedder] to create // and register any additional generative models and embedders -func Init(ctx context.Context, apiKey string) (err error) { +func Init(ctx context.Context, cfg *Config) (err error) { + if cfg == nil { + cfg = &Config{} + } state.mu.Lock() defer state.mu.Unlock() if state.initted { @@ -74,6 +87,7 @@ func Init(ctx context.Context, apiKey string) (err error) { } }() + apiKey := cfg.APIKey if apiKey == "" { apiKey = os.Getenv("GOOGLE_GENAI_API_KEY") if apiKey == "" { @@ -84,7 +98,8 @@ func Init(ctx context.Context, apiKey string) (err error) { } } - client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) + opts := append([]option.ClientOption{option.WithAPIKey(apiKey)}, cfg.ClientOptions...) + client, err := genai.NewClient(ctx, opts...) if err != nil { return err } diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index b2656d1cb8..45317d2e26 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -42,7 +42,7 @@ func TestLive(t *testing.T) { t.Skip("-all provided") } ctx := context.Background() - err := googleai.Init(ctx, *apiKey) + err := googleai.Init(ctx, &googleai.Config{APIKey: *apiKey}) if err != nil { t.Fatal(err) } diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index a822189a59..ea472dc5cf 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -72,36 +72,53 @@ var state struct { pclient *aiplatform.PredictionClient } +// Config is the configuration for the plugin. +type Config struct { + // The cloud project to use for Vertex AI. + // If empty, the values of the environment variables GCLOUD_PROJECT + // and GOOGLE_CLOUD_PROJECT will be consulted, in that order. + ProjectID string + // The location of the Vertex AI service. The default is "us-central1". + Location string + // Options to the Vertex AI client. + ClientOptions []option.ClientOption +} + // Init initializes the plugin and all known models and embedders. // After calling Init, you may call [DefineModel] and [DefineEmbedder] to create // and register any additional generative models and embedders -func Init(ctx context.Context, projectID, location string) error { +func Init(ctx context.Context, cfg *Config) error { + if cfg == nil { + cfg = &Config{} + } state.mu.Lock() defer state.mu.Unlock() if state.initted { panic("vertexai.Init already called") } - if projectID == "" { - projectID = os.Getenv("GCLOUD_PROJECT") - if projectID == "" { - projectID = os.Getenv("GOOGLE_CLOUD_PROJECT") - } - if projectID == "" { - return fmt.Errorf("vertexai.Init: Vertex AI requires setting GCLOUD_PROJECT or GOOGLE_CLOUD_PROJECT in the environment") - } + + state.projectID = cfg.ProjectID + if state.projectID == "" { + state.projectID = os.Getenv("GCLOUD_PROJECT") + } + if state.projectID == "" { + state.projectID = os.Getenv("GOOGLE_CLOUD_PROJECT") } - state.projectID = projectID - if location == "" { - location = "us-central1" + if state.projectID == "" { + return fmt.Errorf("vertexai.Init: Vertex AI requires setting GCLOUD_PROJECT or GOOGLE_CLOUD_PROJECT in the environment") + } + + state.location = cfg.Location + if state.location == "" { + state.location = "us-central1" } - state.location = location var err error // Client for Gemini SDK. - state.gclient, err = genai.NewClient(ctx, projectID, location) + state.gclient, err = genai.NewClient(ctx, state.projectID, state.location, cfg.ClientOptions...) if err != nil { return err } - endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location) + endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", state.location) numConns := max(runtime.GOMAXPROCS(0), 4) o := []option.ClientOption{ option.WithEndpoint(endpoint), diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index b8675d647c..4238d2fe75 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -39,7 +39,8 @@ func TestLive(t *testing.T) { } ctx := context.Background() const modelName = "gemini-1.0-pro" - err := vertexai.Init(ctx, *projectID, *location) + const embedderName = "textembedding-gecko" + err := vertexai.Init(ctx, &vertexai.Config{ProjectID: *projectID, Location: *location}) if err != nil { t.Fatal(err) } diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 1cb4f90cd5..1649e9c77f 100755 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -97,7 +97,7 @@ type testAllCoffeeFlowsOutput struct { } func main() { - if err := googleai.Init(context.Background(), ""); err != nil { + if err := googleai.Init(context.Background(), nil); err != nil { log.Fatal(err) } diff --git a/go/samples/menu/main.go b/go/samples/menu/main.go index 42d6e3ba84..093c573519 100644 --- a/go/samples/menu/main.go +++ b/go/samples/menu/main.go @@ -16,7 +16,6 @@ package main import ( "context" - "fmt" "log" "os" @@ -67,14 +66,8 @@ type textMenuQuestionInput struct { var textMenuQuestionInputSchema = jsonschema.Reflect(textMenuQuestionInput{}) func main() { - projectID := os.Getenv("GCLOUD_PROJECT") - if projectID == "" { - fmt.Fprintln(os.Stderr, "menu example requires setting GCLOUD_PROJECT in the environment.") - os.Exit(1) - } - ctx := context.Background() - err := vertexai.Init(ctx, projectID, os.Getenv("GCLOUD_LOCATION")) + err := vertexai.Init(ctx, &vertexai.Config{Location: os.Getenv("GCLOUD_LOCATION")}) if err != nil { log.Fatal(err) } diff --git a/go/samples/rag/main.go b/go/samples/rag/main.go index 857cc59ff6..44c5a28735 100644 --- a/go/samples/rag/main.go +++ b/go/samples/rag/main.go @@ -38,7 +38,6 @@ import ( "context" "fmt" "log" - "os" "strings" "github.com/firebase/genkit/go/ai" @@ -69,13 +68,7 @@ type simpleQaPromptInput struct { } func main() { - apiKey := os.Getenv("GOOGLE_GENAI_API_KEY") - if apiKey == "" { - fmt.Fprintln(os.Stderr, "rag example requires setting GOOGLE_GENAI_API_KEY in the environment.") - fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.") - os.Exit(1) - } - err := googleai.Init(context.Background(), apiKey) + err := googleai.Init(context.Background(), nil) if err != nil { log.Fatal(err) }