From db154030eb60c25240ebd66dbe68cf2a3c7f0e14 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 25 Jun 2024 12:16:36 -0400 Subject: [PATCH 1/2] feat: [Go] added DefineAllKnownModels function to googeai and vertexai --- go/plugins/googleai/googleai.go | 16 +++++++-- go/plugins/vertexai/vertexai.go | 54 ++++++++++++++++++++++++---- go/plugins/vertexai/vertexai_test.go | 5 ++- go/samples/coffee-shop/main.go | 11 +++--- go/samples/menu/main.go | 13 +++++-- 5 files changed, 82 insertions(+), 17 deletions(-) diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 6c637c2323..54142487a4 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -39,19 +39,20 @@ var ( basicText = ai.ModelCapabilities{ Multiturn: true, Tools: true, - SystemRole: true, + SystemRole: false, Media: false, } multimodal = ai.ModelCapabilities{ Multiturn: true, Tools: true, - SystemRole: true, + SystemRole: false, Media: true, } knownCaps = map[string]ai.ModelCapabilities{ "gemini-1.0-pro": basicText, + "gemini-1.5-pro": multimodal, "gemini-1.5-flash": multimodal, } ) @@ -117,6 +118,17 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) { return defineModel(name, mc), nil } +// DefineAllKnownModels initializes and registers all known models. +func DefineAllKnownModels() error { + for modelName, caps := range knownCaps { + _, err := DefineModel(modelName, &caps) + if err != nil { + return err + } + } + return nil +} + // requires state.mu func defineModel(name string, caps ai.ModelCapabilities) *ai.Model { meta := &ai.ModelMetadata{ diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index cc08dc83d2..e17b2d0a7c 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -30,6 +30,28 @@ import ( const provider = "vertexai" +var ( + basicText = ai.ModelCapabilities{ + Multiturn: true, + Tools: true, + SystemRole: false, + Media: false, + } + + multimodal = ai.ModelCapabilities{ + Multiturn: true, + Tools: true, + SystemRole: false, + Media: true, + } + + knownCaps = map[string]ai.ModelCapabilities{ + "gemini-1.0-pro": basicText, + "gemini-1.5-pro": multimodal, + "gemini-1.5-flash": multimodal, + } +) + var state struct { mu sync.Mutex initted bool @@ -72,20 +94,40 @@ func Init(ctx context.Context, projectID, location string) error { } // DefineModel defines a model with the given name. -func DefineModel(name string) *ai.Model { +func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) { state.mu.Lock() defer state.mu.Unlock() if !state.initted { panic("vertexai.Init not called") } + var mc ai.ModelCapabilities + if caps == nil { + var ok bool + mc, ok = knownCaps[name] + if !ok { + return nil, fmt.Errorf("vertextai.DefineModel: called with unknown model %q and nil ModelCapabilities", name) + } + } else { + mc = *caps + } + meta := &ai.ModelMetadata{ - Label: "Vertex AI - " + name, - Supports: ai.ModelCapabilities{ - Multiturn: true, - }, + Label: "Vertex AI - " + name, + Supports: mc, } g := &generator{model: name, client: state.gclient} - return ai.DefineModel(provider, name, meta, g.generate) + return ai.DefineModel(provider, name, meta, g.generate), nil +} + +// DefineAllKnownModels initializes and registers all known models. +func DefineAllKnownModels() error { + for modelName, caps := range knownCaps { + _, err := DefineModel(modelName, &caps) + if err != nil { + return err + } + } + return nil } // DefineModel defines an embedder with the given name. diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index c2900be4f7..9ade6d89ea 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -44,7 +44,10 @@ func TestLive(t *testing.T) { if err != nil { t.Fatal(err) } - model := vertexai.DefineModel(modelName) + model, err := vertexai.DefineModel(modelName, nil) + if err != nil { + t.Fatal(err) + } embedder := vertexai.DefineEmbedder(embedderName) toolDef := &ai.ToolDefinition{ diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 4fc0d21518..ff1aae53c1 100755 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -104,8 +104,10 @@ func main() { fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.") os.Exit(1) } - err := googleai.Init(context.Background(), apiKey) - if err != nil { + if err := googleai.Init(context.Background(), apiKey); err != nil { + log.Fatal(err) + } + if err := googleai.DefineAllKnownModels(); err != nil { log.Fatal(err) } @@ -113,10 +115,7 @@ func main() { AllowAdditionalProperties: false, DoNotReference: true, } - g, err := googleai.DefineModel("gemini-1.5-pro", nil) - if err != nil { - log.Fatal(err) - } + g := googleai.Model("gemini-1.5-flash") simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate, dotprompt.Config{ Model: g, diff --git a/go/samples/menu/main.go b/go/samples/menu/main.go index 12c8badeef..a5104725b8 100644 --- a/go/samples/menu/main.go +++ b/go/samples/menu/main.go @@ -78,8 +78,14 @@ func main() { if err != nil { log.Fatal(err) } - model := vertexai.DefineModel("gemini-1.0-pro") - visionModel := vertexai.DefineModel("gemini-1.0-pro-vision") + model, err := vertexai.DefineModel("gemini-1.0-pro", nil) + if err != nil { + log.Fatal(err) + } + visionModel, err := vertexai.DefineModel("gemini-1.5-flash", nil) + if err != nil { + log.Fatal(err) + } embedder := vertexai.DefineEmbedder("textembedding-gecko") if err := setup01(ctx, model); err != nil { log.Fatal(err) @@ -99,6 +105,9 @@ func main() { indexer, retriever, err := localvec.DefineIndexerAndRetriever("go-menu_items", localvec.Config{ Embedder: embedder, }) + if err != nil { + log.Fatal(err) + } if err := setup04(ctx, indexer, retriever, model); err != nil { log.Fatal(err) } From 80a110f59d8f7e5b8b5e0e9627ef2a1b5936d36b Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 25 Jun 2024 12:49:09 -0400 Subject: [PATCH 2/2] KnownModels --- go/plugins/googleai/googleai.go | 16 ++++++++-------- go/plugins/vertexai/vertexai.go | 22 ++++++++++++++-------- go/samples/coffee-shop/main.go | 7 +++++-- 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 54142487a4..1e360b2e62 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -118,15 +118,15 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) { return defineModel(name, mc), nil } -// DefineAllKnownModels initializes and registers all known models. -func DefineAllKnownModels() error { - for modelName, caps := range knownCaps { - _, err := DefineModel(modelName, &caps) - if err != nil { - return err - } +// KnownModels returns a slice of all known model names. +func KnownModels() []string { + keys := make([]string, len(knownCaps)) + i := 0 + for k := range knownCaps { + keys[i] = k + i++ } - return nil + return keys } // requires state.mu diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index e17b2d0a7c..5a02cc520b 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -119,15 +119,21 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) { return ai.DefineModel(provider, name, meta, g.generate), nil } -// DefineAllKnownModels initializes and registers all known models. -func DefineAllKnownModels() error { - for modelName, caps := range knownCaps { - _, err := DefineModel(modelName, &caps) - if err != nil { - return err - } +// IsKnownModel reports whether a model is known to this plugin. +func IsKnownModel(name string) bool { + _, ok := knownCaps[name] + return ok +} + +// KnownModels returns a slice of all known model names. +func KnownModels() []string { + keys := make([]string, len(knownCaps)) + i := 0 + for k := range knownCaps { + keys[i] = k + i++ } - return nil + return keys } // DefineModel defines an embedder with the given name. diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index ff1aae53c1..82d5cb08da 100755 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -107,8 +107,11 @@ func main() { if err := googleai.Init(context.Background(), apiKey); err != nil { log.Fatal(err) } - if err := googleai.DefineAllKnownModels(); err != nil { - log.Fatal(err) + for _, mname := range googleai.KnownModels() { + _, err := googleai.DefineModel(mname, nil) + if err != nil { + log.Fatal(err) + } } r := &jsonschema.Reflector{