diff --git a/go/ai/document.go b/go/ai/document.go index f63cb91010..89c3ff897e 100644 --- a/go/ai/document.go +++ b/go/ai/document.go @@ -16,6 +16,7 @@ package ai import ( "encoding/json" + "fmt" ) // A Document is a piece of data that can be embedded, indexed, or retrieved. @@ -30,21 +31,64 @@ type Document struct { // A Part is one part of a [Document]. This may be plain text or it // may be a URL (possibly a "data:" URL with embedded data). type Part struct { - isText bool - contentType string - text string + kind partKind + contentType string // valid for kind==blob + text string // valid for kind∈{text,blob} + toolRequest *ToolRequest // valid for kind==partToolRequest + toolResponse *ToolResponse // valid for kind==partToolResponse } +type partKind int8 + +const ( + partText partKind = iota + partBlob + partToolRequest + partToolResponse +) + +// NewTextPart returns a Part containing raw string data. func NewTextPart(text string) *Part { - return &Part{isText: true, text: text} + return &Part{kind: partText, text: text} } + +// NewBlobPart returns a Part containing structured data described +// by the given mimeType. func NewBlobPart(mimeType, contents string) *Part { - return &Part{isText: false, contentType: mimeType, text: contents} + return &Part{kind: partBlob, contentType: mimeType, text: contents} +} + +// NewToolRequestPart returns a Part containing a request from +// the model to the client to run a Tool. +// (Only genkit plugins should need to use this function.) +func NewToolRequestPart(r *ToolRequest) *Part { + return &Part{kind: partToolRequest, toolRequest: r} +} + +// NewToolResponsePart returns a Part containing the results +// of applying a Tool that the model requested. +func NewToolResponsePart(r *ToolResponse) *Part { + return &Part{kind: partToolResponse, toolResponse: r} } // IsText reports whether the [Part] contains plain text. func (p *Part) IsPlainText() bool { - return p.isText + return p.kind == partText +} + +// IsBlob reports whether the [Part] contains blob (non-plain-text) data. +func (p *Part) IsBlob() bool { + return p.kind == partBlob +} + +// IsToolRequest reports whether the [Part] contains a request to run a tool. +func (p *Part) IsToolRequest() bool { + return p.kind == partToolRequest +} + +// IsToolResponse reports whether the [Part] contains the result of running a tool. +func (p *Part) IsToolResponse() bool { + return p.kind == partToolResponse } // Text returns the text. This is either plain text or a URL. @@ -53,25 +97,38 @@ func (p *Part) Text() string { } // ContentType returns the type of the content. -// This is only interesting if IsText is false. +// This is only interesting if IsBlob() is true. func (p *Part) ContentType() string { - if p.isText { + if p.kind == partText { return "text/plain" } return p.contentType } +// ToolRequest returns a request from the model for the client to run a tool. +// Valid only if [IsToolRequest] is true. +func (p *Part) ToolRequest() *ToolRequest { + return p.toolRequest +} + +// ToolResponse returns the tool response the client is sending to the model. +// Valid only if [IsToolResponse] is true. +func (p *Part) ToolResponse() *ToolResponse { + return p.toolResponse +} + // MarshalJSON is called by the JSON marshaler to write out a Part. func (p *Part) MarshalJSON() ([]byte, error) { // This is not handled by the schema generator because // Part is defined in TypeScript as a union. - if p.isText { + switch p.kind { + case partText: v := textPart{ Text: p.text, } return json.Marshal(v) - } else { + case partBlob: v := mediaPart{ Media: &mediaPartMedia{ ContentType: p.contentType, @@ -79,6 +136,25 @@ func (p *Part) MarshalJSON() ([]byte, error) { }, } return json.Marshal(v) + case partToolRequest: + // TODO: make sure these types marshal/unmarshal nicely + // between Go and javascript. At the very least the + // field name needs to change (here and in UnmarshalJSON). + v := struct { + ToolReq *ToolRequest `json:"toolreq,omitempty"` + }{ + ToolReq: p.toolRequest, + } + return json.Marshal(v) + case partToolResponse: + v := struct { + ToolResp *ToolResponse `json:"toolresp,omitempty"` + }{ + ToolResp: p.toolResponse, + } + return json.Marshal(v) + default: + return nil, fmt.Errorf("invalid part kind %v", p.kind) } } @@ -88,20 +164,29 @@ func (p *Part) UnmarshalJSON(b []byte) error { // Part is defined in TypeScript as a union. var s struct { - Text string `json:"text,omitempty"` - Media *mediaPartMedia `json:"media,omitempty"` + Text string `json:"text,omitempty"` + Media *mediaPartMedia `json:"media,omitempty"` + ToolReq *ToolRequest `json:"toolreq,omitempty"` + ToolResp *ToolResponse `json:"toolresp,omitempty"` } if err := json.Unmarshal(b, &s); err != nil { return err } - if s.Media != nil { - p.isText = false + switch { + case s.Media != nil: + p.kind = partBlob p.text = s.Media.Url p.contentType = s.Media.ContentType - } else { - p.isText = true + case s.ToolReq != nil: + p.kind = partToolRequest + p.toolRequest = s.ToolReq + case s.ToolResp != nil: + p.kind = partToolResponse + p.toolResponse = s.ToolResp + default: + p.kind = partText p.text = s.Text p.contentType = "" } @@ -114,8 +199,8 @@ func DocumentFromText(text string, metadata map[string]any) *Document { return &Document{ Content: []*Part{ &Part{ - isText: true, - text: text, + kind: partText, + text: text, }, }, Metadata: metadata, diff --git a/go/ai/document_test.go b/go/ai/document_test.go index 55e98bf9fd..e41ac1e836 100644 --- a/go/ai/document_test.go +++ b/go/ai/document_test.go @@ -16,6 +16,7 @@ package ai import ( "encoding/json" + "reflect" "testing" "github.com/google/go-cmp/cmp" @@ -41,13 +42,27 @@ func TestDocumentJSON(t *testing.T) { d := Document{ Content: []*Part{ &Part{ - isText: true, + kind: partText, text: "hi", }, &Part{ - isText: false, + kind: partBlob, contentType: "text/plain", - text: "data:,bye", + text: "data:,bye", + }, + &Part{ + kind: partToolRequest, + toolRequest: &ToolRequest{ + Name: "tool1", + Input: map[string]any{"arg1": 3.3, "arg2": "foo"}, + }, + }, + &Part{ + kind: partToolResponse, + toolResponse: &ToolResponse{ + Name: "tool1", + Output: map[string]any{"res1": 4.4, "res2": "bar"}, + }, }, }, } @@ -56,6 +71,7 @@ func TestDocumentJSON(t *testing.T) { if err != nil { t.Fatal(err) } + t.Logf("marshaled:%s\n", string(b)) var d2 Document if err := json.Unmarshal(b, &d2); err != nil { @@ -63,13 +79,21 @@ func TestDocumentJSON(t *testing.T) { } cmpPart := func(a, b *Part) bool { - if a.isText != b.isText { + if a.kind != b.kind { return false } - if a.isText { + switch a.kind { + case partText: return a.text == b.text - } else { + case partBlob: return a.contentType == b.contentType && a.text == b.text + case partToolRequest: + return reflect.DeepEqual(a.toolRequest, b.toolRequest) + case partToolResponse: + return reflect.DeepEqual(a.toolResponse, b.toolResponse) + default: + t.Fatalf("bad part kind %v", a.kind) + return false } } diff --git a/go/ai/gen.go b/go/ai/gen.go index b3163b5c13..b69af845d2 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -135,22 +135,22 @@ type ToolDefinition struct { OutputSchema map[string]any `json:"outputSchema,omitempty"` } -type ToolRequestPart struct { - ToolRequest *ToolRequestPartToolRequest `json:"toolRequest,omitempty"` +// A ToolRequest is a request from the model that the client should run +// a specific tool and pass a [ToolResponse] to the model on the next request it makes. +// Any ToolRequest will correspond to some [ToolDefinition] previously sent by the client. +type ToolRequest struct { + // Input is a JSON object describing the input values to the tool. + // An example might be map[string]any{"country":"USA", "president":3}. + Input map[string]any `json:"input,omitempty"` + Name string `json:"name,omitempty"` } -type ToolRequestPartToolRequest struct { - Input any `json:"input,omitempty"` - Name string `json:"name,omitempty"` - Ref string `json:"ref,omitempty"` -} - -type ToolResponsePart struct { - ToolResponse *ToolResponsePartToolResponse `json:"toolResponse,omitempty"` -} - -type ToolResponsePartToolResponse struct { - Name string `json:"name,omitempty"` - Output any `json:"output,omitempty"` - Ref string `json:"ref,omitempty"` +// A ToolResponse is a response from the client to the model containing +// the results of running a specific tool on the arguments passed to the client +// by the model in a [ToolRequest]. +type ToolResponse struct { + Name string `json:"name,omitempty"` + // Output is a JSON object describing the results of running the tool. + // An example might be map[string]any{"name":"Thomas Jefferson", "born":1743}. + Output map[string]any `json:"output,omitempty"` } diff --git a/go/genkit/schemas.config b/go/genkit/schemas.config index 14f90a28cb..5f4b558290 100644 --- a/go/genkit/schemas.config +++ b/go/genkit/schemas.config @@ -81,6 +81,35 @@ RoleTool indicates this message was generated by a local tool, likely triggered from the model in one of its previous responses. . +ToolRequestPart omit +ToolRequestPartToolRequest name ToolRequest +ToolResponsePart omit +ToolResponsePartToolResponse name ToolResponse + +ToolRequestPartToolRequest.input type map[string]any +ToolRequestPartToolRequest.input doc +Input is a JSON object describing the input values to the tool. +An example might be map[string]any{"country":"USA", "president":3}. +. +ToolResponsePartToolResponse.output type map[string]any +ToolResponsePartToolResponse.output doc +Output is a JSON object describing the results of running the tool. +An example might be map[string]any{"name":"Thomas Jefferson", "born":1743}. +. +ToolRequestPartToolRequest.ref omit +ToolResponsePartToolResponse.ref omit + +ToolRequestPartToolRequest doc +A ToolRequest is a message from the model to the client that it should run a +specific tool and pass a [ToolResponse] to the model on the next chat request it makes. +Any ToolRequest will correspond to some [ToolDefinition] previously sent by the client. +. +ToolResponsePartToolResponse doc +A ToolResponse is a message from the client to the model containing +the results of running a specific tool on the arguments passed to the client +by the model in a [ToolRequest]. +. + Candidate pkg ai CandidateFinishReason pkg ai DocumentData pkg ai diff --git a/go/internal/cmd/jsonschemagen/jsonschemagen.go b/go/internal/cmd/jsonschemagen/jsonschemagen.go index 99b5d4e46c..cbaf2dc963 100644 --- a/go/internal/cmd/jsonschemagen/jsonschemagen.go +++ b/go/internal/cmd/jsonschemagen/jsonschemagen.go @@ -341,6 +341,9 @@ func (g *generator) generateStruct(name string, s *Schema, tcfg *itemConfig) err if fs.Not != nil { continue } + if fcfg.omit { + continue + } typeExpr := fcfg.typeExpr if typeExpr == "" { var err error diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 577c12841f..c5c55430c6 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -16,6 +16,7 @@ package googleai import ( "context" + "fmt" "github.com/google/generative-ai-go/genai" "github.com/google/genkit/go/ai" @@ -59,6 +60,7 @@ func NewEmbedder(ctx context.Context, model, apiKey string) (ai.Embedder, error) type generator struct { model string client *genai.Client + //session *genai.ChatSession // non-nil if we're in the middle of a chat } func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb genkit.StreamingCallback[*ai.Candidate]) (*ai.GenerateResponse, error) { @@ -92,7 +94,32 @@ func (g *generator) Generate(ctx context.Context, input *ai.GenerateRequest, cb if len(messages) > 0 { parts = convertParts(messages[0].Content) } - //TODO: convert input.Tools and append to gm.Tools + + // Convert input.Tools and append to gm.Tools + for _, t := range input.Tools { + schema := &genai.Schema{} + schema.Type = genai.TypeObject + schema.Properties = map[string]*genai.Schema{} + for k, v := range t.InputSchema { + typ := genai.TypeUnspecified + switch v { + case "string": + typ = genai.TypeString + case "float64": + typ = genai.TypeNumber + case "int": + typ = genai.TypeInteger + case "bool": + typ = genai.TypeBoolean + default: + return nil, fmt.Errorf("schema value \"%s\" not allowed", v) + } + schema.Properties[k] = &genai.Schema{Type: typ} + } + fd := &genai.FunctionDeclaration{Name: t.Name, Parameters: schema} + gm.Tools = append(gm.Tools, &genai.Tool{FunctionDeclarations: []*genai.FunctionDeclaration{fd}}) + } + // TODO: gm.ToolConfig? // Send out the actual request. if cb == nil { @@ -167,10 +194,13 @@ func translateCandidate(cand *genai.Candidate) *ai.Candidate { p = ai.NewTextPart(string(part)) case genai.Blob: p = ai.NewBlobPart(part.MIMEType, string(part.Data)) - case genai.FunctionResponse: - p = ai.NewBlobPart("TODO", string(part.Name)) + case genai.FunctionCall: + p = ai.NewToolRequestPart(&ai.ToolRequest{ + Name: part.Name, + Input: part.Args, + }) default: - panic("unknown part type") + panic(fmt.Sprintf("unknown part %#v", part)) } m.Content = append(m.Content, p) } @@ -235,7 +265,15 @@ func convertPart(p *ai.Part) genai.Part { switch { case p.IsPlainText(): return genai.Text(p.Text()) - default: + case p.IsBlob(): return genai.Blob{MIMEType: p.ContentType(), Data: []byte(p.Text())} + case p.IsToolResponse(): + toolResp := p.ToolResponse() + return genai.FunctionResponse{Name: toolResp.Name, Response: toolResp.Output} + case p.IsToolRequest(): + toolReq := p.ToolRequest() + return genai.FunctionCall{Name: toolReq.Name, Args: toolReq.Input} + default: + panic("unknown part type in a request") } } diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index 72c76deb9c..279865b531 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -17,6 +17,7 @@ package googleai_test import ( "context" "flag" + "math" "strings" "testing" @@ -84,6 +85,9 @@ func TestGenerator(t *testing.T) { if out != "France" { t.Errorf("got \"%s\", expecting \"France\"", out) } + if resp.Request != req { + t.Error("Request field not set properly") + } } func TestGeneratorStreaming(t *testing.T) { @@ -123,3 +127,75 @@ func TestGeneratorStreaming(t *testing.T) { t.Errorf("expecting more than one part") } } + +func TestGeneratorTool(t *testing.T) { + if *apiKey == "" { + t.Skipf("no -key provided") + } + ctx := context.Background() + g, err := googleai.NewGenerator(ctx, "gemini-1.0-pro", *apiKey) + if err != nil { + t.Fatal(err) + } + req := &ai.GenerateRequest{ + Candidates: 1, + Messages: []*ai.Message{ + &ai.Message{ + Content: []*ai.Part{ai.NewTextPart("what is 3.5 squared? Use the tool provided.")}, + Role: ai.RoleUser, + }, + }, + Tools: []*ai.ToolDefinition{ + &ai.ToolDefinition{ + Name: "exponentiation", + InputSchema: map[string]any{"base": "float64", "exponent": "int"}, + OutputSchema: map[string]any{"output": "float64"}, + }, + }, + } + + resp, err := g.Generate(ctx, req, nil) + if err != nil { + t.Fatal(err) + } + p := resp.Candidates[0].Message.Content[0] + if !p.IsToolRequest() { + t.Fatalf("tool not requested") + } + toolReq := p.ToolRequest() + if toolReq.Name != "exponentiation" { + t.Errorf("tool name is %q, want \"exponentiation\"", toolReq.Name) + } + if toolReq.Input["base"] != 3.5 { + t.Errorf("base is %f, want 3.5", toolReq.Input["base"]) + } + if toolReq.Input["exponent"] != 2 && toolReq.Input["exponent"] != 2.0 { + // Note: 2.0 is wrong given the schema, but Gemini returns a float anyway. + t.Errorf("exponent is %f, want 2", toolReq.Input["exponent"]) + } + + // Update our conversation with the tool request the model made and our tool response. + // (Our "tool" is just math.Pow.) + req.Messages = append(req.Messages, + resp.Candidates[0].Message, + &ai.Message{ + Content: []*ai.Part{ai.NewToolResponsePart(&ai.ToolResponse{ + Name: "exponentiation", + Output: map[string]any{"output": math.Pow(3.5, 2)}, + })}, + Role: ai.RoleTool, + }, + ) + + // Issue our request again. + resp, err = g.Generate(ctx, req, nil) + if err != nil { + t.Fatal(err) + } + + // Check final response. + out := resp.Candidates[0].Message.Content[0].Text() + if !strings.Contains(out, "12.25") { + t.Errorf("got %s, expecting it to contain \"12.25\"", out) + } +}