From 0d65c37ec7d3f56787c818e46a95f5d5cc9e3553 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 24 Jun 2024 04:34:54 -0400 Subject: [PATCH] [Go] googleai: require models to have capabilities Every model must be associated with a set of capabilities. DefineModel now takes an *ai.ModelCapabilities as a second argument. It can be omitted for known models, but must be provided for unknown ones. Modeled on #426. --- go/plugins/googleai/googleai.go | 118 +++++++++++---------------- go/plugins/googleai/googleai_test.go | 29 ++----- go/samples/coffee-shop/main.go | 5 +- go/samples/rag/main.go | 5 +- 4 files changed, 61 insertions(+), 96 deletions(-) diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 19e2c833c2..07c634fa71 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -17,8 +17,6 @@ package googleai import ( "context" "fmt" - "path" - "slices" "sync" "github.com/firebase/genkit/go/ai" @@ -34,11 +32,29 @@ var state struct { mu sync.Mutex initted bool client *genai.Client - // Results from ListModels - modelNames []string - embedderNames []string } +var ( + basicText = ai.ModelCapabilities{ + Multiturn: true, + Tools: true, + SystemRole: true, + Media: false, + } + + multimodal = ai.ModelCapabilities{ + Multiturn: true, + Tools: true, + SystemRole: true, + Media: true, + } + + knownCaps = map[string]ai.ModelCapabilities{ + "gemini-1.0-pro": basicText, + "gemini-1.5-flash": multimodal, + } +) + // Init initializes the plugin. // After calling Init, call [DefineModel] and [DefineEmbedder] to create and register // generative models and embedders. @@ -63,23 +79,41 @@ func Init(ctx context.Context, apiKey string) (err error) { return nil } +// IsKnownModel reports whether a model is known to this plugin. +func IsKnownModel(name string) bool { + _, ok := knownCaps[name] + return ok +} + // DefineModel defines a model with the given name. -func DefineModel(name string) *ai.ModelAction { +// The second argument describes the capability of the model. +// For known models, it can be nil, or if non-nil it will override the known value. +// It must be supplied for unknown models. +// Use [IsKnownModel] to determine if a model is known. +func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.ModelAction, error) { state.mu.Lock() defer state.mu.Unlock() if !state.initted { panic("googleai.Init not called") } - return defineModel(name) + var mc ai.ModelCapabilities + if caps == nil { + var ok bool + mc, ok = knownCaps[name] + if !ok { + return nil, fmt.Errorf("googleai.DefineModel: called with unknown model %q and nil ModelCapabilities", name) + } + } else { + mc = *caps + } + return defineModel(name, mc), nil } // requires state.mu -func defineModel(name string) *ai.ModelAction { +func defineModel(name string, caps ai.ModelCapabilities) *ai.ModelAction { meta := &ai.ModelMetadata{ - Label: "Google AI - " + name, - Supports: ai.ModelCapabilities{ - Multiturn: true, - }, + Label: "Google AI - " + name, + Supports: caps, } g := generator{model: name, client: state.client} return ai.DefineModel(provider, name, meta, g.generate) @@ -111,66 +145,6 @@ func defineEmbedder(name string) *ai.EmbedderAction { }) } -// DefineAllModels defines all models known to the service. -func DefineAllModels(ctx context.Context) ([]*ai.ModelAction, error) { - state.mu.Lock() - defer state.mu.Unlock() - if !state.initted { - panic("googleai.Init not called") - } - if err := listModels(ctx); err != nil { - return nil, err - } - var mas []*ai.ModelAction - for _, mod := range state.modelNames { - mas = append(mas, defineModel(mod)) - } - return mas, nil -} - -// DefineAllEmbedders defines all embedders known to the service. -func DefineAllEmbedders(ctx context.Context) ([]*ai.EmbedderAction, error) { - state.mu.Lock() - defer state.mu.Unlock() - if !state.initted { - panic("googleai.Init not called") - } - if err := listModels(ctx); err != nil { - return nil, err - } - var eas []*ai.EmbedderAction - for _, em := range state.embedderNames { - eas = append(eas, defineEmbedder(em)) - } - return eas, nil -} - -// requires state.mu -func listModels(ctx context.Context) error { - if len(state.modelNames) > 0 || len(state.embedderNames) > 0 { - // already called - return nil - } - iter := state.client.ListModels(ctx) - for { - mi, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return err - } - // Model names are of the form "models/name". - name := path.Base(mi.Name) - if slices.Contains(mi.SupportedGenerationMethods, "generateContent") { - state.modelNames = append(state.modelNames, name) - } else if slices.Contains(mi.SupportedGenerationMethods, "embedContent") { - state.embedderNames = append(state.embedderNames, name) - } - } - return nil -} - // Model returns the [ai.ModelAction] with the given name. // It returns nil if the model was not configured. func Model(name string) *ai.ModelAction { diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index 1ed4ac98db..4e8faf72fd 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -47,7 +47,10 @@ func TestLive(t *testing.T) { t.Fatal(err) } embedder := googleai.DefineEmbedder("embedding-001") - model := googleai.DefineModel("gemini-1.0-pro") + model, err := googleai.DefineModel("gemini-1.0-pro", nil) + if err != nil { + t.Fatal(err) + } toolDef := &ai.ToolDefinition{ Name: "exponentiation", InputSchema: map[string]any{"base": "float64", "exponent": "int"}, @@ -106,7 +109,7 @@ func TestLive(t *testing.T) { req := &ai.GenerateRequest{ Candidates: 1, Messages: []*ai.Message{ - &ai.Message{ + { Content: []*ai.Part{ai.NewTextPart("Which country was Napoleon the emperor of?")}, Role: ai.RoleUser, }, @@ -133,7 +136,7 @@ func TestLive(t *testing.T) { req := &ai.GenerateRequest{ Candidates: 1, Messages: []*ai.Message{ - &ai.Message{ + { Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the Golden State Warriors.")}, Role: ai.RoleUser, }, @@ -173,7 +176,7 @@ func TestLive(t *testing.T) { req := &ai.GenerateRequest{ Candidates: 1, Messages: []*ai.Message{ - &ai.Message{ + { Content: []*ai.Part{ai.NewTextPart("what is 3.5 squared? Use the tool provided.")}, Role: ai.RoleUser, }, @@ -193,21 +196,3 @@ func TestLive(t *testing.T) { } }) } - -func TestAllModels(t *testing.T) { - if !*testAll { - t.Skip("-all not set") - } - ctx := context.Background() - if err := googleai.Init(ctx, *apiKey); err != nil { - t.Fatal(err) - } - mods, err := googleai.DefineAllModels(ctx) - if err != nil || len(mods) == 0 { - t.Fatalf("got %d, %v, want >0, nil", len(mods), err) - } - embs, err := googleai.DefineAllEmbedders(ctx) - if err != nil || len(embs) == 0 { - t.Fatalf("got %d, %v, want >0, nil", len(mods), err) - } -} diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index fa42fbcc90..30dffd63cd 100755 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -113,7 +113,10 @@ func main() { AllowAdditionalProperties: false, DoNotReference: true, } - g := googleai.DefineModel("gemini-1.5-pro") + g, err := googleai.DefineModel("gemini-1.5-pro", nil) + if err != nil { + log.Fatal(err) + } simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate, dotprompt.Config{ ModelAction: g, diff --git a/go/samples/rag/main.go b/go/samples/rag/main.go index 0e7e40392e..0857691ef6 100644 --- a/go/samples/rag/main.go +++ b/go/samples/rag/main.go @@ -79,7 +79,10 @@ func main() { if err != nil { log.Fatal(err) } - model := googleai.DefineModel("gemini-1.0-pro") + model, err := googleai.DefineModel("gemini-1.0-pro", nil) + if err != nil { + log.Fatal(err) + } embedder := googleai.DefineEmbedder("embedding-001") if err := localvec.Init(); err != nil { log.Fatal(err)