From 3309d312b0fb767570f7ae9e38f1b4449a9c854e Mon Sep 17 00:00:00 2001 From: Keith Randall Date: Mon, 10 Jun 2024 13:23:08 -0700 Subject: [PATCH 1/4] implement vertexai streaming Fixes #344 --- go/go.mod | 8 +- go/go.sum | 8 ++ go/plugins/vertexai/vertexai.go | 115 +++++++++++++++++++++------ go/plugins/vertexai/vertexai_test.go | 40 ++++++++++ 4 files changed, 142 insertions(+), 29 deletions(-) diff --git a/go/go.mod b/go/go.mod index f1a645f8ae..e2dd67b775 100644 --- a/go/go.mod +++ b/go/go.mod @@ -3,9 +3,9 @@ module github.com/firebase/genkit/go go 1.22.0 require ( - cloud.google.com/go/aiplatform v1.66.0 + cloud.google.com/go/aiplatform v1.67.0 cloud.google.com/go/logging v1.9.0 - cloud.google.com/go/vertexai v0.7.1 + cloud.google.com/go/vertexai v0.10.0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.22.0 github.com/aymerick/raymond v2.0.2+incompatible @@ -22,7 +22,7 @@ require ( go.opentelemetry.io/otel/sdk/metric v1.26.0 go.opentelemetry.io/otel/trace v1.26.0 golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 - google.golang.org/api v0.178.0 + google.golang.org/api v0.180.0 google.golang.org/protobuf v1.34.1 gopkg.in/yaml.v3 v3.0.1 ) @@ -30,7 +30,7 @@ require ( require ( cloud.google.com/go v0.113.0 // indirect cloud.google.com/go/ai v0.5.0 // indirect - cloud.google.com/go/auth v0.4.0 // indirect + cloud.google.com/go/auth v0.4.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect cloud.google.com/go/iam v1.1.7 // indirect diff --git a/go/go.sum b/go/go.sum index fa72e192d0..26304299d6 100644 --- a/go/go.sum +++ b/go/go.sum @@ -5,8 +5,12 @@ cloud.google.com/go/ai v0.5.0 h1:x8s4rDn5t9OVZvBCgtr5bZTH5X0O7JdE6zYo+O+MpRw= cloud.google.com/go/ai v0.5.0/go.mod h1:96VBphk70e0zdXZrbtgPuKYRZsQ3UktSUXhuojwiKA8= cloud.google.com/go/aiplatform v1.66.0 h1:bbFYY4JInclG10czRFUYj2rjD+obhh3Gi9zVlyoMgEc= cloud.google.com/go/aiplatform v1.66.0/go.mod h1:bPQS0UjaXaTAq57UgP3XWDCtYFOIbXXpkMsl6uP4JAc= +cloud.google.com/go/aiplatform v1.67.0 h1:YWeqD4BjYwrmY4fa+isGcw0P81lJ3dKVxbWxdBchoiU= +cloud.google.com/go/aiplatform v1.67.0/go.mod h1:s/sJ6btBEr6bKnrNWdK9ZgHCvwbZNdP90b3DDtxxw+Y= cloud.google.com/go/auth v0.4.0 h1:vcJWEguhY8KuiHoSs/udg1JtIRYm3YAWPBE1moF1m3U= cloud.google.com/go/auth v0.4.0/go.mod h1:tO/chJN3obc5AbRYFQDsuFbL4wW5y8LfbPtDCfgwOVE= +cloud.google.com/go/auth v0.4.1 h1:Z7YNIhlWRtrnKlZke7z3GMqzvuYzdc2z98F9D1NV5Hg= +cloud.google.com/go/auth v0.4.1/go.mod h1:QVBuVEKpCn4Zp58hzRGvL0tjRGU0YqdRTdCHM1IHnro= cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= @@ -23,6 +27,8 @@ cloud.google.com/go/trace v1.10.6 h1:XF0Ejdw0NpRfAvuZUeQe3ClAG4R/9w5JYICo7l2weaw cloud.google.com/go/trace v1.10.6/go.mod h1:EABXagUjxGuKcZMy4pXyz0fJpE5Ghog3jzTxcEsVJS4= cloud.google.com/go/vertexai v0.7.1 h1:CSdqsEwjklLIlI1e5SrsnkwG/I+CeJekkBbMTzeYhVg= cloud.google.com/go/vertexai v0.7.1/go.mod h1:HfnfYR9aPS+qF2436S6Hzuw0Fp+PORjzK3ggqymdzSU= +cloud.google.com/go/vertexai v0.10.0 h1:k157bLrtyajGtAAZnqdEn8lwFlUTG3BgHc7kvWbP/3s= +cloud.google.com/go/vertexai v0.10.0/go.mod h1:w/Zb22QvOVvxx5CGM4fPzH3WA6gwUkId9juA7pigzFI= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0 h1:n3T26hyfDl9RdgcOjWvOFMh1lCBNuZ0JQ/3DM5pou8Y= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0/go.mod h1:3S7qK2nHOO2cLID3xk6H8f55D38XswhVFzKEk0nqIbY= @@ -186,6 +192,8 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.178.0 h1:yoW/QMI4bRVCHF+NWOTa4cL8MoWL3Jnuc7FlcFF91Ok= google.golang.org/api v0.178.0/go.mod h1:84/k2v8DFpDRebpGcooklv/lais3MEfqpaBLA12gl2U= +google.golang.org/api v0.180.0 h1:M2D87Yo0rGBPWpo1orwfCLehUUL6E7/TYe5gvMQWDh4= +google.golang.org/api v0.180.0/go.mod h1:51AiyoEg1MJPSZ9zvklA8VnRILPXxn1iVen9v25XHAE= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 6069a386ec..46335b663a 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "reflect" "runtime" aiplatform "cloud.google.com/go/aiplatform/apiv1" @@ -125,9 +126,6 @@ type generator struct { } func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) { - if cb != nil { - panic("streaming not supported yet") // TODO: streaming - } gm := g.client.GenerativeModel(g.model) // Translate from a ai.GenerateRequest to a genai request. @@ -143,7 +141,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb gm.SetTemperature(float32(c.Temperature)) } if c.TopK != 0 { - gm.SetTopK(float32(c.TopK)) + gm.SetTopK(int32(c.TopK)) } if c.TopP != 0 { gm.SetTopP(float32(c.TopP)) @@ -213,13 +211,77 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb // TODO: gm.ToolConfig? // Send out the actual request. - resp, err := cs.SendMessage(ctx, parts...) - if err != nil { - return nil, err + if cb == nil { + resp, err := cs.SendMessage(ctx, parts...) + if err != nil { + return nil, err + } + + r := translateResponse(resp) + r.Request = input + return r, nil } - r := translateResponse(resp) - r.Request = input + // Streaming version. + iter := cs.SendMessageStream(ctx, parts...) + r := &ai.GenerateResponse{Request: input, Candidates: make([]*ai.Candidate, input.Candidates)} + for { + chunk, err := iter.Next() + if err != nil { + if err.Error() == "no more items in iterator" { + break + } + return nil, err + } + + // Process each candidate. + for _, c := range chunk.Candidates { + tc := translateCandidate(c) + + // Call callback with the candidate info. + err := cb(ctx, &ai.GenerateResponseChunk{ + Content: tc.Message.Content, + Index: tc.Index, + }) + if err != nil { + return nil, err + } + // Save candidate in full response structure. + if old := r.Candidates[tc.Index]; old == nil { + r.Candidates[tc.Index] = tc + } else { + // Need to merge two "parts" of a candidate. + // Currently, we: + // - append the Message content + // - merge the FinishReason + // - assert everything else is unchanged + // (We do that 3rd step first.) + c1 := *r.Candidates[tc.Index] + c2 := *tc + m1 := *c1.Message + m2 := *c2.Message + c1.Message = &m1 + c2.Message = &m2 + m1.Content = nil + m2.Content = nil + c1.FinishReason = ai.FinishReasonUnknown + c2.FinishReason = ai.FinishReasonUnknown + if !reflect.DeepEqual(&c1, &c2) { + return nil, fmt.Errorf("some candidate fields unexpectedly changed") + } + + // Append the Parts to the final candidate. + old.Message.Content = append(old.Message.Content, tc.Message.Content...) + // Merge the FinishReasons. + if old.FinishReason == ai.FinishReasonUnknown { + old.FinishReason = tc.FinishReason + } else if old.FinishReason != tc.FinishReason { + return nil, fmt.Errorf("invalid finish reason transition: %s to %s", old.FinishReason, tc.FinishReason) + } + } + } + // TODO: use chunk.PromptFeedback, chunk.UsageMetadata + } return r, nil } @@ -242,23 +304,25 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate { c.FinishReason = ai.FinishReasonUnknown } m := &ai.Message{} - m.Role = ai.Role(cand.Content.Role) - for _, part := range cand.Content.Parts { - var p *ai.Part - switch part := part.(type) { - case genai.Text: - p = ai.NewTextPart(string(part)) - case genai.Blob: - p = ai.NewMediaPart(part.MIMEType, string(part.Data)) - case genai.FunctionCall: - p = ai.NewToolRequestPart(&ai.ToolRequest{ - Name: part.Name, - Input: part.Args, - }) - default: - panic(fmt.Sprintf("unknown part %#v", part)) + if cand.Content != nil { + m.Role = ai.Role(cand.Content.Role) + for _, part := range cand.Content.Parts { + var p *ai.Part + switch part := part.(type) { + case genai.Text: + p = ai.NewTextPart(string(part)) + case genai.Blob: + p = ai.NewMediaPart(part.MIMEType, string(part.Data)) + case genai.FunctionCall: + p = ai.NewToolRequestPart(&ai.ToolRequest{ + Name: part.Name, + Input: part.Args, + }) + default: + panic(fmt.Sprintf("unknown part %#v", part)) + } + m.Content = append(m.Content, p) } - m.Content = append(m.Content, p) } c.Message = m return c @@ -266,6 +330,7 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate { // Translate from a genai.GenerateContentResponse to a ai.GenerateResponse. func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse { + // Note: this path doesn't get used when streaming. r := &ai.GenerateResponse{} for _, c := range resp.Candidates { r.Candidates = append(r.Candidates, translateCandidate(c)) diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index 460ba3b31b..8f928484b2 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -106,6 +106,46 @@ func TestLive(t *testing.T) { t.Error("Request field not set properly") } }) + t.Run("streaming", func(t *testing.T) { + req := &ai.GenerateRequest{ + Candidates: 1, + Messages: []*ai.Message{ + &ai.Message{ + Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the Golden State Warriors.")}, + Role: ai.RoleUser, + }, + }, + } + + out := "" + parts := 0 + model := vertexai.Model(modelName) + final, err := ai.Generate(ctx, model, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error { + parts++ + for _, p := range c.Content { + out += p.Text + } + return nil + }) + if err != nil { + t.Fatal(err) + } + out2 := "" + for _, p := range final.Candidates[0].Message.Content { + out2 += p.Text + } + if out != out2 { + t.Errorf("streaming and final should contain the same text.\nstreaming:%s\nfinal:%s", out, out2) + } + const want = "Golden" + if !strings.Contains(out, want) { + t.Errorf("got %q, expecting it to contain %q", out, want) + } + if parts == 1 { + // Check if streaming actually occurred. + t.Errorf("expecting more than one part") + } + }) t.Run("tool", func(t *testing.T) { req := &ai.GenerateRequest{ Candidates: 1, From 125b1b828e1aba3de4941d3d3b830193a2310ad2 Mon Sep 17 00:00:00 2001 From: Keith Randall Date: Mon, 10 Jun 2024 16:58:46 -0700 Subject: [PATCH 2/4] round of review --- go/go.sum | 8 -------- go/plugins/vertexai/vertexai.go | 5 +++-- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/go/go.sum b/go/go.sum index 26304299d6..7667d35852 100644 --- a/go/go.sum +++ b/go/go.sum @@ -3,12 +3,8 @@ cloud.google.com/go v0.113.0 h1:g3C70mn3lWfckKBiCVsAshabrDg01pQ0pnX1MNtnMkA= cloud.google.com/go v0.113.0/go.mod h1:glEqlogERKYeePz6ZdkcLJ28Q2I6aERgDDErBg9GzO8= cloud.google.com/go/ai v0.5.0 h1:x8s4rDn5t9OVZvBCgtr5bZTH5X0O7JdE6zYo+O+MpRw= cloud.google.com/go/ai v0.5.0/go.mod h1:96VBphk70e0zdXZrbtgPuKYRZsQ3UktSUXhuojwiKA8= -cloud.google.com/go/aiplatform v1.66.0 h1:bbFYY4JInclG10czRFUYj2rjD+obhh3Gi9zVlyoMgEc= -cloud.google.com/go/aiplatform v1.66.0/go.mod h1:bPQS0UjaXaTAq57UgP3XWDCtYFOIbXXpkMsl6uP4JAc= cloud.google.com/go/aiplatform v1.67.0 h1:YWeqD4BjYwrmY4fa+isGcw0P81lJ3dKVxbWxdBchoiU= cloud.google.com/go/aiplatform v1.67.0/go.mod h1:s/sJ6btBEr6bKnrNWdK9ZgHCvwbZNdP90b3DDtxxw+Y= -cloud.google.com/go/auth v0.4.0 h1:vcJWEguhY8KuiHoSs/udg1JtIRYm3YAWPBE1moF1m3U= -cloud.google.com/go/auth v0.4.0/go.mod h1:tO/chJN3obc5AbRYFQDsuFbL4wW5y8LfbPtDCfgwOVE= cloud.google.com/go/auth v0.4.1 h1:Z7YNIhlWRtrnKlZke7z3GMqzvuYzdc2z98F9D1NV5Hg= cloud.google.com/go/auth v0.4.1/go.mod h1:QVBuVEKpCn4Zp58hzRGvL0tjRGU0YqdRTdCHM1IHnro= cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= @@ -25,8 +21,6 @@ cloud.google.com/go/monitoring v1.18.1 h1:0yvFXK+xQd95VKo6thndjwnJMno7c7Xw1CwMBy cloud.google.com/go/monitoring v1.18.1/go.mod h1:52hTzJ5XOUMRm7jYi7928aEdVxBEmGwA0EjNJXIBvt8= cloud.google.com/go/trace v1.10.6 h1:XF0Ejdw0NpRfAvuZUeQe3ClAG4R/9w5JYICo7l2weaw= cloud.google.com/go/trace v1.10.6/go.mod h1:EABXagUjxGuKcZMy4pXyz0fJpE5Ghog3jzTxcEsVJS4= -cloud.google.com/go/vertexai v0.7.1 h1:CSdqsEwjklLIlI1e5SrsnkwG/I+CeJekkBbMTzeYhVg= -cloud.google.com/go/vertexai v0.7.1/go.mod h1:HfnfYR9aPS+qF2436S6Hzuw0Fp+PORjzK3ggqymdzSU= cloud.google.com/go/vertexai v0.10.0 h1:k157bLrtyajGtAAZnqdEn8lwFlUTG3BgHc7kvWbP/3s= cloud.google.com/go/vertexai v0.10.0/go.mod h1:w/Zb22QvOVvxx5CGM4fPzH3WA6gwUkId9juA7pigzFI= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= @@ -190,8 +184,6 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3 golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.178.0 h1:yoW/QMI4bRVCHF+NWOTa4cL8MoWL3Jnuc7FlcFF91Ok= -google.golang.org/api v0.178.0/go.mod h1:84/k2v8DFpDRebpGcooklv/lais3MEfqpaBLA12gl2U= google.golang.org/api v0.180.0 h1:M2D87Yo0rGBPWpo1orwfCLehUUL6E7/TYe5gvMQWDh4= google.golang.org/api v0.180.0/go.mod h1:51AiyoEg1MJPSZ9zvklA8VnRILPXxn1iVen9v25XHAE= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 46335b663a..e483e1d609 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -25,6 +25,7 @@ import ( "cloud.google.com/go/vertexai/genai" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/plugins/internal/uri" + "google.golang.org/api/iterator" "google.golang.org/api/option" ) @@ -228,7 +229,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb for { chunk, err := iter.Next() if err != nil { - if err.Error() == "no more items in iterator" { + if err == iterator.Done { break } return nil, err @@ -267,7 +268,7 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb c1.FinishReason = ai.FinishReasonUnknown c2.FinishReason = ai.FinishReasonUnknown if !reflect.DeepEqual(&c1, &c2) { - return nil, fmt.Errorf("some candidate fields unexpectedly changed") + return nil, fmt.Errorf("some candidate fields unexpectedly changed\n%#v\n%#v", c1, c2) } // Append the Parts to the final candidate. From f8fbf66ba203f448dd19dfd58d11efba1072f30b Mon Sep 17 00:00:00 2001 From: Keith Randall Date: Thu, 13 Jun 2024 13:06:04 -0700 Subject: [PATCH 3/4] use newly-added MergedResponse api --- go/go.mod | 28 ++++++++--------- go/go.sum | 56 ++++++++++++++++----------------- go/plugins/vertexai/vertexai.go | 44 +++++--------------------- 3 files changed, 50 insertions(+), 78 deletions(-) diff --git a/go/go.mod b/go/go.mod index e2dd67b775..60f1a33dcb 100644 --- a/go/go.mod +++ b/go/go.mod @@ -3,9 +3,9 @@ module github.com/firebase/genkit/go go 1.22.0 require ( - cloud.google.com/go/aiplatform v1.67.0 - cloud.google.com/go/logging v1.9.0 - cloud.google.com/go/vertexai v0.10.0 + cloud.google.com/go/aiplatform v1.68.0 + cloud.google.com/go/logging v1.10.0 + cloud.google.com/go/vertexai v0.12.0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.22.0 github.com/aymerick/raymond v2.0.2+incompatible @@ -22,21 +22,21 @@ require ( go.opentelemetry.io/otel/sdk/metric v1.26.0 go.opentelemetry.io/otel/trace v1.26.0 golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 - google.golang.org/api v0.180.0 + google.golang.org/api v0.183.0 google.golang.org/protobuf v1.34.1 gopkg.in/yaml.v3 v3.0.1 ) require ( - cloud.google.com/go v0.113.0 // indirect + cloud.google.com/go v0.114.0 // indirect cloud.google.com/go/ai v0.5.0 // indirect - cloud.google.com/go/auth v0.4.1 // indirect + cloud.google.com/go/auth v0.5.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect - cloud.google.com/go/iam v1.1.7 // indirect + cloud.google.com/go/iam v1.1.8 // indirect cloud.google.com/go/longrunning v0.5.7 // indirect - cloud.google.com/go/monitoring v1.18.1 // indirect - cloud.google.com/go/trace v1.10.6 // indirect + cloud.google.com/go/monitoring v1.19.0 // indirect + cloud.google.com/go/trace v1.10.7 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.46.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect @@ -57,14 +57,14 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect golang.org/x/crypto v0.23.0 // indirect golang.org/x/net v0.25.0 // indirect - golang.org/x/oauth2 v0.20.0 // indirect + golang.org/x/oauth2 v0.21.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.20.0 // indirect golang.org/x/text v0.15.0 // indirect golang.org/x/time v0.5.0 // indirect - google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240506185236-b8a5c65736ae // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240506185236-b8a5c65736ae // indirect - google.golang.org/grpc v1.63.2 // indirect + google.golang.org/genproto v0.0.0-20240528184218-531527333157 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect + google.golang.org/grpc v1.64.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go/go.sum b/go/go.sum index 7667d35852..ea68f9f6e6 100644 --- a/go/go.sum +++ b/go/go.sum @@ -1,28 +1,28 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.113.0 h1:g3C70mn3lWfckKBiCVsAshabrDg01pQ0pnX1MNtnMkA= -cloud.google.com/go v0.113.0/go.mod h1:glEqlogERKYeePz6ZdkcLJ28Q2I6aERgDDErBg9GzO8= +cloud.google.com/go v0.114.0 h1:OIPFAdfrFDFO2ve2U7r/H5SwSbBzEdrBdE7xkgwc+kY= +cloud.google.com/go v0.114.0/go.mod h1:ZV9La5YYxctro1HTPug5lXH/GefROyW8PPD4T8n9J8E= cloud.google.com/go/ai v0.5.0 h1:x8s4rDn5t9OVZvBCgtr5bZTH5X0O7JdE6zYo+O+MpRw= cloud.google.com/go/ai v0.5.0/go.mod h1:96VBphk70e0zdXZrbtgPuKYRZsQ3UktSUXhuojwiKA8= -cloud.google.com/go/aiplatform v1.67.0 h1:YWeqD4BjYwrmY4fa+isGcw0P81lJ3dKVxbWxdBchoiU= -cloud.google.com/go/aiplatform v1.67.0/go.mod h1:s/sJ6btBEr6bKnrNWdK9ZgHCvwbZNdP90b3DDtxxw+Y= -cloud.google.com/go/auth v0.4.1 h1:Z7YNIhlWRtrnKlZke7z3GMqzvuYzdc2z98F9D1NV5Hg= -cloud.google.com/go/auth v0.4.1/go.mod h1:QVBuVEKpCn4Zp58hzRGvL0tjRGU0YqdRTdCHM1IHnro= +cloud.google.com/go/aiplatform v1.68.0 h1:EPPqgHDJpBZKRvv+OsB3cr0jYz3EL2pZ+802rBPcG8U= +cloud.google.com/go/aiplatform v1.68.0/go.mod h1:105MFA3svHjC3Oazl7yjXAmIR89LKhRAeNdnDKJczME= +cloud.google.com/go/auth v0.5.1 h1:0QNO7VThG54LUzKiQxv8C6x1YX7lUrzlAa1nVLF8CIw= +cloud.google.com/go/auth v0.5.1/go.mod h1:vbZT8GjzDf3AVqCcQmqeeM32U9HBFc32vVVAbwDsa6s= cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -cloud.google.com/go/iam v1.1.7 h1:z4VHOhwKLF/+UYXAJDFwGtNF0b6gjsW1Pk9Ml0U/IoM= -cloud.google.com/go/iam v1.1.7/go.mod h1:J4PMPg8TtyurAUvSmPj8FF3EDgY1SPRZxcUGrn7WXGA= -cloud.google.com/go/logging v1.9.0 h1:iEIOXFO9EmSiTjDmfpbRjOxECO7R8C7b8IXUGOj7xZw= -cloud.google.com/go/logging v1.9.0/go.mod h1:1Io0vnZv4onoUnsVUQY3HZ3Igb1nBchky0A0y7BBBhE= +cloud.google.com/go/iam v1.1.8 h1:r7umDwhj+BQyz0ScZMp4QrGXjSTI3ZINnpgU2nlB/K0= +cloud.google.com/go/iam v1.1.8/go.mod h1:GvE6lyMmfxXauzNq8NbgJbeVQNspG+tcdL/W8QO1+zE= +cloud.google.com/go/logging v1.10.0 h1:f+ZXMqyrSJ5vZ5pE/zr0xC8y/M9BLNzQeLBwfeZ+wY4= +cloud.google.com/go/logging v1.10.0/go.mod h1:EHOwcxlltJrYGqMGfghSet736KR3hX1MAj614mrMk9I= cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU= cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng= -cloud.google.com/go/monitoring v1.18.1 h1:0yvFXK+xQd95VKo6thndjwnJMno7c7Xw1CwMByg0B+8= -cloud.google.com/go/monitoring v1.18.1/go.mod h1:52hTzJ5XOUMRm7jYi7928aEdVxBEmGwA0EjNJXIBvt8= -cloud.google.com/go/trace v1.10.6 h1:XF0Ejdw0NpRfAvuZUeQe3ClAG4R/9w5JYICo7l2weaw= -cloud.google.com/go/trace v1.10.6/go.mod h1:EABXagUjxGuKcZMy4pXyz0fJpE5Ghog3jzTxcEsVJS4= -cloud.google.com/go/vertexai v0.10.0 h1:k157bLrtyajGtAAZnqdEn8lwFlUTG3BgHc7kvWbP/3s= -cloud.google.com/go/vertexai v0.10.0/go.mod h1:w/Zb22QvOVvxx5CGM4fPzH3WA6gwUkId9juA7pigzFI= +cloud.google.com/go/monitoring v1.19.0 h1:NCXf8hfQi+Kmr56QJezXRZ6GPb80ZI7El1XztyUuLQI= +cloud.google.com/go/monitoring v1.19.0/go.mod h1:25IeMR5cQ5BoZ8j1eogHE5VPJLlReQ7zFp5OiLgiGZw= +cloud.google.com/go/trace v1.10.7 h1:gK8z2BIJQ3KIYGddw9RJLne5Fx0FEXkrEQzPaeEYVvk= +cloud.google.com/go/trace v1.10.7/go.mod h1:qk3eiKmZX0ar2dzIJN/3QhY2PIFh1eqcIdaN5uEjQPM= +cloud.google.com/go/vertexai v0.12.0 h1:zTadEo/CtsoyRXNx3uGCncoWAP1H2HakGqwznt+iMo8= +cloud.google.com/go/vertexai v0.12.0/go.mod h1:8u+d0TsvBfAAd2x5R6GMgbYhsLgo3J7lmP4bR8g2ig8= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0 h1:n3T26hyfDl9RdgcOjWvOFMh1lCBNuZ0JQ/3DM5pou8Y= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0/go.mod h1:3S7qK2nHOO2cLID3xk6H8f55D38XswhVFzKEk0nqIbY= @@ -159,8 +159,8 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo= -golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= +golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -184,26 +184,26 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3 golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.180.0 h1:M2D87Yo0rGBPWpo1orwfCLehUUL6E7/TYe5gvMQWDh4= -google.golang.org/api v0.180.0/go.mod h1:51AiyoEg1MJPSZ9zvklA8VnRILPXxn1iVen9v25XHAE= +google.golang.org/api v0.183.0 h1:PNMeRDwo1pJdgNcFQ9GstuLe/noWKIc89pRWRLMvLwE= +google.golang.org/api v0.183.0/go.mod h1:q43adC5/pHoSZTx5h2mSmdF7NcyfW9JuDyIOJAgS9ZQ= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda h1:wu/KJm9KJwpfHWhkkZGohVC6KRrc1oJNr4jwtQMOQXw= -google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda/go.mod h1:g2LLCvCeCSir/JJSWosk19BR4NVxGqHUC6rxIRsd7Aw= -google.golang.org/genproto/googleapis/api v0.0.0-20240506185236-b8a5c65736ae h1:AH34z6WAGVNkllnKs5raNq3yRq93VnjBG6rpfub/jYk= -google.golang.org/genproto/googleapis/api v0.0.0-20240506185236-b8a5c65736ae/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240506185236-b8a5c65736ae h1:c55+MER4zkBS14uJhSZMGGmya0yJx5iHV4x/fpOSNRk= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240506185236-b8a5c65736ae/go.mod h1:I7Y+G38R2bu5j1aLzfFmQfTcU/WnFuqDwLZAbvKTKpM= +google.golang.org/genproto v0.0.0-20240528184218-531527333157 h1:u7WMYrIrVvs0TF5yaKwKNbcJyySYf+HAIFXxWltJOXE= +google.golang.org/genproto v0.0.0-20240528184218-531527333157/go.mod h1:ubQlAQnzejB8uZzszhrTCU2Fyp6Vi7ZE5nn0c3W8+qQ= +google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 h1:+rdxYoE3E5htTEWIe15GlN6IfvbURM//Jt0mmkmm6ZU= +google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117/go.mod h1:OimBR/bc1wPO9iV4NC2bpyjy3VnAwZh5EBPQdtaE5oo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 h1:1GBuWVLM/KMVUv1t1En5Gs+gFZCNd360GGb4sSxtrhU= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM= -google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= +google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= +google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index e483e1d609..8fb10466db 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -18,7 +18,6 @@ import ( "context" "errors" "fmt" - "reflect" "runtime" aiplatform "cloud.google.com/go/aiplatform/apiv1" @@ -225,11 +224,12 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb // Streaming version. iter := cs.SendMessageStream(ctx, parts...) - r := &ai.GenerateResponse{Request: input, Candidates: make([]*ai.Candidate, input.Candidates)} + var r *ai.GenerateResponse for { chunk, err := iter.Next() if err != nil { if err == iterator.Done { + r = translateResponse(iter.MergedResponse()) break } return nil, err @@ -247,42 +247,14 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb if err != nil { return nil, err } - // Save candidate in full response structure. - if old := r.Candidates[tc.Index]; old == nil { - r.Candidates[tc.Index] = tc - } else { - // Need to merge two "parts" of a candidate. - // Currently, we: - // - append the Message content - // - merge the FinishReason - // - assert everything else is unchanged - // (We do that 3rd step first.) - c1 := *r.Candidates[tc.Index] - c2 := *tc - m1 := *c1.Message - m2 := *c2.Message - c1.Message = &m1 - c2.Message = &m2 - m1.Content = nil - m2.Content = nil - c1.FinishReason = ai.FinishReasonUnknown - c2.FinishReason = ai.FinishReasonUnknown - if !reflect.DeepEqual(&c1, &c2) { - return nil, fmt.Errorf("some candidate fields unexpectedly changed\n%#v\n%#v", c1, c2) - } - - // Append the Parts to the final candidate. - old.Message.Content = append(old.Message.Content, tc.Message.Content...) - // Merge the FinishReasons. - if old.FinishReason == ai.FinishReasonUnknown { - old.FinishReason = tc.FinishReason - } else if old.FinishReason != tc.FinishReason { - return nil, fmt.Errorf("invalid finish reason transition: %s to %s", old.FinishReason, tc.FinishReason) - } - } } - // TODO: use chunk.PromptFeedback, chunk.UsageMetadata } + if r == nil { + // No candidates were returned. Probably rare, but it might avoid a NPE + // to return an empty instead of nil result. + r = &ai.GenerateResponse{} + } + r.Request = input return r, nil } From e3fcb9077f5c067013c938b648fed4b872a45416 Mon Sep 17 00:00:00 2001 From: Keith Randall Date: Thu, 13 Jun 2024 13:08:51 -0700 Subject: [PATCH 4/4] small cleanups --- go/plugins/vertexai/vertexai.go | 35 +++++++++++++++------------------ 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 8fb10466db..e538c7d614 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -277,25 +277,23 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate { c.FinishReason = ai.FinishReasonUnknown } m := &ai.Message{} - if cand.Content != nil { - m.Role = ai.Role(cand.Content.Role) - for _, part := range cand.Content.Parts { - var p *ai.Part - switch part := part.(type) { - case genai.Text: - p = ai.NewTextPart(string(part)) - case genai.Blob: - p = ai.NewMediaPart(part.MIMEType, string(part.Data)) - case genai.FunctionCall: - p = ai.NewToolRequestPart(&ai.ToolRequest{ - Name: part.Name, - Input: part.Args, - }) - default: - panic(fmt.Sprintf("unknown part %#v", part)) - } - m.Content = append(m.Content, p) + m.Role = ai.Role(cand.Content.Role) + for _, part := range cand.Content.Parts { + var p *ai.Part + switch part := part.(type) { + case genai.Text: + p = ai.NewTextPart(string(part)) + case genai.Blob: + p = ai.NewMediaPart(part.MIMEType, string(part.Data)) + case genai.FunctionCall: + p = ai.NewToolRequestPart(&ai.ToolRequest{ + Name: part.Name, + Input: part.Args, + }) + default: + panic(fmt.Sprintf("unknown part %#v", part)) } + m.Content = append(m.Content, p) } c.Message = m return c @@ -303,7 +301,6 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate { // Translate from a genai.GenerateContentResponse to a ai.GenerateResponse. func translateResponse(resp *genai.GenerateContentResponse) *ai.GenerateResponse { - // Note: this path doesn't get used when streaming. r := &ai.GenerateResponse{} for _, c := range resp.Candidates { r.Candidates = append(r.Candidates, translateCandidate(c))