diff --git a/go/internal/doc-snippets/dotprompt.go b/go/internal/doc-snippets/dotprompt.go index bd688db997..f7262d66b2 100644 --- a/go/internal/doc-snippets/dotprompt.go +++ b/go/internal/doc-snippets/dotprompt.go @@ -35,9 +35,9 @@ func dot01() error { ctx := context.Background() // The .prompt file specifies vertexai/gemini-1.5-pro, so make sure it's set - // up: - projectID := os.Getenv("GCLOUD_PROJECT") - vertexai.Init(ctx, projectID, "us-central1") + // up. + // Default to the project in GCLOUD_PROJECT and the location "us-central1". + vertexai.Init(ctx, nil) vertexai.DefineModel("gemini-1.5-pro", nil) type GreetingPromptInput struct { diff --git a/go/internal/doc-snippets/init/main.go b/go/internal/doc-snippets/init/main.go index 0343558844..62c2c81c0a 100644 --- a/go/internal/doc-snippets/init/main.go +++ b/go/internal/doc-snippets/init/main.go @@ -30,11 +30,11 @@ import ( func main() { ctx := context.Background() - // Initialize the Google AI plugin. When you pass an empty string for the - // apiKey parameter, the Google AI plugin will use the value from the + // Initialize the Google AI plugin. When you pass nil for the + // Config parameter, the Google AI plugin will get the API key from the // GOOGLE_GENAI_API_KEY environment variable, which is the recommended // practice. - if err := googleai.Init(ctx, ""); err != nil { + if err := googleai.Init(ctx, nil); err != nil { log.Fatal(err) } diff --git a/go/internal/doc-snippets/models.go b/go/internal/doc-snippets/models.go index e7258b6f54..dd3262271d 100644 --- a/go/internal/doc-snippets/models.go +++ b/go/internal/doc-snippets/models.go @@ -34,12 +34,13 @@ var gemini15pro *ai.Model func m1() error { // !+init - projectID := os.Getenv("GCLOUD_PROJECT") - if err := vertexai.Init(ctx, projectID, "us-central1"); err != nil { + // Default to the value of GCLOUD_PROJECT for the project, + // and "us-central1" for the location. + // To specify these values directly, pass a vertexai.Config value to Init. + if err := vertexai.Init(ctx, nil); err != nil { return err } // !-init - _ = projectID // !+model gemini15pro := vertexai.Model("gemini-1.5-pro") diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 8b192542a4..69bf559f63 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -59,10 +59,23 @@ var ( knownEmbedders = []string{"text-embedding-004", "embedding-001"} ) +// Config is the configuration for the plugin. +type Config struct { + // The API key to access the service. + // If empty, the values of the environment variables GOOGLE_GENAI_API_KEY + // and GOOGLE_API_KEY will be consulted, in that order. + APIKey string + // Options to the Google AI client. + ClientOptions []option.ClientOption +} + // Init initializes the plugin and all known models and embedders. // After calling Init, you may call [DefineModel] and [DefineEmbedder] to create // and register any additional generative models and embedders -func Init(ctx context.Context, apiKey string) (err error) { +func Init(ctx context.Context, cfg *Config) (err error) { + if cfg == nil { + cfg = &Config{} + } state.mu.Lock() defer state.mu.Unlock() if state.initted { @@ -74,6 +87,7 @@ func Init(ctx context.Context, apiKey string) (err error) { } }() + apiKey := cfg.APIKey if apiKey == "" { apiKey = os.Getenv("GOOGLE_GENAI_API_KEY") if apiKey == "" { @@ -84,7 +98,8 @@ func Init(ctx context.Context, apiKey string) (err error) { } } - client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) + opts := append([]option.ClientOption{option.WithAPIKey(apiKey)}, cfg.ClientOptions...) + client, err := genai.NewClient(ctx, opts...) if err != nil { return err } diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index b2656d1cb8..45317d2e26 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -42,7 +42,7 @@ func TestLive(t *testing.T) { t.Skip("-all provided") } ctx := context.Background() - err := googleai.Init(ctx, *apiKey) + err := googleai.Init(ctx, &googleai.Config{APIKey: *apiKey}) if err != nil { t.Fatal(err) } diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index a822189a59..ea472dc5cf 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -72,36 +72,53 @@ var state struct { pclient *aiplatform.PredictionClient } +// Config is the configuration for the plugin. +type Config struct { + // The cloud project to use for Vertex AI. + // If empty, the values of the environment variables GCLOUD_PROJECT + // and GOOGLE_CLOUD_PROJECT will be consulted, in that order. + ProjectID string + // The location of the Vertex AI service. The default is "us-central1". + Location string + // Options to the Vertex AI client. + ClientOptions []option.ClientOption +} + // Init initializes the plugin and all known models and embedders. // After calling Init, you may call [DefineModel] and [DefineEmbedder] to create // and register any additional generative models and embedders -func Init(ctx context.Context, projectID, location string) error { +func Init(ctx context.Context, cfg *Config) error { + if cfg == nil { + cfg = &Config{} + } state.mu.Lock() defer state.mu.Unlock() if state.initted { panic("vertexai.Init already called") } - if projectID == "" { - projectID = os.Getenv("GCLOUD_PROJECT") - if projectID == "" { - projectID = os.Getenv("GOOGLE_CLOUD_PROJECT") - } - if projectID == "" { - return fmt.Errorf("vertexai.Init: Vertex AI requires setting GCLOUD_PROJECT or GOOGLE_CLOUD_PROJECT in the environment") - } + + state.projectID = cfg.ProjectID + if state.projectID == "" { + state.projectID = os.Getenv("GCLOUD_PROJECT") + } + if state.projectID == "" { + state.projectID = os.Getenv("GOOGLE_CLOUD_PROJECT") } - state.projectID = projectID - if location == "" { - location = "us-central1" + if state.projectID == "" { + return fmt.Errorf("vertexai.Init: Vertex AI requires setting GCLOUD_PROJECT or GOOGLE_CLOUD_PROJECT in the environment") + } + + state.location = cfg.Location + if state.location == "" { + state.location = "us-central1" } - state.location = location var err error // Client for Gemini SDK. - state.gclient, err = genai.NewClient(ctx, projectID, location) + state.gclient, err = genai.NewClient(ctx, state.projectID, state.location, cfg.ClientOptions...) if err != nil { return err } - endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location) + endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", state.location) numConns := max(runtime.GOMAXPROCS(0), 4) o := []option.ClientOption{ option.WithEndpoint(endpoint), diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index b8675d647c..4238d2fe75 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -39,7 +39,8 @@ func TestLive(t *testing.T) { } ctx := context.Background() const modelName = "gemini-1.0-pro" - err := vertexai.Init(ctx, *projectID, *location) + const embedderName = "textembedding-gecko" + err := vertexai.Init(ctx, &vertexai.Config{ProjectID: *projectID, Location: *location}) if err != nil { t.Fatal(err) } diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 1cb4f90cd5..1649e9c77f 100755 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -97,7 +97,7 @@ type testAllCoffeeFlowsOutput struct { } func main() { - if err := googleai.Init(context.Background(), ""); err != nil { + if err := googleai.Init(context.Background(), nil); err != nil { log.Fatal(err) } diff --git a/go/samples/menu/main.go b/go/samples/menu/main.go index 42d6e3ba84..093c573519 100644 --- a/go/samples/menu/main.go +++ b/go/samples/menu/main.go @@ -16,7 +16,6 @@ package main import ( "context" - "fmt" "log" "os" @@ -67,14 +66,8 @@ type textMenuQuestionInput struct { var textMenuQuestionInputSchema = jsonschema.Reflect(textMenuQuestionInput{}) func main() { - projectID := os.Getenv("GCLOUD_PROJECT") - if projectID == "" { - fmt.Fprintln(os.Stderr, "menu example requires setting GCLOUD_PROJECT in the environment.") - os.Exit(1) - } - ctx := context.Background() - err := vertexai.Init(ctx, projectID, os.Getenv("GCLOUD_LOCATION")) + err := vertexai.Init(ctx, &vertexai.Config{Location: os.Getenv("GCLOUD_LOCATION")}) if err != nil { log.Fatal(err) } diff --git a/go/samples/rag/main.go b/go/samples/rag/main.go index 857cc59ff6..44c5a28735 100644 --- a/go/samples/rag/main.go +++ b/go/samples/rag/main.go @@ -38,7 +38,6 @@ import ( "context" "fmt" "log" - "os" "strings" "github.com/firebase/genkit/go/ai" @@ -69,13 +68,7 @@ type simpleQaPromptInput struct { } func main() { - apiKey := os.Getenv("GOOGLE_GENAI_API_KEY") - if apiKey == "" { - fmt.Fprintln(os.Stderr, "rag example requires setting GOOGLE_GENAI_API_KEY in the environment.") - fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.") - os.Exit(1) - } - err := googleai.Init(context.Background(), apiKey) + err := googleai.Init(context.Background(), nil) if err != nil { log.Fatal(err) }