Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions go/internal/doc-snippets/dotprompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ func dot01() error {
ctx := context.Background()

// The .prompt file specifies vertexai/gemini-1.5-pro, so make sure it's set
// up:
projectID := os.Getenv("GCLOUD_PROJECT")
vertexai.Init(ctx, projectID, "us-central1")
// up.
// Default to the project in GCLOUD_PROJECT and the location "us-central1".
vertexai.Init(ctx, nil)
vertexai.DefineModel("gemini-1.5-pro", nil)

type GreetingPromptInput struct {
Expand Down
6 changes: 3 additions & 3 deletions go/internal/doc-snippets/init/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ import (
func main() {
ctx := context.Background()

// Initialize the Google AI plugin. When you pass an empty string for the
// apiKey parameter, the Google AI plugin will use the value from the
// Initialize the Google AI plugin. When you pass nil for the
// Config parameter, the Google AI plugin will get the API key from the
// GOOGLE_GENAI_API_KEY environment variable, which is the recommended
// practice.
if err := googleai.Init(ctx, ""); err != nil {
if err := googleai.Init(ctx, nil); err != nil {
log.Fatal(err)
}

Expand Down
7 changes: 4 additions & 3 deletions go/internal/doc-snippets/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ var gemini15pro *ai.Model

func m1() error {
// !+init
projectID := os.Getenv("GCLOUD_PROJECT")
if err := vertexai.Init(ctx, projectID, "us-central1"); err != nil {
// Default to the value of GCLOUD_PROJECT for the project,
// and "us-central1" for the location.
// To specify these values directly, pass a vertexai.Config value to Init.
if err := vertexai.Init(ctx, nil); err != nil {
return err
}
// !-init
_ = projectID

// !+model
gemini15pro := vertexai.Model("gemini-1.5-pro")
Expand Down
19 changes: 17 additions & 2 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,23 @@ var (
knownEmbedders = []string{"text-embedding-004", "embedding-001"}
)

// Config is the configuration for the plugin.
type Config struct {
// The API key to access the service.
// If empty, the values of the environment variables GOOGLE_GENAI_API_KEY
// and GOOGLE_API_KEY will be consulted, in that order.
APIKey string
// Options to the Google AI client.
ClientOptions []option.ClientOption
}

// Init initializes the plugin and all known models and embedders.
// After calling Init, you may call [DefineModel] and [DefineEmbedder] to create
// and register any additional generative models and embedders
func Init(ctx context.Context, apiKey string) (err error) {
func Init(ctx context.Context, cfg *Config) (err error) {
if cfg == nil {
cfg = &Config{}
}
state.mu.Lock()
defer state.mu.Unlock()
if state.initted {
Expand All @@ -74,6 +87,7 @@ func Init(ctx context.Context, apiKey string) (err error) {
}
}()

apiKey := cfg.APIKey
if apiKey == "" {
apiKey = os.Getenv("GOOGLE_GENAI_API_KEY")
if apiKey == "" {
Expand All @@ -84,7 +98,8 @@ func Init(ctx context.Context, apiKey string) (err error) {
}
}

client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
opts := append([]option.ClientOption{option.WithAPIKey(apiKey)}, cfg.ClientOptions...)
client, err := genai.NewClient(ctx, opts...)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestLive(t *testing.T) {
t.Skip("-all provided")
}
ctx := context.Background()
err := googleai.Init(ctx, *apiKey)
err := googleai.Init(ctx, &googleai.Config{APIKey: *apiKey})
if err != nil {
t.Fatal(err)
}
Expand Down
47 changes: 32 additions & 15 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,36 +72,53 @@ var state struct {
pclient *aiplatform.PredictionClient
}

// Config is the configuration for the plugin.
type Config struct {
// The cloud project to use for Vertex AI.
// If empty, the values of the environment variables GCLOUD_PROJECT
// and GOOGLE_CLOUD_PROJECT will be consulted, in that order.
ProjectID string
// The location of the Vertex AI service. The default is "us-central1".
Location string
// Options to the Vertex AI client.
ClientOptions []option.ClientOption
}

// Init initializes the plugin and all known models and embedders.
// After calling Init, you may call [DefineModel] and [DefineEmbedder] to create
// and register any additional generative models and embedders
func Init(ctx context.Context, projectID, location string) error {
func Init(ctx context.Context, cfg *Config) error {
if cfg == nil {
cfg = &Config{}
}
state.mu.Lock()
defer state.mu.Unlock()
if state.initted {
panic("vertexai.Init already called")
}
if projectID == "" {
projectID = os.Getenv("GCLOUD_PROJECT")
if projectID == "" {
projectID = os.Getenv("GOOGLE_CLOUD_PROJECT")
}
if projectID == "" {
return fmt.Errorf("vertexai.Init: Vertex AI requires setting GCLOUD_PROJECT or GOOGLE_CLOUD_PROJECT in the environment")
}

state.projectID = cfg.ProjectID
if state.projectID == "" {
state.projectID = os.Getenv("GCLOUD_PROJECT")
}
if state.projectID == "" {
state.projectID = os.Getenv("GOOGLE_CLOUD_PROJECT")
}
state.projectID = projectID
if location == "" {
location = "us-central1"
if state.projectID == "" {
return fmt.Errorf("vertexai.Init: Vertex AI requires setting GCLOUD_PROJECT or GOOGLE_CLOUD_PROJECT in the environment")
}

state.location = cfg.Location
if state.location == "" {
state.location = "us-central1"
}
state.location = location
var err error
// Client for Gemini SDK.
state.gclient, err = genai.NewClient(ctx, projectID, location)
state.gclient, err = genai.NewClient(ctx, state.projectID, state.location, cfg.ClientOptions...)
if err != nil {
return err
}
endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
endpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", state.location)
numConns := max(runtime.GOMAXPROCS(0), 4)
o := []option.ClientOption{
option.WithEndpoint(endpoint),
Expand Down
3 changes: 2 additions & 1 deletion go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ func TestLive(t *testing.T) {
}
ctx := context.Background()
const modelName = "gemini-1.0-pro"
err := vertexai.Init(ctx, *projectID, *location)
const embedderName = "textembedding-gecko"
err := vertexai.Init(ctx, &vertexai.Config{ProjectID: *projectID, Location: *location})
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion go/samples/coffee-shop/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ type testAllCoffeeFlowsOutput struct {
}

func main() {
if err := googleai.Init(context.Background(), ""); err != nil {
if err := googleai.Init(context.Background(), nil); err != nil {
log.Fatal(err)
}

Expand Down
9 changes: 1 addition & 8 deletions go/samples/menu/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package main

import (
"context"
"fmt"
"log"
"os"

Expand Down Expand Up @@ -67,14 +66,8 @@ type textMenuQuestionInput struct {
var textMenuQuestionInputSchema = jsonschema.Reflect(textMenuQuestionInput{})

func main() {
projectID := os.Getenv("GCLOUD_PROJECT")
if projectID == "" {
fmt.Fprintln(os.Stderr, "menu example requires setting GCLOUD_PROJECT in the environment.")
os.Exit(1)
}

ctx := context.Background()
err := vertexai.Init(ctx, projectID, os.Getenv("GCLOUD_LOCATION"))
err := vertexai.Init(ctx, &vertexai.Config{Location: os.Getenv("GCLOUD_LOCATION")})
if err != nil {
log.Fatal(err)
}
Expand Down
9 changes: 1 addition & 8 deletions go/samples/rag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import (
"context"
"fmt"
"log"
"os"
"strings"

"github.com/firebase/genkit/go/ai"
Expand Down Expand Up @@ -69,13 +68,7 @@ type simpleQaPromptInput struct {
}

func main() {
apiKey := os.Getenv("GOOGLE_GENAI_API_KEY")
if apiKey == "" {
fmt.Fprintln(os.Stderr, "rag example requires setting GOOGLE_GENAI_API_KEY in the environment.")
fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.")
os.Exit(1)
}
err := googleai.Init(context.Background(), apiKey)
err := googleai.Init(context.Background(), nil)
if err != nil {
log.Fatal(err)
}
Expand Down