Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,20 @@ var (
basicText = ai.ModelCapabilities{
Multiturn: true,
Tools: true,
SystemRole: true,
SystemRole: false,
Media: false,
}

multimodal = ai.ModelCapabilities{
Multiturn: true,
Tools: true,
SystemRole: true,
SystemRole: false,
Media: true,
}

knownCaps = map[string]ai.ModelCapabilities{
"gemini-1.0-pro": basicText,
"gemini-1.5-pro": multimodal,
"gemini-1.5-flash": multimodal,
}
)
Expand Down Expand Up @@ -117,6 +118,17 @@ func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
return defineModel(name, mc), nil
}

// KnownModels returns a slice of all known model names.
func KnownModels() []string {
keys := make([]string, len(knownCaps))
i := 0
for k := range knownCaps {
keys[i] = k
i++
}
return keys
}

// requires state.mu
func defineModel(name string, caps ai.ModelCapabilities) *ai.Model {
meta := &ai.ModelMetadata{
Expand Down
60 changes: 54 additions & 6 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@ import (

const provider = "vertexai"

var (
basicText = ai.ModelCapabilities{
Multiturn: true,
Tools: true,
SystemRole: false,
Media: false,
}

multimodal = ai.ModelCapabilities{
Multiturn: true,
Tools: true,
SystemRole: false,
Media: true,
}

knownCaps = map[string]ai.ModelCapabilities{
"gemini-1.0-pro": basicText,
"gemini-1.5-pro": multimodal,
"gemini-1.5-flash": multimodal,
}
)

var state struct {
mu sync.Mutex
initted bool
Expand Down Expand Up @@ -72,20 +94,46 @@ func Init(ctx context.Context, projectID, location string) error {
}

// DefineModel defines a model with the given name.
func DefineModel(name string) *ai.Model {
func DefineModel(name string, caps *ai.ModelCapabilities) (*ai.Model, error) {
state.mu.Lock()
defer state.mu.Unlock()
if !state.initted {
panic("vertexai.Init not called")
}
var mc ai.ModelCapabilities
if caps == nil {
var ok bool
mc, ok = knownCaps[name]
if !ok {
return nil, fmt.Errorf("vertextai.DefineModel: called with unknown model %q and nil ModelCapabilities", name)
}
} else {
mc = *caps
}

meta := &ai.ModelMetadata{
Label: "Vertex AI - " + name,
Supports: ai.ModelCapabilities{
Multiturn: true,
},
Label: "Vertex AI - " + name,
Supports: mc,
}
g := &generator{model: name, client: state.gclient}
return ai.DefineModel(provider, name, meta, g.generate)
return ai.DefineModel(provider, name, meta, g.generate), nil
}

// IsKnownModel reports whether a model is known to this plugin.
func IsKnownModel(name string) bool {
_, ok := knownCaps[name]
return ok
}

// KnownModels returns a slice of all known model names.
func KnownModels() []string {
keys := make([]string, len(knownCaps))
i := 0
for k := range knownCaps {
keys[i] = k
i++
}
return keys
}

// DefineModel defines an embedder with the given name.
Expand Down
5 changes: 4 additions & 1 deletion go/plugins/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ func TestLive(t *testing.T) {
if err != nil {
t.Fatal(err)
}
model := vertexai.DefineModel(modelName)
model, err := vertexai.DefineModel(modelName, nil)
if err != nil {
t.Fatal(err)
}
embedder := vertexai.DefineEmbedder(embedderName)

toolDef := &ai.ToolDefinition{
Expand Down
14 changes: 8 additions & 6 deletions go/samples/coffee-shop/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,21 @@ func main() {
fmt.Fprintln(os.Stderr, "You can get an API key at https://ai.google.dev.")
os.Exit(1)
}
err := googleai.Init(context.Background(), apiKey)
if err != nil {
if err := googleai.Init(context.Background(), apiKey); err != nil {
log.Fatal(err)
}
for _, mname := range googleai.KnownModels() {
_, err := googleai.DefineModel(mname, nil)
if err != nil {
log.Fatal(err)
}
}

r := &jsonschema.Reflector{
AllowAdditionalProperties: false,
DoNotReference: true,
}
g, err := googleai.DefineModel("gemini-1.5-pro", nil)
if err != nil {
log.Fatal(err)
}
g := googleai.Model("gemini-1.5-flash")
simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate,
dotprompt.Config{
Model: g,
Expand Down
13 changes: 11 additions & 2 deletions go/samples/menu/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,14 @@ func main() {
if err != nil {
log.Fatal(err)
}
model := vertexai.DefineModel("gemini-1.0-pro")
visionModel := vertexai.DefineModel("gemini-1.0-pro-vision")
model, err := vertexai.DefineModel("gemini-1.0-pro", nil)
if err != nil {
log.Fatal(err)
}
visionModel, err := vertexai.DefineModel("gemini-1.5-flash", nil)
if err != nil {
log.Fatal(err)
}
embedder := vertexai.DefineEmbedder("textembedding-gecko")
if err := setup01(ctx, model); err != nil {
log.Fatal(err)
Expand All @@ -99,6 +105,9 @@ func main() {
indexer, retriever, err := localvec.DefineIndexerAndRetriever("go-menu_items", localvec.Config{
Embedder: embedder,
})
if err != nil {
log.Fatal(err)
}
if err := setup04(ctx, indexer, retriever, model); err != nil {
log.Fatal(err)
}
Expand Down