Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,28 @@ import (
"google.golang.org/api/option"
)

var (
defaultModelCapabilities = ai.ModelCapabilities{
Multiturn: true,
}
basicText = ai.ModelCapabilities{
Multiturn: true,
Tools: true,
SystemRole: true,
Media: false,
}
multimodal = ai.ModelCapabilities{
Multiturn: true,
Tools: true,
SystemRole: true,
Media: true,
}
knownModelsCapabilities = map[string]ai.ModelCapabilities{
"gemini-1.0-pro": basicText,
"gemini-1.5-flash": multimodal,
}
)

const provider = "googleai"

// Config provides configuration options for the Init function.
Expand All @@ -36,7 +58,7 @@ type Config struct {
APIKey string
// Generative models to provide.
// If empty, a complete list will be obtained from the service.
Models []string
Models map[string]*ai.ModelCapabilities
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the user provides both the model name and its capabilities? How do they know? What if they're wrong?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for new models that haven't been released by us yet.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alex, the idea here is that developers who are just starting out would leave Models blank, and the plugin would call the ListModels API and fetch all available models. It's good for developers who are just starting out, but the idea is that once they've built everything they would provide the final list of model they are using so that the framework doesn't have to do ListModels on cold start.

Reading that back, now I'm convinced that ListModels is not ideal. The bottom line is that we need ai.ModelCapabilities one way or the other. Either the user must provide them (which is bad DX and error prone) or we have to hardcode them. The approach we took on JS side is the latter -- all known models are hardcoded in the plugin with their capabilities. If a new model comes out that we're not aware of only then the developer will have to manually specify ai.ModelCapabilities (or wait until we add the model to the plugin).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like/agree with the last takeaway. Hardcoding what we know and allowing some way to use newly released models with manual settings feels like the right way to go since new models don't drop every week.

I didn't know about the ListModels API.

// Embedding models to provide.
// If empty, a complete list will be obtained from the service.
Embedders []string
Expand All @@ -58,7 +80,10 @@ func Init(ctx context.Context, cfg Config) (err error) {
return err
}

needModels := len(cfg.Models) == 0
needModels := cfg.Models == nil || len(cfg.Models) == 0
if needModels {
cfg.Models = map[string]*ai.ModelCapabilities{}
}
needEmbedders := len(cfg.Embedders) == 0
if needModels || needEmbedders {
iter := client.ListModels(ctx)
Expand All @@ -73,28 +98,35 @@ func Init(ctx context.Context, cfg Config) (err error) {
// Model names are of the form "models/name".
name := path.Base(mi.Name)
if needModels && slices.Contains(mi.SupportedGenerationMethods, "generateContent") {
cfg.Models = append(cfg.Models, name)
cfg.Models[name] = nil
}
if needEmbedders && slices.Contains(mi.SupportedGenerationMethods, "embedContent") {
cfg.Embedders = append(cfg.Embedders, name)
}
}
}
for _, name := range cfg.Models {
defineModel(name, client)
for name, c := range cfg.Models {
defineModel(name, client, c)
}
for _, name := range cfg.Embedders {
defineEmbedder(name, client)
}
return nil
}

func defineModel(name string, client *genai.Client) {
func defineModel(name string, client *genai.Client, capabilities *ai.ModelCapabilities) {
c := defaultModelCapabilities
if capabilities == nil {
foundCapability, ok := knownModelsCapabilities[name]
if ok {
c = foundCapability
}
} else {
c = *capabilities
}
meta := &ai.ModelMetadata{
Label: "Google AI - " + name,
Supports: ai.ModelCapabilities{
Multiturn: true,
},
Label: "Google AI - " + name,
Supports: c,
}
g := generator{model: name, client: client}
ai.DefineModel(provider, name, meta, g.generate)
Expand Down
2 changes: 1 addition & 1 deletion js/flow/tests/durable_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ describe('durable', () => {
name: 'testFlow',
inputSchema: z.string(),
outputSchema: z.string(),
experimentalDurable: true,
experimentalDurable: true,
},
async (input) => {
const response = await interrupt(
Expand Down