diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 6c637c2323..1e360b2e62 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -39,19 +39,20 @@ var ( basicText = ai.ModelCapabilities{ Multiturn: true, Tools: true, - SystemRole: true, + SystemRole: false, Media: false, } multimodal = ai.ModelCapabilities{ Multiturn: true, Tools: true, - SystemRole: true, + SystemRole: false, Media: true, } knownCaps = map[string]ai.ModelCapabilities{ "gemini-1.0-pro": basicText, + "gemini-1.5-pro": multimodal, "gemini-1.5-flash": multimodal, } ) @@ -117,6 +118,17 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) { return defineModel(name, mc), nil } +// KnownModels returns a slice of all known model names. +func KnownModels() []string { + keys := make([]string, len(knownCaps)) + i := 0 + for k := range knownCaps { + keys[i] = k + i++ + } + return keys +} + // requires state.mu func defineModel(name string, caps ai.ModelCapabilities) *ai.Model { meta := &ai.ModelMetadata{ diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index cc08dc83d2..5a02cc520b 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -30,6 +30,28 @@ import ( const provider = "vertexai" +var ( + basicText = ai.ModelCapabilities{ + Multiturn: true, + Tools: true, + SystemRole: false, + Media: false, + } + + multimodal = ai.ModelCapabilities{ + Multiturn: true, + Tools: true, + SystemRole: false, + Media: true, + } + + knownCaps = map[string]ai.ModelCapabilities{ + "gemini-1.0-pro": basicText, + "gemini-1.5-pro": multimodal, + "gemini-1.5-flash": multimodal, + } +) + var state struct { mu sync.Mutex initted bool @@ -72,20 +94,46 @@ func Init(ctx context.Context, projectID, location string) error { } // DefineModel defines a model with the given name. -func DefineModel(name string) *ai.Model { +func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) { state.mu.Lock() defer state.mu.Unlock() if !state.initted { panic("vertexai.Init not called") } + var mc ai.ModelCapabilities + if caps == nil { + var ok bool + mc, ok = knownCaps[name] + if !ok { + return nil, fmt.Errorf("vertextai.DefineModel: called with unknown model %q and nil ModelCapabilities", name) + } + } else { + mc = *caps + } + meta := &ai.ModelMetadata{ - Label: "Vertex AI - " + name, - Supports: ai.ModelCapabilities{ - Multiturn: true, - }, + Label: "Vertex AI - " + name, + Supports: mc, } g := &generator{model: name, client: state.gclient} - return ai.DefineModel(provider, name, meta, g.generate) + return ai.DefineModel(provider, name, meta, g.generate), nil +} + +// IsKnownModel reports whether a model is known to this plugin. +func IsKnownModel(name string) bool { + _, ok := knownCaps[name] + return ok +} + +// KnownModels returns a slice of all known model names. +func KnownModels() []string { + keys := make([]string, len(knownCaps)) + i := 0 + for k := range knownCaps { + keys[i] = k + i++ + } + return keys } // DefineModel defines an embedder with the given name. diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index c2900be4f7..9ade6d89ea 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -44,7 +44,10 @@ func TestLive(t *testing.T) { if err != nil { t.Fatal(err) } - model := vertexai.DefineModel(modelName) + model, err := vertexai.DefineModel(modelName, nil) + if err != nil { + t.Fatal(err) + } embedder := vertexai.DefineEmbedder(embedderName) toolDef := &ai.ToolDefinition{ diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 4fc0d21518..82d5cb08da 100755 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -104,19 +104,21 @@ func main() { fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.") os.Exit(1) } - err := googleai.Init(context.Background(), apiKey) - if err != nil { + if err := googleai.Init(context.Background(), apiKey); err != nil { log.Fatal(err) } + for _, mname := range googleai.KnownModels() { + _, err := googleai.DefineModel(mname, nil) + if err != nil { + log.Fatal(err) + } + } r := &jsonschema.Reflector{ AllowAdditionalProperties: false, DoNotReference: true, } - g, err := googleai.DefineModel("gemini-1.5-pro", nil) - if err != nil { - log.Fatal(err) - } + g := googleai.Model("gemini-1.5-flash") simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate, dotprompt.Config{ Model: g, diff --git a/go/samples/menu/main.go b/go/samples/menu/main.go index 12c8badeef..a5104725b8 100644 --- a/go/samples/menu/main.go +++ b/go/samples/menu/main.go @@ -78,8 +78,14 @@ func main() { if err != nil { log.Fatal(err) } - model := vertexai.DefineModel("gemini-1.0-pro") - visionModel := vertexai.DefineModel("gemini-1.0-pro-vision") + model, err := vertexai.DefineModel("gemini-1.0-pro", nil) + if err != nil { + log.Fatal(err) + } + visionModel, err := vertexai.DefineModel("gemini-1.5-flash", nil) + if err != nil { + log.Fatal(err) + } embedder := vertexai.DefineEmbedder("textembedding-gecko") if err := setup01(ctx, model); err != nil { log.Fatal(err) @@ -99,6 +105,9 @@ func main() { indexer, retriever, err := localvec.DefineIndexerAndRetriever("go-menu_items", localvec.Config{ Embedder: embedder, }) + if err != nil { + log.Fatal(err) + } if err := setup04(ctx, indexer, retriever, model); err != nil { log.Fatal(err) }