diff --git a/go/ai/document.go b/go/ai/document.go index eb40f37173..a54784b257 100644 --- a/go/ai/document.go +++ b/go/ai/document.go @@ -31,101 +31,80 @@ type Document struct { // A Part is one part of a [Document]. This may be plain text or it // may be a URL (possibly a "data:" URL with embedded data). type Part struct { - kind partKind - contentType string // valid for kind==blob - text string // valid for kind∈{text,blob} - toolRequest *ToolRequest // valid for kind==partToolRequest - toolResponse *ToolResponse // valid for kind==partToolResponse + Kind PartKind `json:"kind,omitempty"` + ContentType string `json:"contentType,omitempty"` // valid for kind==blob + Text string `json:"text,omitempty"` // valid for kind∈{text,blob} + ToolRequest *ToolRequest `json:"toolreq,omitempty"` // valid for kind==partToolRequest + ToolResponse *ToolResponse `json:"toolresp,omitempty"` // valid for kind==partToolResponse } -type partKind int8 +type PartKind int8 const ( - partText partKind = iota - partMedia - partData - partToolRequest - partToolResponse + PartText PartKind = iota + PartMedia + PartData + PartToolRequest + PartToolResponse ) // NewTextPart returns a Part containing text. func NewTextPart(text string) *Part { - return &Part{kind: partText, text: text} + return &Part{Kind: PartText, ContentType: "plain/text", Text: text} +} + +// NewJSONPart returns a Part containing JSON. +func NewJSONPart(text string) *Part { + return &Part{Kind: PartText, ContentType: "application/json", Text: text} } // NewMediaPart returns a Part containing structured data described // by the given mimeType. func NewMediaPart(mimeType, contents string) *Part { - return &Part{kind: partMedia, contentType: mimeType, text: contents} + return &Part{Kind: PartMedia, ContentType: mimeType, Text: contents} } // NewDataPart returns a Part containing raw string data. func NewDataPart(contents string) *Part { - return &Part{kind: partData, text: contents} + return &Part{Kind: PartData, Text: contents} } // NewToolRequestPart returns a Part containing a request from // the model to the client to run a Tool. // (Only genkit plugins should need to use this function.) func NewToolRequestPart(r *ToolRequest) *Part { - return &Part{kind: partToolRequest, toolRequest: r} + return &Part{Kind: PartToolRequest, ToolRequest: r} } // NewToolResponsePart returns a Part containing the results // of applying a Tool that the model requested. func NewToolResponsePart(r *ToolResponse) *Part { - return &Part{kind: partToolResponse, toolResponse: r} + return &Part{Kind: PartToolResponse, ToolResponse: r} } // IsText reports whether the [Part] contains plain text. func (p *Part) IsText() bool { - return p.kind == partText + return p.Kind == PartText } // IsMedia reports whether the [Part] contains structured media data. func (p *Part) IsMedia() bool { - return p.kind == partMedia + return p.Kind == PartMedia } // IsData reports whether the [Part] contains unstructured data. func (p *Part) IsData() bool { - return p.kind == partData + return p.Kind == PartData } // IsToolRequest reports whether the [Part] contains a request to run a tool. func (p *Part) IsToolRequest() bool { - return p.kind == partToolRequest + return p.Kind == PartToolRequest } // IsToolResponse reports whether the [Part] contains the result of running a tool. func (p *Part) IsToolResponse() bool { - return p.kind == partToolResponse -} - -// Text returns the text. This is either plain text or a URL. -func (p *Part) Text() string { - return p.text -} - -// ContentType returns the type of the content. -// This is only interesting if IsBlob() is true. -func (p *Part) ContentType() string { - if p.kind == partText { - return "text/plain" - } - return p.contentType -} - -// ToolRequest returns a request from the model for the client to run a tool. -// Valid only if [IsToolRequest] is true. -func (p *Part) ToolRequest() *ToolRequest { - return p.toolRequest -} - -// ToolResponse returns the tool response the client is sending to the model. -// Valid only if [IsToolResponse] is true. -func (p *Part) ToolResponse() *ToolResponse { - return p.toolResponse + return p.Kind == PartToolResponse } // MarshalJSON is called by the JSON marshaler to write out a Part. @@ -133,44 +112,44 @@ func (p *Part) MarshalJSON() ([]byte, error) { // This is not handled by the schema generator because // Part is defined in TypeScript as a union. - switch p.kind { - case partText: + switch p.Kind { + case PartText: v := textPart{ - Text: p.text, + Text: p.Text, } return json.Marshal(v) - case partMedia: + case PartMedia: v := mediaPart{ Media: &mediaPartMedia{ - ContentType: p.contentType, - Url: p.text, + ContentType: p.ContentType, + Url: p.Text, }, } return json.Marshal(v) - case partData: + case PartData: v := dataPart{ - Data: p.text, + Data: p.Text, } return json.Marshal(v) - case partToolRequest: + case PartToolRequest: // TODO: make sure these types marshal/unmarshal nicely // between Go and javascript. At the very least the // field name needs to change (here and in UnmarshalJSON). v := struct { ToolReq *ToolRequest `json:"toolreq,omitempty"` }{ - ToolReq: p.toolRequest, + ToolReq: p.ToolRequest, } return json.Marshal(v) - case partToolResponse: + case PartToolResponse: v := struct { ToolResp *ToolResponse `json:"toolresp,omitempty"` }{ - ToolResp: p.toolResponse, + ToolResp: p.ToolResponse, } return json.Marshal(v) default: - return nil, fmt.Errorf("invalid part kind %v", p.kind) + return nil, fmt.Errorf("invalid part kind %v", p.Kind) } } @@ -193,23 +172,23 @@ func (p *Part) UnmarshalJSON(b []byte) error { switch { case s.Media != nil: - p.kind = partMedia - p.text = s.Media.Url - p.contentType = s.Media.ContentType + p.Kind = PartMedia + p.Text = s.Media.Url + p.ContentType = s.Media.ContentType case s.ToolReq != nil: - p.kind = partToolRequest - p.toolRequest = s.ToolReq + p.Kind = PartToolRequest + p.ToolRequest = s.ToolReq case s.ToolResp != nil: - p.kind = partToolResponse - p.toolResponse = s.ToolResp + p.Kind = PartToolResponse + p.ToolResponse = s.ToolResp default: - p.kind = partText - p.text = s.Text - p.contentType = "" + p.Kind = PartText + p.Text = s.Text + p.ContentType = "" if s.Data != "" { // Note: if part is completely empty, we use text by default. - p.kind = partData - p.text = s.Data + p.Kind = PartData + p.Text = s.Data } } return nil @@ -220,9 +199,9 @@ func (p *Part) UnmarshalJSON(b []byte) error { func DocumentFromText(text string, metadata map[string]any) *Document { return &Document{ Content: []*Part{ - &Part{ - kind: partText, - text: text, + { + Kind: PartText, + Text: text, }, }, Metadata: metadata, diff --git a/go/ai/document_test.go b/go/ai/document_test.go index bc30eb7f44..8d173fa2b3 100644 --- a/go/ai/document_test.go +++ b/go/ai/document_test.go @@ -32,7 +32,7 @@ func TestDocumentFromText(t *testing.T) { if !p.IsText() { t.Errorf("IsText() == %t, want %t", p.IsText(), true) } - if got := p.Text(); got != data { + if got := p.Text; got != data { t.Errorf("Data() == %q, want %q", got, data) } } @@ -42,28 +42,28 @@ func TestDocumentJSON(t *testing.T) { d := Document{ Content: []*Part{ &Part{ - kind: partText, - text: "hi", + Kind: PartText, + Text: "hi", }, &Part{ - kind: partMedia, - contentType: "text/plain", - text: "data:,bye", + Kind: PartMedia, + ContentType: "text/plain", + Text: "data:,bye", }, &Part{ - kind: partData, - text: "somedata\x00string", + Kind: PartData, + Text: "somedata\x00string", }, &Part{ - kind: partToolRequest, - toolRequest: &ToolRequest{ + Kind: PartToolRequest, + ToolRequest: &ToolRequest{ Name: "tool1", Input: map[string]any{"arg1": 3.3, "arg2": "foo"}, }, }, &Part{ - kind: partToolResponse, - toolResponse: &ToolResponse{ + Kind: PartToolResponse, + ToolResponse: &ToolResponse{ Name: "tool1", Output: map[string]any{"res1": 4.4, "res2": "bar"}, }, @@ -83,22 +83,22 @@ func TestDocumentJSON(t *testing.T) { } cmpPart := func(a, b *Part) bool { - if a.kind != b.kind { + if a.Kind != b.Kind { return false } - switch a.kind { - case partText: - return a.text == b.text - case partMedia: - return a.contentType == b.contentType && a.text == b.text - case partData: - return a.text == b.text - case partToolRequest: - return reflect.DeepEqual(a.toolRequest, b.toolRequest) - case partToolResponse: - return reflect.DeepEqual(a.toolResponse, b.toolResponse) + switch a.Kind { + case PartText: + return a.Text == b.Text + case PartMedia: + return a.ContentType == b.ContentType && a.Text == b.Text + case PartData: + return a.Text == b.Text + case PartToolRequest: + return reflect.DeepEqual(a.ToolRequest, b.ToolRequest) + case PartToolResponse: + return reflect.DeepEqual(a.ToolResponse, b.ToolResponse) default: - t.Fatalf("bad part kind %v", a.kind) + t.Fatalf("bad part kind %v", a.Kind) return false } } diff --git a/go/ai/generator.go b/go/ai/generator.go index cb6cacf0d4..748b79676e 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -16,12 +16,15 @@ package ai import ( "context" + "encoding/json" "errors" "fmt" "slices" + "strconv" "strings" "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/logger" ) // Generator is the interface used to query an AI model. @@ -70,14 +73,24 @@ func RegisterGenerator(provider, name string, metadata *GeneratorMetadata, gener } // Generate applies a [Generator] to some input, handling tool requests. -func Generate(ctx context.Context, g Generator, input *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { +func Generate(ctx context.Context, g Generator, req *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { + if err := conformOutput(req); err != nil { + return nil, err + } + for { - resp, err := g.Generate(ctx, input, cb) + resp, err := g.Generate(ctx, req, cb) if err != nil { return nil, err } - newReq, err := handleToolRequest(ctx, input, resp) + candidates, err := validCandidates(ctx, resp) + if err != nil { + return nil, err + } + resp.Candidates = candidates + + newReq, err := handleToolRequest(ctx, req, resp) if err != nil { return nil, err } @@ -85,7 +98,7 @@ func Generate(ctx context.Context, g Generator, input *GenerateRequest, cb func( return resp, nil } - input = newReq + req = newReq } } @@ -115,14 +128,24 @@ type generatorAction struct { // Generate implements Generator. This is like the [Generate] function, // but invokes the [core.Action] rather than invoking the Generator // directly. -func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { +func (ga *generatorAction) Generate(ctx context.Context, req *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { + if err := conformOutput(req); err != nil { + return nil, err + } + for { - resp, err := ga.action.Run(ctx, input, cb) + resp, err := ga.action.Run(ctx, req, cb) if err != nil { return nil, err } - newReq, err := handleToolRequest(ctx, input, resp) + candidates, err := validCandidates(ctx, resp) + if err != nil { + return nil, err + } + resp.Candidates = candidates + + newReq, err := handleToolRequest(ctx, req, resp) if err != nil { return nil, err } @@ -130,8 +153,83 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, return resp, nil } - input = newReq + req = newReq + } +} + +// conformOutput appends a message to the request indicating conformance to the expected schema. +func conformOutput(req *GenerateRequest) error { + if req.Output.Format == OutputFormatJSON && len(req.Messages) > 0 { + jsonBytes, err := json.Marshal(req.Output.Schema) + if err != nil { + return fmt.Errorf("expected schema is not valid: %w", err) + } + + escapedJSON := strconv.Quote(string(jsonBytes)) + part := NewTextPart(fmt.Sprintf("Output should be in JSON format and conform to the following schema:\n\n```%s```", escapedJSON)) + req.Messages[len(req.Messages)-1].Content = append(req.Messages[len(req.Messages)-1].Content, part) } + return nil +} + +// validCandidates finds all candidates that match the expected schema. +// It will strip JSON markdown delimiters from the response. +func validCandidates(ctx context.Context, resp *GenerateResponse) ([]*Candidate, error) { + var candidates []*Candidate + for i, c := range resp.Candidates { + c, err := validCandidate(c, resp.Request.Output) + if err == nil { + candidates = append(candidates, c) + } else { + logger.FromContext(ctx).Debug("candidate did not match expected schema", "index", i, "error", err.Error()) + } + } + if len(candidates) == 0 { + return nil, errors.New("generation resulted in no candidates matching expected schema") + } + return candidates, nil +} + +// validCandidate will validate the candidate's response against the expected schema. +// It will return an error if it does not match, otherwise it will return a candidate with JSON content and type. +func validCandidate(c *Candidate, output *GenerateRequestOutput) (*Candidate, error) { + if output.Format == OutputFormatJSON { + text, err := c.Text() + if err != nil { + return nil, err + } + text = stripJSONDelimiters(text) + var schemaBytes []byte + schemaBytes, err = json.Marshal(output.Schema) + if err != nil { + return nil, fmt.Errorf("expected schema is not valid: %w", err) + } + if err = core.ValidateRaw([]byte(text), schemaBytes); err != nil { + return nil, err + } + // TODO: Verify that it okay to replace all content with JSON. + c.Message.Content = []*Part{NewJSONPart(text)} + } + return c, nil +} + +// stripJSONDelimiters strips Markdown JSON delimiters that may come back in the response. +func stripJSONDelimiters(s string) string { + s = strings.TrimSpace(s) + delimiters := []string{"```", "~~~"} + for _, delimiter := range delimiters { + if strings.HasPrefix(s, delimiter) && strings.HasSuffix(s, delimiter) { + s = strings.TrimPrefix(s, delimiter) + s = strings.TrimSuffix(s, delimiter) + s = strings.TrimSpace(s) + if strings.HasPrefix(s, "json") { + s = strings.TrimPrefix(s, "json") + s = strings.TrimSpace(s) + } + break + } + } + return s } // handleToolRequest checks if a tool was requested by a generator. @@ -150,7 +248,7 @@ func handleToolRequest(ctx context.Context, req *GenerateRequest, resp *Generate return nil, nil } - toolReq := part.ToolRequest() + toolReq := part.ToolRequest output, err := RunTool(ctx, toolReq.Name, toolReq.Input) if err != nil { return nil, err @@ -180,19 +278,25 @@ func (gr *GenerateResponse) Text() (string, error) { if len(gr.Candidates) == 0 { return "", errors.New("no candidates returned") } - msg := gr.Candidates[0].Message + return gr.Candidates[0].Text() +} + +// Text returns the contents of a [Candidate] as a string. It +// returns an error if the candidate has no message. +func (c *Candidate) Text() (string, error) { + msg := c.Message if msg == nil { - return "", errors.New("candidate with no message") + return "", errors.New("candidate has no message") } if len(msg.Content) == 0 { return "", errors.New("candidate message has no content") } if len(msg.Content) == 1 { - return msg.Content[0].Text(), nil + return msg.Content[0].Text, nil } else { var sb strings.Builder for _, p := range msg.Content { - sb.WriteString(p.Text()) + sb.WriteString(p.Text) } return sb.String(), nil } diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go new file mode 100644 index 0000000000..01d1ab13a6 --- /dev/null +++ b/go/ai/generator_test.go @@ -0,0 +1,203 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ai + +import ( + "strings" + "testing" +) + +func TestValidCandidate(t *testing.T) { + t.Parallel() + + t.Run("Valid candidate with text format", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{ + Content: []*Part{ + NewTextPart("Hello, World!"), + }, + }, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatText, + } + _, err := validCandidate(candidate, outputSchema) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("Valid candidate with JSON format and matching schema", func(t *testing.T) { + json := `{ + "name": "John", + "age": 30, + "address": { + "street": "123 Main St", + "city": "New York", + "country": "USA" + } + }` + candidate := &Candidate{ + Message: &Message{ + Content: []*Part{ + NewTextPart(JSONMarkdown(json)), + }, + }, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + Schema: map[string]any{ + "type": "object", + "required": []string{"name", "age", "address"}, + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "integer"}, + "address": map[string]any{ + "type": "object", + "required": []string{"street", "city", "country"}, + "properties": map[string]any{ + "street": map[string]any{"type": "string"}, + "city": map[string]any{"type": "string"}, + "country": map[string]any{"type": "string"}, + }, + }, + "phone": map[string]any{"type": "string"}, + }, + }, + } + candidate, err := validCandidate(candidate, outputSchema) + if err != nil { + t.Fatal(err) + } + text, err := candidate.Text() + if err != nil { + t.Fatal(err) + } + if text != json { + t.Fatalf("got %q, want %q", json, text) + } + }) + + t.Run("Invalid candidate with JSON format and non-matching schema", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{ + Content: []*Part{ + NewTextPart(JSONMarkdown(`{"name": "John", "age": "30"}`)), + }, + }, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + Schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "integer"}, + }, + }, + } + _, err := validCandidate(candidate, outputSchema) + errorContains(t, err, "data did not match expected schema") + }) + + t.Run("Candidate with invalid JSON", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{ + Content: []*Part{ + NewTextPart(JSONMarkdown(`{"name": "John", "age": 30`)), // Missing trailing }. + }, + }, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + } + _, err := validCandidate(candidate, outputSchema) + errorContains(t, err, "data is not valid JSON") + }) + + t.Run("Candidate with no message", func(t *testing.T) { + candidate := &Candidate{} + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + } + _, err := validCandidate(candidate, outputSchema) + errorContains(t, err, "candidate has no message") + }) + + t.Run("Candidate with message but no content", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{}, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + } + _, err := validCandidate(candidate, outputSchema) + errorContains(t, err, "candidate message has no content") + }) + + t.Run("Candidate contains unexpected field", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{ + Content: []*Part{ + NewTextPart(JSONMarkdown(`{"name": "John", "height": 190}`)), + }, + }, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + Schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "integer"}, + }, + "additionalProperties": false, + }, + } + _, err := validCandidate(candidate, outputSchema) + errorContains(t, err, "data did not match expected schema") + }) + + t.Run("Invalid expected schema", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{ + Content: []*Part{ + NewTextPart(JSONMarkdown(`{"name": "John", "age": 30}`)), + }, + }, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + Schema: map[string]any{ + "type": "invalid", + }, + } + _, err := validCandidate(candidate, outputSchema) + errorContains(t, err, "failed to validate data against expected schema") + }) +} + +func JSONMarkdown(text string) string { + return "```json\n" + text + "\n```" +} + +func errorContains(t *testing.T, err error, want string) { + t.Helper() + if err == nil { + t.Error("got nil, want error") + } else if !strings.Contains(err.Error(), want) { + t.Errorf("got error message %q, want it to contain %q", err, want) + } +} diff --git a/go/core/action.go b/go/core/action.go index b4eec57667..10d2a886ec 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -19,7 +19,6 @@ import ( "encoding/json" "fmt" "maps" - "reflect" "time" "github.com/firebase/genkit/go/core/logger" @@ -101,8 +100,6 @@ func (a *Action[In, Out, Stream]) setTracingState(tstate *tracing.State) { a.tst // Run executes the Action's function in a new trace span. func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(context.Context, Stream) error) (output Out, err error) { - // TODO: validate input against JSONSchema for I. - // TODO: validate output against JSONSchema for O. logger.FromContext(ctx).Debug("Action.Run", "name", a.name, "input", fmt.Sprintf("%#v", input)) @@ -120,18 +117,34 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con return tracing.RunInNewSpan(ctx, tstate, a.name, "action", false, input, func(ctx context.Context, input In) (Out, error) { start := time.Now() - out, err := a.fn(ctx, input, cb) + var err error + if err = ValidateValue(input, a.inputSchema); err != nil { + err = fmt.Errorf("invalid input: %w", err) + } + var output Out + if err == nil { + output, err = a.fn(ctx, input, cb) + if err == nil { + if err = ValidateValue(output, a.outputSchema); err != nil { + err = fmt.Errorf("invalid output: %w", err) + } + } + } latency := time.Since(start) if err != nil { writeActionFailure(ctx, a.name, latency, err) return internal.Zero[Out](), err } writeActionSuccess(ctx, a.name, latency) - return out, nil + return output, nil }) } func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) { + // Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process. + if err := ValidateJSON(input, a.inputSchema); err != nil { + return nil, err + } var in In if err := json.Unmarshal(input, &in); err != nil { return nil, err @@ -212,16 +225,8 @@ func (a *Action[I, O, S]) desc() actionDesc { } func inferJSONSchema(x any) (s *jsonschema.Schema) { - var r jsonschema.Reflector - t := reflect.TypeOf(x) - if t.Kind() == reflect.Struct { - if t.NumField() == 0 { - // Make struct{} correspond to ZodVoid. - return &jsonschema.Schema{Type: "null"} - } - // Put a struct definition at the "top level" of the schema, - // instead of nested inside a "$defs" object. - r.ExpandedStruct = true + r := jsonschema.Reflector{ + DoNotReference: true, } s = r.Reflect(x) // TODO: Unwind this change once Monaco Editor supports newer than JSON schema draft-07. diff --git a/go/core/flow.go b/go/core/flow.go index c086b63936..25a38b0dc3 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -28,6 +28,7 @@ import ( "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal" "github.com/google/uuid" + "github.com/invopop/jsonschema" otrace "go.opentelemetry.io/otel/trace" ) @@ -88,10 +89,12 @@ import ( // A Flow[In, Out, Stream] represents a function from In to Out. The Stream parameter is for // flows that support streaming: providing their results incrementally. type Flow[In, Out, Stream any] struct { - name string // The last component of the flow's key in the registry. - fn Func[In, Out, Stream] // The function to run. - stateStore FlowStateStore // Where FlowStates are stored, to support resumption. - tstate *tracing.State // set from the action when the flow is defined + name string // The last component of the flow's key in the registry. + fn Func[In, Out, Stream] // The function to run. + stateStore FlowStateStore // Where FlowStates are stored, to support resumption. + tstate *tracing.State // set from the action when the flow is defined + inputSchema *jsonschema.Schema // Schema of the input to the flow + outputSchema *jsonschema.Schema // Schema of the output out of the flow // TODO(jba): scheduler // TODO(jba): experimentalDurable // TODO(jba): authPolicy @@ -104,9 +107,13 @@ func DefineFlow[In, Out, Stream any](name string, fn Func[In, Out, Stream]) *Flo } func defineFlow[In, Out, Stream any](r *registry, name string, fn Func[In, Out, Stream]) *Flow[In, Out, Stream] { + var i In + var o Out f := &Flow[In, Out, Stream]{ - name: name, - fn: fn, + name: name, + fn: fn, + inputSchema: inferJSONSchema(i), + outputSchema: inferJSONSchema(o), // TODO(jba): set stateStore? } a := f.action() @@ -171,9 +178,8 @@ type flowState[In, Out any] struct { FlowID string `json:"flowId,omitempty"` FlowName string `json:"name,omitempty"` // start time in milliseconds since the epoch - StartTime tracing.Milliseconds `json:"startTime,omitempty"` - Input In `json:"input,omitempty"` - + StartTime tracing.Milliseconds `json:"startTime,omitempty"` + Input In `json:"input,omitempty"` mu sync.Mutex Cache map[string]json.RawMessage `json:"cache,omitempty"` EventsTriggered map[string]any `json:"eventsTriggered,omitempty"` @@ -244,11 +250,9 @@ type FlowResult[Out any] struct { // action creates an action for the flow. See the comment at the top of this file for more information. func (f *Flow[In, Out, Stream]) action() *Action[*flowInstruction[In], *flowState[In, Out], Stream] { - var i In - var o Out metadata := map[string]any{ - "inputSchema": inferJSONSchema(i), - "outputSchema": inferJSONSchema(o), + "inputSchema": f.inputSchema, + "outputSchema": f.outputSchema, } cback := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) { tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true") @@ -291,6 +295,10 @@ type flow interface { func (f *Flow[In, Out, Stream]) Name() string { return f.name } func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) { + // Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process. + if err := ValidateJSON(input, f.inputSchema); err != nil { + return nil, &httpError{http.StatusBadRequest, err} + } var in In if err := json.Unmarshal(input, &in); err != nil { return nil, &httpError{http.StatusBadRequest, err} @@ -370,7 +378,19 @@ func (f *Flow[In, Out, Stream]) execute(ctx context.Context, state *flowState[In // TODO(jba): Save rootSpanContext in the state. // TODO(jba): If input is missing, get it from state.input and overwrite metadata.input. start := time.Now() - output, err := f.fn(ctx, input, cb) + var err error + if err = ValidateValue(input, f.inputSchema); err != nil { + err = fmt.Errorf("invalid input: %w", err) + } + var output Out + if err == nil { + output, err = f.fn(ctx, input, cb) + if err == nil { + if err = ValidateValue(output, f.outputSchema); err != nil { + err = fmt.Errorf("invalid output: %w", err) + } + } + } latency := time.Since(start) if err != nil { // TODO(jba): handle InterruptError @@ -533,7 +553,6 @@ func RunFlow[In, Out, Stream any](ctx context.Context, flow *Flow[In, Out, Strea // InternalStreamFlow is for use by genkit.StreamFlow exclusively. // It is not subject to any backwards compatibility guarantees. func InternalStreamFlow[In, Out, Stream any](ctx context.Context, flow *Flow[In, Out, Stream], input In, callback func(context.Context, Stream) error) (Out, error) { - state, err := flow.start(ctx, input, callback) if err != nil { return internal.Zero[Out](), err diff --git a/go/core/validation.go b/go/core/validation.go new file mode 100644 index 0000000000..970b1959d0 --- /dev/null +++ b/go/core/validation.go @@ -0,0 +1,72 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/invopop/jsonschema" + "github.com/xeipuuv/gojsonschema" +) + +// ValidateValue will validate any value against the expected schema. +// It will return an error if it doesn't match the schema, otherwise it will return nil. +func ValidateValue(data any, schema *jsonschema.Schema) error { + dataBytes, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("data is not a valid JSON type: %w", err) + } + return ValidateJSON(dataBytes, schema) +} + +// ValidateJSON will validate JSON against the expected schema. +// It will return an error if it doesn't match the schema, otherwise it will return nil. +func ValidateJSON(dataBytes json.RawMessage, schema *jsonschema.Schema) error { + schemaBytes, err := schema.MarshalJSON() + if err != nil { + return fmt.Errorf("expected schema is not valid: %w", err) + } + return ValidateRaw(dataBytes, schemaBytes) +} + +// ValidateRaw will validate JSON data against the JSON schema. +// It will return an error if it doesn't match the schema, otherwise it will return nil. +func ValidateRaw(dataBytes json.RawMessage, schemaBytes json.RawMessage) error { + var data any + // Do this check separately from below to get a better error message. + if err := json.Unmarshal(dataBytes, &data); err != nil { + return fmt.Errorf("data is not valid JSON: %w", err) + } + + schemaLoader := gojsonschema.NewBytesLoader(schemaBytes) + documentLoader := gojsonschema.NewBytesLoader(dataBytes) + + result, err := gojsonschema.Validate(schemaLoader, documentLoader) + if err != nil { + return fmt.Errorf("failed to validate data against expected schema: %w", err) + } + + if !result.Valid() { + var errors []string + for _, err := range result.Errors() { + errors = append(errors, fmt.Sprintf("- %s", err)) + } + return fmt.Errorf("data did not match expected schema:\n%s", strings.Join(errors, "\n")) + } + + return nil +} diff --git a/go/go.mod b/go/go.mod index 6ffc1a32a5..c267ab6caa 100644 --- a/go/go.mod +++ b/go/go.mod @@ -15,6 +15,7 @@ require ( github.com/invopop/jsonschema v0.12.0 github.com/jba/slog v0.2.0 github.com/wk8/go-ordered-map/v2 v2.1.8 + github.com/xeipuuv/gojsonschema v1.2.0 go.opentelemetry.io/otel v1.26.0 go.opentelemetry.io/otel/metric v1.26.0 go.opentelemetry.io/otel/sdk v1.26.0 @@ -49,6 +50,8 @@ require ( github.com/googleapis/gax-go/v2 v2.12.3 // indirect github.com/kr/text v0.2.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect diff --git a/go/go.sum b/go/go.sum index 5ec241394d..a041ea68dd 100644 --- a/go/go.sum +++ b/go/go.sum @@ -93,8 +93,6 @@ github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/ github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4= github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI= github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= -github.com/jba/slog v0.1.0 h1:m7pbPxGRvFcQy4vONykm/9X+0Fx4FGEDl7A6E/C/z9Q= -github.com/jba/slog v0.1.0/go.mod h1:R9u+1Qbl7LcDnJaFNIPer+AJa3yK9eZ8SQUE4waKFiw= github.com/jba/slog v0.2.0 h1:jI0U5NRR3EJKGsbeEVpItJNogk0c4RMeCl7vJmogCJI= github.com/jba/slog v0.2.0/go.mod h1:0Dh7Vyz3Td68Z1OwzadfincHwr7v+PpzadrS2Jua338= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= @@ -112,6 +110,7 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= @@ -119,6 +118,12 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= diff --git a/go/plugins/dotprompt/genkit_test.go b/go/plugins/dotprompt/genkit_test.go index 8d03f5b24f..aef5ce1076 100644 --- a/go/plugins/dotprompt/genkit_test.go +++ b/go/plugins/dotprompt/genkit_test.go @@ -25,7 +25,7 @@ import ( type testGenerator struct{} func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) { - input := req.Messages[0].Content[0].Text() + input := req.Messages[0].Content[0].Text output := fmt.Sprintf("AI reply to %q", input) r := &ai.GenerateResponse{ @@ -69,7 +69,7 @@ func TestExecute(t *testing.T) { t.FailNow() } } - got := msg.Content[0].Text() + got := msg.Content[0].Text want := `AI reply to "TestExecute"` if got != want { t.Errorf("fake generator replied with %q, want %q", got, want) diff --git a/go/plugins/dotprompt/render.go b/go/plugins/dotprompt/render.go index 82bda0a899..d47c0b38e4 100644 --- a/go/plugins/dotprompt/render.go +++ b/go/plugins/dotprompt/render.go @@ -41,7 +41,7 @@ func (p *Prompt) RenderText(variables map[string]any) (string, error) { if !part.IsText() { return "", errors.New("RenderText: multi-modal prompt can't be rendered as text") } - sb.WriteString(part.Text()) + sb.WriteString(part.Text) } return sb.String(), nil } diff --git a/go/plugins/dotprompt/render_test.go b/go/plugins/dotprompt/render_test.go index 8ffe1022ed..58917ad7e3 100644 --- a/go/plugins/dotprompt/render_test.go +++ b/go/plugins/dotprompt/render_test.go @@ -201,10 +201,10 @@ func TestRenderMessages(t *testing.T) { if a.IsText() != b.IsText() { return false } - if a.Text() != b.Text() { + if a.Text != b.Text { return false } - if a.ContentType() != b.ContentType() { + if a.ContentType != b.ContentType { return false } return true diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 17eea79c1d..28ba7881f0 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -235,19 +235,19 @@ func convertParts(parts []*ai.Part) []genai.Part { func convertPart(p *ai.Part) genai.Part { switch { case p.IsText(): - return genai.Text(p.Text()) + return genai.Text(p.Text) case p.IsMedia(): - return genai.Blob{MIMEType: p.ContentType(), Data: []byte(p.Text())} + return genai.Blob{MIMEType: p.ContentType, Data: []byte(p.Text)} case p.IsData(): panic("googleai does not support Data parts") case p.IsToolResponse(): - toolResp := p.ToolResponse() + toolResp := p.ToolResponse return genai.FunctionResponse{ Name: toolResp.Name, Response: toolResp.Output, } case p.IsToolRequest(): - toolReq := p.ToolRequest() + toolReq := p.ToolRequest return genai.FunctionCall{ Name: toolReq.Name, Args: toolReq.Input, diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index 96c557a56e..3330a867a3 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -83,7 +83,7 @@ func TestGenerator(t *testing.T) { if err != nil { t.Fatal(err) } - out := resp.Candidates[0].Message.Content[0].Text() + out := resp.Candidates[0].Message.Content[0].Text if out != "France" { t.Errorf("got \"%s\", expecting \"France\"", out) } @@ -115,7 +115,7 @@ func TestGeneratorStreaming(t *testing.T) { parts := 0 _, err = g.Generate(ctx, req, func(ctx context.Context, c *ai.Candidate) error { parts++ - out += c.Message.Content[0].Text() + out += c.Message.Content[0].Text return nil }) if err != nil { @@ -193,7 +193,7 @@ func TestGeneratorTool(t *testing.T) { t.Fatal(err) } - out := resp.Candidates[0].Message.Content[0].Text() + out := resp.Candidates[0].Message.Content[0].Text if !strings.Contains(out, "12.25") { t.Errorf("got %s, expecting it to contain \"12.25\"", out) } diff --git a/go/plugins/localvec/localvec_test.go b/go/plugins/localvec/localvec_test.go index 978e4b1301..80f9bf9c94 100644 --- a/go/plugins/localvec/localvec_test.go +++ b/go/plugins/localvec/localvec_test.go @@ -82,7 +82,7 @@ func TestLocalVec(t *testing.T) { t.Errorf("got %d results, expected 2", len(docs)) } for _, d := range docs { - text := d.Content[0].Text() + text := d.Content[0].Text if !strings.HasPrefix(text, "hello") { t.Errorf("returned doc text %q does not start with %q", text, "hello") } diff --git a/go/plugins/pinecone/genkit.go b/go/plugins/pinecone/genkit.go index e05e8222f9..4413c37a5c 100644 --- a/go/plugins/pinecone/genkit.go +++ b/go/plugins/pinecone/genkit.go @@ -149,7 +149,7 @@ func (r *retriever) Index(ctx context.Context, req *ai.IndexerRequest) error { // but it loses the structure of the document. var sb strings.Builder for _, p := range doc.Content { - sb.WriteString(p.Text()) + sb.WriteString(p.Text) } metadata[r.textKey] = sb.String() diff --git a/go/plugins/pinecone/genkit_test.go b/go/plugins/pinecone/genkit_test.go index cb6d5421e7..4c72fa437c 100644 --- a/go/plugins/pinecone/genkit_test.go +++ b/go/plugins/pinecone/genkit_test.go @@ -132,7 +132,7 @@ func TestGenkit(t *testing.T) { t.Errorf("got %d results, expected 2", len(docs)) } for _, d := range docs { - text := d.Content[0].Text() + text := d.Content[0].Text if !strings.HasPrefix(text, "hello") { t.Errorf("returned doc text %q does not start with %q", text, "hello") } diff --git a/go/plugins/vertexai/embed.go b/go/plugins/vertexai/embed.go index baa15ca804..22665f215b 100644 --- a/go/plugins/vertexai/embed.go +++ b/go/plugins/vertexai/embed.go @@ -59,7 +59,7 @@ func (e *embedder) Embed(ctx context.Context, req *ai.EmbedRequest) ([]float32, instances := make([]*structpb.Value, 0, len(req.Document.Content)) for _, part := range req.Document.Content { fields := map[string]any{ - "content": part.Text(), + "content": part.Text, } if title != "" { fields["title"] = title diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index f9e0c1c677..828274d79a 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -203,19 +203,19 @@ func convertParts(parts []*ai.Part) []genai.Part { func convertPart(p *ai.Part) genai.Part { switch { case p.IsText(): - return genai.Text(p.Text()) + return genai.Text(p.Text) case p.IsMedia(): - return genai.Blob{MIMEType: p.ContentType(), Data: []byte(p.Text())} + return genai.Blob{MIMEType: p.ContentType, Data: []byte(p.Text)} case p.IsData(): panic("vertexai does not support Data parts") case p.IsToolResponse(): - toolResp := p.ToolResponse() + toolResp := p.ToolResponse return genai.FunctionResponse{ Name: toolResp.Name, Response: toolResp.Output, } case p.IsToolRequest(): - toolReq := p.ToolRequest() + toolReq := p.ToolRequest return genai.FunctionCall{ Name: toolReq.Name, Args: toolReq.Input, diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index be16c25620..c53f141b0c 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -56,7 +56,7 @@ func TestGenerator(t *testing.T) { if err != nil { t.Fatal(err) } - out := resp.Candidates[0].Message.Content[0].Text() + out := resp.Candidates[0].Message.Content[0].Text if !strings.Contains(out, "France") { t.Errorf("got \"%s\", expecting it would contain \"France\"", out) } @@ -128,7 +128,7 @@ func TestGeneratorTool(t *testing.T) { t.Fatal(err) } - out := resp.Candidates[0].Message.Content[0].Text() + out := resp.Candidates[0].Message.Content[0].Text if !strings.Contains(out, "12.25") { t.Errorf("got %s, expecting it to contain \"12.25\"", out) } diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 64a90162b0..1c8a786341 100644 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -41,6 +41,7 @@ package main import ( "context" + "encoding/json" "fmt" "log" "os" @@ -58,10 +59,23 @@ A regular customer named {{customerName}} enters. Greet the customer in one sentence, and recommend a coffee drink. ` +const simpleStructuredGreetingPromptTemplate = ` +You're a barista at a nice coffee shop. +A regular customer named {{customerName}} enters. +Greet the customer in one sentence. +Provide the name of the drink of the day, nothing else. +` + type simpleGreetingInput struct { CustomerName string `json:"customerName"` } +type simpleGreetingOutput struct { + CustomerName string `json:"customerName"` + Greeting string `json:"greeting,omitempty"` + DrinkOfDay string `json:"drinkOfDay"` +} + const greetingWithHistoryPromptTemplate = ` {{role "user"}} Hi, my name is {{customerName}}. The time is {{currentTime}}. Who are you? @@ -96,13 +110,13 @@ func main() { os.Exit(1) } - if err := googleai.Init(context.Background(), "gemini-1.0-pro", apiKey); err != nil { + if err := googleai.Init(context.Background(), "gemini-1.5-pro", apiKey); err != nil { log.Fatal(err) } simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate, &dotprompt.Config{ - Model: "google-genai/gemini-1.0-pro", + Model: "google-genai/gemini-1.5-pro", InputSchema: jsonschema.Reflect(simpleGreetingInput{}), OutputFormat: ai.OutputFormatText, }, @@ -130,7 +144,7 @@ func main() { greetingWithHistoryPrompt, err := dotprompt.Define("greetingWithHistory", greetingWithHistoryPromptTemplate, &dotprompt.Config{ - Model: "google-genai/gemini-1.0-pro", + Model: "google-genai/gemini-1.5-pro", InputSchema: jsonschema.Reflect(customerTimeAndHistoryInput{}), OutputFormat: ai.OutputFormatText, }, @@ -156,6 +170,51 @@ func main() { return text, nil }) + r := &jsonschema.Reflector{ + AllowAdditionalProperties: false, + DoNotReference: true, + } + schema := r.Reflect(simpleGreetingOutput{}) + jsonBytes, err := schema.MarshalJSON() + if err != nil { + log.Fatal(err) + } + + var outputSchema map[string]any + err = json.Unmarshal(jsonBytes, &outputSchema) + if err != nil { + log.Fatal(err) + } + + simpleStructuredGreetingPrompt, err := dotprompt.Define("simpleStructuredGreeting", simpleStructuredGreetingPromptTemplate, + &dotprompt.Config{ + Model: "google-genai/gemini-1.5-pro", + InputSchema: jsonschema.Reflect(simpleGreetingInput{}), + OutputFormat: ai.OutputFormatJSON, + OutputSchema: outputSchema, + }, + ) + if err != nil { + log.Fatal(err) + } + + genkit.DefineFlow("simpleStructuredGreeting", func(ctx context.Context, input *simpleGreetingInput, _ genkit.NoStream) (string, error) { + vars, err := simpleGreetingPrompt.BuildVariables(input) + if err != nil { + return "", err + } + ai := &dotprompt.ActionInput{Variables: vars} + resp, err := simpleStructuredGreetingPrompt.Execute(ctx, ai) + if err != nil { + return "", err + } + text, err := resp.Text() + if err != nil { + return "", fmt.Errorf("simpleStructuredGreeting: %v", err) + } + return text, nil + }) + genkit.DefineFlow("testAllCoffeeFlows", func(ctx context.Context, _ struct{}, _ genkit.NoStream) (*testAllCoffeeFlowsOutput, error) { test1, err := genkit.RunFlow(ctx, simpleGreetingFlow, &simpleGreetingInput{ CustomerName: "Sam",