Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] refactor googleai generate and copy to vertexai #650

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 102 additions & 69 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,13 @@ func defineModel(name string, caps ai.ModelCapabilities) *ai.Model {
Label: labelPrefix + " - " + name,
Supports: caps,
}
g := generator{model: name, client: state.gclient}
return ai.DefineModel(provider, name, meta, g.generate)
return ai.DefineModel(provider, name, meta, func(
ctx context.Context,
input *ai.GenerateRequest,
cb func(context.Context, *ai.GenerateResponseChunk) error,
) (*ai.GenerateResponse, error) {
return generate(ctx, state.gclient, name, input, cb)
})
}

// IsDefinedModel reports whether the named [Model] is defined by this plugin.
Expand All @@ -157,6 +162,8 @@ func IsDefinedModel(name string) bool {

//copy:stop

//copy:start vertexai.go defineEmbedder

// DefineEmbedder defines an embedder with a given name.
func DefineEmbedder(name string) *ai.Embedder {
state.mu.Lock()
Expand All @@ -172,6 +179,8 @@ func IsDefinedEmbedder(name string) bool {
return ai.IsDefinedEmbedder(provider, name)
}

//copy:stop

// requires state.mu
func defineEmbedder(name string) *ai.Embedder {
return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) (*ai.EmbedResponse, error) {
Expand Down Expand Up @@ -213,16 +222,85 @@ func Embedder(name string) *ai.Embedder {

//copy:stop

type generator struct {
model string
client *genai.Client
//session *genai.ChatSession // non-nil if we're in the middle of a chat
}
//copy:start vertexai.go generate

func generate(
ctx context.Context,
client *genai.Client,
model string,
input *ai.GenerateRequest,
cb func(context.Context, *ai.GenerateResponseChunk) error,
) (*ai.GenerateResponse, error) {
gm := newModel(client, model, input)
cs, err := startChat(gm, input)
if err != nil {
return nil, err
}
// The last message gets added to the parts slice.
var parts []genai.Part
if len(input.Messages) > 0 {
last := input.Messages[len(input.Messages)-1]
var err error
parts, err = convertParts(last.Content)
if err != nil {
return nil, err
}
}

gm.Tools, err = convertTools(input.Tools)
if err != nil {
return nil, err
}
// Convert input.Tools and append to gm.Tools

// TODO: gm.ToolConfig?

// Send out the actual request.
if cb == nil {
resp, err := cs.SendMessage(ctx, parts...)
if err != nil {
return nil, err
}
r := translateResponse(resp)
r.Request = input
return r, nil
}

func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) {
gm := g.client.GenerativeModel(g.model)
// Streaming version.
iter := cs.SendMessageStream(ctx, parts...)
var r *ai.GenerateResponse
for {
chunk, err := iter.Next()
if err == iterator.Done {
r = translateResponse(iter.MergedResponse())
break
}
if err != nil {
return nil, err
}
// Send candidates to the callback.
for _, c := range chunk.Candidates {
tc := translateCandidate(c)
err := cb(ctx, &ai.GenerateResponseChunk{
Content: tc.Message.Content,
Index: tc.Index,
})
if err != nil {
return nil, err
}
}
}
if r == nil {
// No candidates were returned. Probably rare, but it might avoid a NPE
// to return an empty instead of nil result.
r = &ai.GenerateResponse{}
}
r.Request = input
return r, nil
}

// Translate from a ai.GenerateRequest to a genai request.
func newModel(client *genai.Client, model string, input *ai.GenerateRequest) *genai.GenerativeModel {
gm := client.GenerativeModel(model)
gm.SetCandidateCount(int32(input.Candidates))
if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil {
if c.MaxOutputTokens != 0 {
Expand All @@ -241,8 +319,11 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
gm.SetTopP(float32(c.TopP))
}
}
return gm
}

// Start a "chat".
// startChat starts a chat session and configures it with the input messages.
func startChat(gm *genai.GenerativeModel, input *ai.GenerateRequest) (*genai.ChatSession, error) {
cs := gm.StartChat()

// All but the last message goes in the history field.
Expand All @@ -259,18 +340,11 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
Role: string(m.Role),
})
}
// The last message gets added to the parts slice.
var parts []genai.Part
if len(messages) > 0 {
var err error
parts, err = convertParts(messages[0].Content)
if err != nil {
return nil, err
}
}

// Convert input.Tools and append to gm.Tools
for _, t := range input.Tools {
return cs, nil
}
func convertTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) {
var outTools []*genai.Tool
for _, t := range inTools {
schema := &genai.Schema{}
schema.Type = genai.TypeObject
schema.Properties = map[string]*genai.Schema{}
Expand All @@ -286,7 +360,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
case "bool":
typ = genai.TypeBoolean
default:
return nil, fmt.Errorf("schema value \"%s\" not allowed", v)
return nil, fmt.Errorf("schema value %q not allowed", v)
}
schema.Properties[k] = &genai.Schema{Type: typ}
}
Expand All @@ -295,54 +369,13 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
Parameters: schema,
Description: t.Description,
}
gm.Tools = append(gm.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{fd}})
}
// TODO: gm.ToolConfig?

// Send out the actual request.
if cb == nil {
resp, err := cs.SendMessage(ctx, parts...)
if err != nil {
return nil, err
}
r := translateResponse(resp)
r.Request = input
return r, nil
}

// Streaming version.
iter := cs.SendMessageStream(ctx, parts...)
var r *ai.GenerateResponse
for {
chunk, err := iter.Next()
if err == iterator.Done {
r = translateResponse(iter.MergedResponse())
break
}
if err != nil {
return nil, err
}
// Send candidates to the callback.
for _, c := range chunk.Candidates {
tc := translateCandidate(c)
err := cb(ctx, &ai.GenerateResponseChunk{
Content: tc.Message.Content,
Index: tc.Index,
})
if err != nil {
return nil, err
}
}
outTools = append(outTools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{fd}})
}
if r == nil {
// No candidates were returned. Probably rare, but it might avoid a NPE
// to return an empty instead of nil result.
r = &ai.GenerateResponse{}
}
r.Request = input
return r, nil
return outTools, nil
}

//copy:stop

//copy:start vertexai.go translateCandidate

// translateCandidate translates from a genai.GenerateContentResponse to an ai.GenerateResponse.
Expand Down
5 changes: 5 additions & 0 deletions go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ import (
// The tests here only work with an API key set to a valid value.
var apiKey = flag.String("key", "", "Gemini API key")

var header = flag.Bool("header", false, "run test for x-goog-client-api header")

// We can't test the DefineAll functions along with the other tests because
// we get duplicate definitions of models.
var testAll = flag.Bool("all", false, "test DefineAllXXX functions")
Expand Down Expand Up @@ -203,6 +205,9 @@ func TestLive(t *testing.T) {
}

func TestHeader(t *testing.T) {
if !*header {
t.Skip("skipped; to run, pass -header and don't run the live test")
}
ctx := context.Background()
var header http.Header
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading
Loading