diff --git a/go/ai/gen.go b/go/ai/gen.go index 5d24d51bf0..b6b9a62c43 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -271,7 +271,7 @@ type ModelResponseChunk struct { Aggregated bool `json:"aggregated,omitempty"` Content []*Part `json:"content,omitempty"` Custom any `json:"custom,omitempty"` - Index int `json:"index,omitempty"` + Index int `json:"index"` Role Role `json:"role,omitempty"` } diff --git a/go/ai/generate_test.go b/go/ai/generate_test.go index 95299f1852..b39778a231 100644 --- a/go/ai/generate_test.go +++ b/go/ai/generate_test.go @@ -98,7 +98,8 @@ func TestStreamingChunksHaveRoleAndIndex(t *testing.T) { From string To string Temperature float64 - }) (float64, error) { + }, + ) (float64, error) { if input.From == "celsius" && input.To == "fahrenheit" { return input.Temperature*9/5 + 32, nil } diff --git a/go/go.mod b/go/go.mod index cbde46d01c..70a8eb0a54 100644 --- a/go/go.mod +++ b/go/go.mod @@ -41,7 +41,7 @@ require ( golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 golang.org/x/tools v0.34.0 google.golang.org/api v0.236.0 - google.golang.org/genai v1.30.0 + google.golang.org/genai v1.36.0 ) require ( diff --git a/go/go.sum b/go/go.sum index 7100070ee1..f528809d99 100644 --- a/go/go.sum +++ b/go/go.sum @@ -537,8 +537,8 @@ google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9Ywl google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine/v2 v2.0.6 h1:LvPZLGuchSBslPBp+LAhihBeGSiRh1myRoYK4NtuBIw= google.golang.org/appengine/v2 v2.0.6/go.mod h1:WoEXGoXNfa0mLvaH5sV3ZSGXwVmy8yf7Z1JKf3J3wLI= -google.golang.org/genai v1.30.0 h1:7021aneIvl24nEBLbtQFEWleHsMbjzpcQvkT4WcJ1dc= -google.golang.org/genai v1.30.0/go.mod h1:7pAilaICJlQBonjKKJNhftDFv3SREhZcTe9F6nRcjbg= +google.golang.org/genai v1.36.0 h1:sJCIjqTAmwrtAIaemtTiKkg2TO1RxnYEusTmEQ3nGxM= +google.golang.org/genai v1.36.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= diff --git a/go/internal/cmd/jsonschemagen/jsonschemagen.go b/go/internal/cmd/jsonschemagen/jsonschemagen.go index 6714c0dce8..71daf886a9 100644 --- a/go/internal/cmd/jsonschemagen/jsonschemagen.go +++ b/go/internal/cmd/jsonschemagen/jsonschemagen.go @@ -42,6 +42,14 @@ var ( outputDir = flag.String("outdir", "", "directory to write to, or '-' for stdout") noFormat = flag.Bool("nofmt", false, "do not format output") configFile = flag.String("config", "", "config filename") + + // fieldOmitEmptyTag maps schemas (e.g., "ModelResponseChunk") to fields (e.g., "index") + // that should not receive the `omitempty` JSON tag. + fieldOmitEmptyTag = map[string]map[string]struct{}{ + "ModelResponseChunk": { + "index": {}, // fields should be as defined in core/schemas.config + }, + } ) func main() { @@ -241,7 +249,6 @@ func nameAnonymousTypes(schemas map[string]*Schema) { nameFields(prefix+fname, fs.Properties) } } - } for typeName, ts := range schemas { nameFields(typeName, ts.Properties) @@ -407,13 +414,28 @@ func (g *generator) generateStruct(name string, s *Schema, tcfg *itemConfig) err } } g.generateDoc(fs, fcfg) + jsonTag := fmt.Sprintf(`json:"%s,omitempty"`, field) + if skipOmitEmpty(goName, field) { + jsonTag = fmt.Sprintf(`json:"%s"`, field) + } g.pr(fmt.Sprintf(" %s %s `%s`\n", adjustIdentifier(field), typeExpr, jsonTag)) } g.pr("}\n\n") return nil } +// skipOmitEmpty determines whether a schema field should include the +// `omitempty` JSON tag +func skipOmitEmpty(schema, field string) bool { + fields, ok := fieldOmitEmptyTag[schema] + if !ok { + return false + } + _, ok = fields[field] + return ok +} + func (g *generator) generateStringEnum(name string, s *Schema, tcfg *itemConfig) error { g.generateDoc(s, tcfg) goName := tcfg.name diff --git a/go/internal/cmd/jsonschemagen/jsonschemagen_test.go b/go/internal/cmd/jsonschemagen/jsonschemagen_test.go index 09360fc3cc..d903481fa2 100644 --- a/go/internal/cmd/jsonschemagen/jsonschemagen_test.go +++ b/go/internal/cmd/jsonschemagen/jsonschemagen_test.go @@ -57,3 +57,53 @@ func Test(t *testing.T) { } } } + +func TestSkipOmitEmpty(t *testing.T) { + tests := []struct { + name string + schema string + field string + expected bool + }{ + { + name: "ChunkIndexOK", + schema: "ModelResponseChunk", + field: "index", + expected: true, + }, + { + name: "ChunkNoIndex", + schema: "ModelResponseChunk", + field: "text", + expected: false, + }, + { + name: "NotChunkSchema", + schema: "RequestHeader", + field: "ID", + expected: false, + }, + { + name: "ChunkNoField", + schema: "ModelResponseChunk", + field: "", + expected: false, + }, + { + name: "EmptySchema", + schema: "", + field: "index", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := skipOmitEmpty(tt.schema, tt.field) + if actual != tt.expected { + t.Errorf("skipOmitEmpty(schema: %q, field: %q) = %v, want %v", + tt.schema, tt.field, actual, tt.expected) + } + }) + } +} diff --git a/go/plugins/googlegenai/gemini.go b/go/plugins/googlegenai/gemini.go index ee430a697f..0f018edf13 100644 --- a/go/plugins/googlegenai/gemini.go +++ b/go/plugins/googlegenai/gemini.go @@ -290,11 +290,12 @@ func generate( // Streaming version. iter := client.Models.GenerateContentStream(ctx, model, contents, gcc) + var r *ai.ModelResponse + var genaiResp *genai.GenerateContentResponse - // merge all streamed responses - var resp *genai.GenerateContentResponse - var chunks []*genai.Part + genaiParts := []*genai.Part{} + chunks := []*ai.Part{} for chunk, err := range iter { // abort stream if error found in the iterator items if err != nil { @@ -307,27 +308,38 @@ func generate( } err = cb(ctx, &ai.ModelResponseChunk{ Content: tc.Message.Content, + Role: ai.RoleModel, }) if err != nil { return nil, err } - chunks = append(chunks, c.Content.Parts...) + genaiParts = append(genaiParts, c.Content.Parts...) + chunks = append(chunks, tc.Message.Content...) } - // keep the last chunk for usage metadata - resp = chunk + genaiResp = chunk + } - // manually merge all candidate responses, iterator does not provide a - // merged response utility + if len(genaiResp.Candidates) == 0 { + return nil, fmt.Errorf("no valid candidates found") + } + + // preserve original parts since they will be included in the + // "custom" response field merged := []*genai.Candidate{ { + FinishReason: genaiResp.Candidates[0].FinishReason, Content: &genai.Content{ - Parts: chunks, + Role: string(ai.RoleModel), + Parts: genaiParts, }, }, } - resp.Candidates = merged - r, err = translateResponse(resp) + + genaiResp.Candidates = merged + r, err = translateResponse(genaiResp) + r.Message.Content = chunks + if err != nil { return nil, fmt.Errorf("failed to generate contents: %w", err) } @@ -715,16 +727,14 @@ func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { m.FinishReason = ai.FinishReasonBlocked case genai.FinishReasonOther: m.FinishReason = ai.FinishReasonOther - default: // Unspecified - m.FinishReason = ai.FinishReasonUnknown } + m.FinishMessage = cand.FinishMessage if cand.Content == nil { return nil, fmt.Errorf("no valid candidates were found in the generate response") } msg := &ai.Message{} msg.Role = ai.Role(cand.Content.Role) - // iterate over the candidate parts, only one struct member // must be populated, more than one is considered an error for _, part := range cand.Content.Parts { @@ -799,13 +809,20 @@ func translateResponse(resp *genai.GenerateContentResponse) (*ai.ModelResponse, r.Usage = &ai.GenerationUsage{} } + // populate "custom" with plugin custom information + custom := make(map[string]any) + custom["candidates"] = resp.Candidates + if u := resp.UsageMetadata; u != nil { r.Usage.InputTokens = int(u.PromptTokenCount) r.Usage.OutputTokens = int(u.CandidatesTokenCount) r.Usage.TotalTokens = int(u.TotalTokenCount) r.Usage.CachedContentTokens = int(u.CachedContentTokenCount) r.Usage.ThoughtsTokens = int(u.ThoughtsTokenCount) + custom["usageMetadata"] = resp.UsageMetadata } + + r.Custom = custom return r, nil } diff --git a/go/plugins/googlegenai/googleai_live_test.go b/go/plugins/googlegenai/googleai_live_test.go index 2ebe49e114..6f63ba42f6 100644 --- a/go/plugins/googlegenai/googleai_live_test.go +++ b/go/plugins/googlegenai/googleai_live_test.go @@ -176,6 +176,34 @@ func TestGoogleAILive(t *testing.T) { t.Errorf("got %q, expecting it to contain %q", out, want) } }) + t.Run("tool stream", func(t *testing.T) { + parts := 0 + out := "" + final, err := genkit.Generate(ctx, g, + ai.WithPrompt("what is a gablorken of 2 over 3.5?"), + ai.WithTools(gablorkenTool), + ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error { + parts++ + out += c.Content[0].Text + return nil + })) + if err != nil { + t.Fatal(err) + } + out2 := "" + for _, p := range final.Message.Content { + out2 += p.Text + } + if out != out2 { + t.Errorf("streaming and final should contain the same text.\nstreaming:%s\nfinal:%s", out, out2) + } + + const want = "11.31" + if !strings.Contains(final.Text(), want) { + t.Errorf("got %q, expecting it to contain %q", out, want) + } + }) + t.Run("tool with thinking", func(t *testing.T) { m := googlegenai.GoogleAIModel(g, "gemini-2.5-flash") resp, err := genkit.Generate(ctx, g,