Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 46 additions & 72 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ package googleai
import (
"context"
"fmt"
"path"
"slices"
"sync"

"github.com/firebase/genkit/go/ai"
Expand All @@ -34,11 +32,29 @@ var state struct {
mu sync.Mutex
initted bool
client *genai.Client
// Results from ListModels
modelNames []string
embedderNames []string
}

var (
basicText = ai.ModelCapabilities{
Multiturn: true,
Tools: true,
SystemRole: true,
Media: false,
}

multimodal = ai.ModelCapabilities{
Multiturn: true,
Tools: true,
SystemRole: true,
Media: true,
}

knownCaps = map[string]ai.ModelCapabilities{
"gemini-1.0-pro": basicText,
"gemini-1.5-flash": multimodal,
}
)

// Init initializes the plugin.
// After calling Init, call [DefineModel] and [DefineEmbedder] to create and register
// generative models and embedders.
Expand All @@ -63,23 +79,41 @@ func Init(ctx context.Context, apiKey string) (err error) {
return nil
}

// IsKnownModel reports whether a model is known to this plugin.
func IsKnownModel(name string) bool {
_, ok := knownCaps[name]
return ok
}

// DefineModel defines a model with the given name.
func DefineModel(name string) *ai.ModelAction {
// The second argument describes the capability of the model.
// For known models, it can be nil, or if non-nil it will override the known value.
// It must be supplied for unknown models.
// Use [IsKnownModel] to determine if a model is known.
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.ModelAction, error) {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic("googleai.Init not called")
}
return defineModel(name)
var mc ai.ModelCapabilities
if caps == nil {
var ok bool
mc, ok = knownCaps[name]
if !ok {
return nil, fmt.Errorf("googleai.DefineModel: called with unknown model %q and nil ModelCapabilities", name)
}
} else {
mc = *caps
}
return defineModel(name, mc), nil
}

// requires state.mu
func defineModel(name string) *ai.ModelAction {
func defineModel(name string, caps ai.ModelCapabilities) *ai.ModelAction {
meta := &ai.ModelMetadata{
Label: "Google AI - " + name,
Supports: ai.ModelCapabilities{
Multiturn: true,
},
Label: "Google AI - " + name,
Supports: caps,
}
g := generator{model: name, client: state.client}
return ai.DefineModel(provider, name, meta, g.generate)
Expand Down Expand Up @@ -111,66 +145,6 @@ func defineEmbedder(name string) *ai.EmbedderAction {
})
}

// DefineAllModels defines all models known to the service.
func DefineAllModels(ctx context.Context) ([]*ai.ModelAction, error) {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic("googleai.Init not called")
}
if err := listModels(ctx); err != nil {
return nil, err
}
var mas []*ai.ModelAction
for _, mod := range state.modelNames {
mas = append(mas, defineModel(mod))
}
return mas, nil
}

// DefineAllEmbedders defines all embedders known to the service.
func DefineAllEmbedders(ctx context.Context) ([]*ai.EmbedderAction, error) {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic("googleai.Init not called")
}
if err := listModels(ctx); err != nil {
return nil, err
}
var eas []*ai.EmbedderAction
for _, em := range state.embedderNames {
eas = append(eas, defineEmbedder(em))
}
return eas, nil
}

// requires state.mu
func listModels(ctx context.Context) error {
if len(state.modelNames) > 0 || len(state.embedderNames) > 0 {
// already called
return nil
}
iter := state.client.ListModels(ctx)
for {
mi, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
// Model names are of the form "models/name".
name := path.Base(mi.Name)
if slices.Contains(mi.SupportedGenerationMethods, "generateContent") {
state.modelNames = append(state.modelNames, name)
} else if slices.Contains(mi.SupportedGenerationMethods, "embedContent") {
state.embedderNames = append(state.embedderNames, name)
}
}
return nil
}

// Model returns the [ai.ModelAction] with the given name.
// It returns nil if the model was not configured.
func Model(name string) *ai.ModelAction {
Expand Down
29 changes: 7 additions & 22 deletions go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ func TestLive(t *testing.T) {
t.Fatal(err)
}
embedder := googleai.DefineEmbedder("embedding-001")
model := googleai.DefineModel("gemini-1.0-pro")
model, err := googleai.DefineModel("gemini-1.0-pro", nil)
if err != nil {
t.Fatal(err)
}
toolDef := &ai.ToolDefinition{
Name: "exponentiation",
InputSchema: map[string]any{"base": "float64", "exponent": "int"},
Expand Down Expand Up @@ -106,7 +109,7 @@ func TestLive(t *testing.T) {
req := &ai.GenerateRequest{
Candidates: 1,
Messages: []*ai.Message{
&ai.Message{
{
Content: []*ai.Part{ai.NewTextPart("Which country was Napoleon the emperor of?")},
Role: ai.RoleUser,
},
Expand All @@ -133,7 +136,7 @@ func TestLive(t *testing.T) {
req := &ai.GenerateRequest{
Candidates: 1,
Messages: []*ai.Message{
&ai.Message{
{
Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the Golden State Warriors.")},
Role: ai.RoleUser,
},
Expand Down Expand Up @@ -173,7 +176,7 @@ func TestLive(t *testing.T) {
req := &ai.GenerateRequest{
Candidates: 1,
Messages: []*ai.Message{
&ai.Message{
{
Content: []*ai.Part{ai.NewTextPart("what is 3.5 squared? Use the tool provided.")},
Role: ai.RoleUser,
},
Expand All @@ -193,21 +196,3 @@ func TestLive(t *testing.T) {
}
})
}

func TestAllModels(t *testing.T) {
if !*testAll {
t.Skip("-all not set")
}
ctx := context.Background()
if err := googleai.Init(ctx, *apiKey); err != nil {
t.Fatal(err)
}
mods, err := googleai.DefineAllModels(ctx)
if err != nil || len(mods) == 0 {
t.Fatalf("got %d, %v, want >0, nil", len(mods), err)
}
embs, err := googleai.DefineAllEmbedders(ctx)
if err != nil || len(embs) == 0 {
t.Fatalf("got %d, %v, want >0, nil", len(mods), err)
}
}
5 changes: 4 additions & 1 deletion go/samples/coffee-shop/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ func main() {
AllowAdditionalProperties: false,
DoNotReference: true,
}
g := googleai.DefineModel("gemini-1.5-pro")
g, err := googleai.DefineModel("gemini-1.5-pro", nil)
if err != nil {
log.Fatal(err)
}
simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate,
dotprompt.Config{
ModelAction: g,
Expand Down
5 changes: 4 additions & 1 deletion go/samples/rag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ func main() {
if err != nil {
log.Fatal(err)
}
model := googleai.DefineModel("gemini-1.0-pro")
model, err := googleai.DefineModel("gemini-1.0-pro", nil)
if err != nil {
log.Fatal(err)
}
embedder := googleai.DefineEmbedder("embedding-001")
if err := localvec.Init(); err != nil {
log.Fatal(err)
Expand Down