From 56ee5942578ac42cfb365f49be30a8d8294cb7c8 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 23 May 2024 09:04:40 -0700 Subject: [PATCH 01/27] Added input schema validation. --- go/genkit/action.go | 12 ++++++++---- go/genkit/flow.go | 8 +++++++- go/samples/flow-sample1/main.go | 14 ++++++++++++++ 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/go/genkit/action.go b/go/genkit/action.go index 13661f02c6..d23af55bbf 100644 --- a/go/genkit/action.go +++ b/go/genkit/action.go @@ -195,10 +195,11 @@ func (a *Action[I, O, S]) desc() actionDesc { } // Required by genkit UI: if ad.Metadata == nil { - ad.Metadata = map[string]any{} + ad.Metadata = map[string]any{ + "inputSchema": nil, + "outputSchema": nil, + } } - ad.Metadata["inputSchema"] = nil - ad.Metadata["outputSchema"] = nil return ad } @@ -214,5 +215,8 @@ func inferJSONSchema(x any) (s *jsonschema.Schema) { // instead of nested inside a "$defs" object. r.ExpandedStruct = true } - return r.Reflect(x) + s = r.Reflect(x) + // TODO: Unwind this change once Monaco Editor supports newer than JSON schema draft-07. + s.Version = "" + return s } diff --git a/go/genkit/flow.go b/go/genkit/flow.go index 5f43db1b65..d16a1c59f0 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -230,7 +230,13 @@ type FlowResult[O any] struct { // action creates an action for the flow. See the comment at the top of this file for more information. func (f *Flow[I, O, S]) action() *Action[*flowInstruction[I], *flowState[I, O], S] { - return NewStreamingAction(f.name, nil, func(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) { + var i I + var o O + metadata := map[string]any{ + "inputSchema": inferJSONSchema(i), + "outputSchema": inferJSONSchema(o), + } + return NewStreamingAction(f.name, metadata, func(ctx context.Context, inst *flowInstruction[I], cb StreamingCallback[S]) (*flowState[I, O], error) { spanMetaKey.fromContext(ctx).SetAttr("flow:wrapperAction", "true") return f.runInstruction(ctx, inst, cb) }) diff --git a/go/samples/flow-sample1/main.go b/go/samples/flow-sample1/main.go index 4d00d3e40c..215d10130a 100644 --- a/go/samples/flow-sample1/main.go +++ b/go/samples/flow-sample1/main.go @@ -26,6 +26,7 @@ import ( "context" "fmt" "log" + "strconv" "github.com/firebase/genkit/go/genkit" ) @@ -43,6 +44,19 @@ func main() { return genkit.RunFlow(ctx, basic, "foo") }) + type complex struct { + Key string `json:"key"` + Value int `json:"value"` + } + + genkit.DefineFlow("complex", func(ctx context.Context, c complex, _ genkit.NoStream) (string, error) { + foo, err := genkit.Run(ctx, "call-llm", func() (string, error) { return c.Key + ": " + strconv.Itoa(c.Value), nil }) + if err != nil { + return "", err + } + return foo, nil + }) + type chunk struct { Count int `json:"count"` } From 740a4f9e3f295792e76cdb2d90aebc3dafe21d9c Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Sat, 25 May 2024 17:02:02 -0700 Subject: [PATCH 02/27] Added output conformance. --- go/ai/generator.go | 103 +++++++++++++++++++++++- go/ai/generator_test.go | 168 ++++++++++++++++++++++++++++++++++++++++ go/go.mod | 6 ++ go/go.sum | 9 ++- 4 files changed, 283 insertions(+), 3 deletions(-) create mode 100644 go/ai/generator_test.go diff --git a/go/ai/generator.go b/go/ai/generator.go index c1dc3ab537..dd7ff4ece9 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -16,12 +16,14 @@ package ai import ( "context" + "encoding/json" "errors" "fmt" "slices" "strings" "github.com/firebase/genkit/go/genkit" + "github.com/xeipuuv/gojsonschema" ) // Generator is the interface used to query an AI model. @@ -122,6 +124,12 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, return nil, err } + candidates, err := findValidCandidates(resp) + if err != nil { + return nil, err + } + resp.Candidates = candidates + newReq, err := handleToolRequest(ctx, input, resp) if err != nil { return nil, err @@ -134,6 +142,93 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, } } +// findValidCandidates will return any candidates that match the expected schema. +// It will return an error if there are no valid candidates found. +func findValidCandidates(resp *GenerateResponse) ([]*Candidate, error) { + candidates := []*Candidate{} + for _, c := range resp.Candidates { + if err := validateCandidate(c, resp.Request.Output); err != nil { + candidates = append(candidates, c) + } + } + if len(candidates) == 0 { + return nil, errors.New("generation resulted in no candidates matching provided output schema") + } + return candidates, nil +} + +// validateCandidate will check a candidate against the expected schema. +// It will return an error if it does not match, otherwise it will return nil. +func validateCandidate(candidate *Candidate, outputSchema *GenerateRequestOutput) error { + if outputSchema.Format == OutputFormatText { + return nil + } + + text, err := candidate.Text() + if err != nil { + return err + } + + var jsonData interface{} + err = json.Unmarshal([]byte(text), &jsonData) + if err != nil { + return fmt.Errorf("candidate did not have valid JSON: %w", err) + } + + schemaBytes, err := json.Marshal(outputSchema.Schema) + if err != nil { + return fmt.Errorf("expected schema is not valid: %w", err) + } + + schemaLoader := gojsonschema.NewStringLoader(string(schemaBytes)) + jsonLoader := gojsonschema.NewGoLoader(jsonData) + result, err := gojsonschema.Validate(schemaLoader, jsonLoader) + if err != nil { + return fmt.Errorf("failed to validate expected schema: %w", err) + } + + if !result.Valid() { + var errMsg string + for _, err := range result.Errors() { + errMsg += fmt.Sprintf("- %s\n", err) + } + return fmt.Errorf("candidate did not match expected schema:\n%s", errMsg) + } + + if err = checkUnknownFields(jsonData.(map[string]interface{}), outputSchema.Schema); err != nil { + return err + } + + return nil +} + +// checkUnknownFields checks for unexpected fields that do not appear in the schema. +func checkUnknownFields(jsonData map[string]interface{}, schema map[string]any) error { + for key, value := range jsonData { + schemaProperties, ok := schema["properties"].(map[string]any) + if !ok { + return fmt.Errorf("candidate contains unexpected field: %s", key) + } + + if _, ok := schemaProperties[key]; !ok { + return fmt.Errorf("candidate contains unexpected field: %s", key) + } + + if nestedObj, ok := value.(map[string]interface{}); ok { + nestedSchema, ok := schemaProperties[key].(map[string]any) + if !ok { + return fmt.Errorf("candidate contains unexpected nested object: %s", key) + } + err := checkUnknownFields(nestedObj, nestedSchema) + if err != nil { + return err + } + } + } + + return nil +} + // handleToolRequest checks if a tool was requested by a generator. // If a tool was requested, this runs the tool and returns an // updated GenerateRequest. If no tool was requested this returns nil. @@ -180,7 +275,13 @@ func (gr *GenerateResponse) Text() (string, error) { if len(gr.Candidates) == 0 { return "", errors.New("no candidates returned") } - msg := gr.Candidates[0].Message + return gr.Candidates[0].Text() +} + +// Text returns the contents of a [Candidate] as a string. It +// returns an error if the candidate has no message. +func (c *Candidate) Text() (string, error) { + msg := c.Message if msg == nil { return "", errors.New("candidate with no message") } diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go new file mode 100644 index 0000000000..612fe85529 --- /dev/null +++ b/go/ai/generator_test.go @@ -0,0 +1,168 @@ +package ai + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateCandidate(t *testing.T) { + t.Run("Valid candidate with text format", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{ + Content: []*Part{ + {text: "Hello, World!"}, + }, + }, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatText, + } + err := validateCandidate(candidate, outputSchema) + assert.NoError(t, err) + }) + + t.Run("Valid candidate with JSON format and matching schema", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{ + Content: []*Part{ + {text: `{ + "name": "John", + "age": 30, + "address": { + "street": "123 Main St", + "city": "New York", + "country": "USA" + } + }`}, + }, + }, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + Schema: map[string]any{ + "type": "object", + "required": []string{"name", "age", "address"}, + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "integer"}, + "address": map[string]any{ + "type": "object", + "required": []string{"street", "city", "country"}, + "properties": map[string]any{ + "street": map[string]any{"type": "string"}, + "city": map[string]any{"type": "string"}, + "country": map[string]any{"type": "string"}, + }, + }, + "phone": map[string]any{"type": "string"}, + }, + }, + } + err := validateCandidate(candidate, outputSchema) + assert.NoError(t, err) + }) + + t.Run("Invalid candidate with JSON format and non-matching schema", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{ + Content: []*Part{ + {text: `{"name": "John", "age": "30"}`}, + }, + }, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + Schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "integer"}, + }, + }, + } + err := validateCandidate(candidate, outputSchema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "candidate did not match expected schema") + }) + + t.Run("Candidate with invalid JSON", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{ + Content: []*Part{ + {text: `{"name": "John", "age": 30`}, // Missing trailing }. + }, + }, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + } + err := validateCandidate(candidate, outputSchema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "candidate did not have valid JSON") + }) + + t.Run("Candidate with no message", func(t *testing.T) { + candidate := &Candidate{} + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + } + err := validateCandidate(candidate, outputSchema) + assert.Error(t, err) + assert.Equal(t, "candidate with no message", err.Error()) + }) + + t.Run("Candidate with message but no content", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{}, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + } + err := validateCandidate(candidate, outputSchema) + assert.Error(t, err) + assert.Equal(t, "candidate message has no content", err.Error()) + }) + + t.Run("Candidate contains unexpected field", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{ + Content: []*Part{ + {text: `{"name": "John", "height": "190"}`}, + }, + }, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + Schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "integer"}, + }, + }, + } + err := validateCandidate(candidate, outputSchema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "candidate contains unexpected field") + }) + + t.Run("Invalid expected schema", func(t *testing.T) { + candidate := &Candidate{ + Message: &Message{ + Content: []*Part{ + {text: `{"name": "John", "age": 30}`}, + }, + }, + } + outputSchema := &GenerateRequestOutput{ + Format: OutputFormatJSON, + Schema: map[string]any{ + "type": "invalid", + }, + } + err := validateCandidate(candidate, outputSchema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to validate expected schema") + }) +} diff --git a/go/go.mod b/go/go.mod index 6ffc1a32a5..2e2190b210 100644 --- a/go/go.mod +++ b/go/go.mod @@ -14,7 +14,9 @@ require ( github.com/google/uuid v1.6.0 github.com/invopop/jsonschema v0.12.0 github.com/jba/slog v0.2.0 + github.com/stretchr/testify v1.9.0 github.com/wk8/go-ordered-map/v2 v2.1.8 + github.com/xeipuuv/gojsonschema v1.2.0 go.opentelemetry.io/otel v1.26.0 go.opentelemetry.io/otel/metric v1.26.0 go.opentelemetry.io/otel/sdk v1.26.0 @@ -39,6 +41,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.46.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -49,6 +52,9 @@ require ( github.com/googleapis/gax-go/v2 v2.12.3 // indirect github.com/kr/text v0.2.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect diff --git a/go/go.sum b/go/go.sum index 5ec241394d..a041ea68dd 100644 --- a/go/go.sum +++ b/go/go.sum @@ -93,8 +93,6 @@ github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/ github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4= github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI= github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= -github.com/jba/slog v0.1.0 h1:m7pbPxGRvFcQy4vONykm/9X+0Fx4FGEDl7A6E/C/z9Q= -github.com/jba/slog v0.1.0/go.mod h1:R9u+1Qbl7LcDnJaFNIPer+AJa3yK9eZ8SQUE4waKFiw= github.com/jba/slog v0.2.0 h1:jI0U5NRR3EJKGsbeEVpItJNogk0c4RMeCl7vJmogCJI= github.com/jba/slog v0.2.0/go.mod h1:0Dh7Vyz3Td68Z1OwzadfincHwr7v+PpzadrS2Jua338= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= @@ -112,6 +110,7 @@ github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncj github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= @@ -119,6 +118,12 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= From 4cf6db5d8751850e9d5369ffccbdc6f51a261820 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Sat, 25 May 2024 17:10:58 -0700 Subject: [PATCH 03/27] Moved error up one level. --- go/ai/generator.go | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/go/ai/generator.go b/go/ai/generator.go index 0947568880..fdef1872da 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -124,9 +124,9 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, return nil, err } - candidates, err := findValidCandidates(resp) - if err != nil { - return nil, err + candidates := findValidCandidates(resp) + if len(candidates) == 0 { + return nil, errors.New("generation resulted in no candidates matching provided output schema") } resp.Candidates = candidates @@ -142,19 +142,15 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, } } -// findValidCandidates will return any candidates that match the expected schema. -// It will return an error if there are no valid candidates found. -func findValidCandidates(resp *GenerateResponse) ([]*Candidate, error) { +// findValidCandidates finds all candidates that match the expected schema. +func findValidCandidates(resp *GenerateResponse) []*Candidate { candidates := []*Candidate{} for _, c := range resp.Candidates { if err := validateCandidate(c, resp.Request.Output); err != nil { candidates = append(candidates, c) } } - if len(candidates) == 0 { - return nil, errors.New("generation resulted in no candidates matching provided output schema") - } - return candidates, nil + return candidates } // validateCandidate will check a candidate against the expected schema. From 03d8949aedb8d429c72accfe228887fdfc8624c8 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Sat, 25 May 2024 17:12:53 -0700 Subject: [PATCH 04/27] Formatting. --- go/ai/generator_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index 612fe85529..9953a6b9ed 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -1,3 +1,17 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package ai import ( From be02e923fc9104fcb8655db47f5da2961fc8735e Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Sat, 25 May 2024 18:46:06 -0700 Subject: [PATCH 05/27] Fixed inverse condition. --- go/ai/generator.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/go/ai/generator.go b/go/ai/generator.go index fdef1872da..28faf31f2f 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -20,9 +20,11 @@ import ( "errors" "fmt" "slices" + "strconv" "strings" "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/internal" "github.com/xeipuuv/gojsonschema" ) @@ -124,7 +126,7 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, return nil, err } - candidates := findValidCandidates(resp) + candidates := findValidCandidates(ctx, resp) if len(candidates) == 0 { return nil, errors.New("generation resulted in no candidates matching provided output schema") } @@ -143,10 +145,13 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, } // findValidCandidates finds all candidates that match the expected schema. -func findValidCandidates(resp *GenerateResponse) []*Candidate { +func findValidCandidates(ctx context.Context, resp *GenerateResponse) []*Candidate { candidates := []*Candidate{} - for _, c := range resp.Candidates { - if err := validateCandidate(c, resp.Request.Output); err != nil { + for i, c := range resp.Candidates { + err := validateCandidate(c, resp.Request.Output) + if err != nil { + internal.Logger(ctx).Debug("candidate %s did not match provided output schema: %w", strconv.Itoa(i), err.Error()) + } else { candidates = append(candidates, c) } } From 7c5c81e22a66889abe77ecefe71d2adff91beee9 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Sat, 25 May 2024 21:56:53 -0700 Subject: [PATCH 06/27] Added conformance message to request and sample code. --- go/ai/generator.go | 72 ++++++++------- go/ai/generator_test.go | 3 +- go/samples/coffee-shop/main.go | 61 ++++++++++++- test.txt | 157 +++++++++++++++++++++++++++++++++ 4 files changed, 259 insertions(+), 34 deletions(-) create mode 100644 test.txt diff --git a/go/ai/generator.go b/go/ai/generator.go index 28faf31f2f..662c8900b8 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -75,12 +75,22 @@ func RegisterGenerator(provider, name string, metadata *GeneratorMetadata, gener // Generate applies a [Generator] to some input, handling tool requests. func Generate(ctx context.Context, g Generator, input *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { + if err := conformOutput(input); err != nil { + return nil, err + } + for { resp, err := g.Generate(ctx, input, cb) if err != nil { return nil, err } + candidates := findValidCandidates(ctx, resp) + if len(candidates) == 0 { + return nil, errors.New("generation resulted in no candidates matching provided output schema") + } + resp.Candidates = candidates + newReq, err := handleToolRequest(ctx, input, resp) if err != nil { return nil, err @@ -120,6 +130,10 @@ type generatorAction struct { // but invokes the [core.Action] rather than invoking the Generator // directly. func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { + if err := conformOutput(input); err != nil { + return nil, err + } + for { resp, err := ga.action.Run(ctx, input, cb) if err != nil { @@ -144,15 +158,33 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, } } +// conformOutput appends a message to the request indicating conformance to the expected schema. +func conformOutput(input *GenerateRequest) error { + if len(input.Output.Schema) > 0 && len(input.Messages) > 0 { + jsonBytes, err := json.Marshal(input.Output.Schema) + if err != nil { + return fmt.Errorf("expected schema is not valid: %w", err) + } + + jsonStr := string(jsonBytes) + escapedJSON := strconv.Quote(jsonStr) + part := &Part{ + text: fmt.Sprintf("Output should be in JSON format and conform to the following schema:\n\n```%s```", escapedJSON), + } + input.Messages[len(input.Messages)-1].Content = append(input.Messages[len(input.Messages)-1].Content, part) + } + return nil +} + // findValidCandidates finds all candidates that match the expected schema. func findValidCandidates(ctx context.Context, resp *GenerateResponse) []*Candidate { candidates := []*Candidate{} for i, c := range resp.Candidates { err := validateCandidate(c, resp.Request.Output) - if err != nil { - internal.Logger(ctx).Debug("candidate %s did not match provided output schema: %w", strconv.Itoa(i), err.Error()) - } else { + if err == nil { candidates = append(candidates, c) + } else { + internal.Logger(ctx).Debug("candidate did not match provided output schema", "index", i, "error", err.Error()) } } return candidates @@ -170,6 +202,8 @@ func validateCandidate(candidate *Candidate, outputSchema *GenerateRequestOutput return err } + text = stripJsonDelimiters(text) + var jsonData interface{} err = json.Unmarshal([]byte(text), &jsonData) if err != nil { @@ -196,38 +230,12 @@ func validateCandidate(candidate *Candidate, outputSchema *GenerateRequestOutput return fmt.Errorf("candidate did not match expected schema:\n%s", errMsg) } - if err = checkUnknownFields(jsonData.(map[string]interface{}), outputSchema.Schema); err != nil { - return err - } - return nil } -// checkUnknownFields checks for unexpected fields that do not appear in the schema. -func checkUnknownFields(jsonData map[string]interface{}, schema map[string]any) error { - for key, value := range jsonData { - schemaProperties, ok := schema["properties"].(map[string]any) - if !ok { - return fmt.Errorf("candidate contains unexpected field: %s", key) - } - - if _, ok := schemaProperties[key]; !ok { - return fmt.Errorf("candidate contains unexpected field: %s", key) - } - - if nestedObj, ok := value.(map[string]interface{}); ok { - nestedSchema, ok := schemaProperties[key].(map[string]any) - if !ok { - return fmt.Errorf("candidate contains unexpected nested object: %s", key) - } - err := checkUnknownFields(nestedObj, nestedSchema) - if err != nil { - return err - } - } - } - - return nil +// stripJsonDelimiters strips JSON delimiters that may come back in the response. +func stripJsonDelimiters(s string) string { + return strings.TrimSuffix(strings.TrimPrefix(s, "```json"), "```") } // handleToolRequest checks if a tool was requested by a generator. diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index 9953a6b9ed..fa28e95f33 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -154,11 +154,12 @@ func TestValidateCandidate(t *testing.T) { "name": map[string]any{"type": "string"}, "age": map[string]any{"type": "integer"}, }, + "additionalProperties": false, }, } err := validateCandidate(candidate, outputSchema) assert.Error(t, err) - assert.Contains(t, err.Error(), "candidate contains unexpected field") + assert.Contains(t, err.Error(), "candidate did not match expected schema") }) t.Run("Invalid expected schema", func(t *testing.T) { diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index a1ccb1b4f1..b45fe262d3 100644 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -36,6 +36,7 @@ package main import ( "context" + "encoding/json" "fmt" "log" "os" @@ -53,10 +54,23 @@ A regular customer named {{customerName}} enters. Greet the customer in one sentence, and recommend a coffee drink. ` +const simpleStructuredGreetingPromptTemplate = ` +You're a barista at a nice coffee shop. +A regular customer named {{customerName}} enters. +Greet the customer in one sentence. +Provide the name of the drink of the day, nothing else. +` + type simpleGreetingInput struct { CustomerName string `json:"customerName"` } +type simpleGreetingOutput struct { + CustomerName string `json:"customerName"` + Greeting string `json:"greeting,omitempty"` + DrinkOfDay string `json:"drinkOfDay"` +} + const greetingWithHistoryPromptTemplate = ` {{role "user"}} Hi, my name is {{customerName}}. The time is {{currentTime}}. Who are you? @@ -91,7 +105,7 @@ func main() { os.Exit(1) } - if err := googleai.Init(context.Background(), "gemini-1.0-pro", apiKey); err != nil { + if err := googleai.Init(context.Background(), "gemini-1.5-pro", apiKey); err != nil { log.Fatal(err) } @@ -151,6 +165,51 @@ func main() { return text, nil }) + r := &jsonschema.Reflector{ + AllowAdditionalProperties: false, + ExpandedStruct: true, + } + schema := r.Reflect(simpleGreetingOutput{}) + jsonBytes, err := schema.MarshalJSON() + if err != nil { + log.Fatal(err) + } + + var outputSchema map[string]any + err = json.Unmarshal(jsonBytes, &outputSchema) + if err != nil { + log.Fatal(err) + } + + simpleStructuredGreetingPrompt, err := dotprompt.Define("simpleStructuredGreeting", simpleStructuredGreetingPromptTemplate, + &dotprompt.Config{ + Model: "google-genai/gemini-1.5-pro", + InputSchema: jsonschema.Reflect(simpleGreetingInput{}), + OutputFormat: ai.OutputFormatJSON, + OutputSchema: outputSchema, + }, + ) + if err != nil { + log.Fatal(err) + } + + genkit.DefineFlow("simpleStructuredGreeting", func(ctx context.Context, input *simpleGreetingInput, _ genkit.NoStream) (string, error) { + vars, err := simpleGreetingPrompt.BuildVariables(input) + if err != nil { + return "", err + } + ai := &dotprompt.ActionInput{Variables: vars} + resp, err := simpleStructuredGreetingPrompt.Execute(ctx, ai) + if err != nil { + return "", err + } + text, err := resp.Text() + if err != nil { + return "", fmt.Errorf("simpleStructuredGreeting: %v", err) + } + return text, nil + }) + genkit.DefineFlow("testAllCoffeeFlows", func(ctx context.Context, _ struct{}, _ genkit.NoStream) (*testAllCoffeeFlowsOutput, error) { test1, err := genkit.RunFlow(ctx, simpleGreetingFlow, &simpleGreetingInput{ CustomerName: "Sam", diff --git a/test.txt b/test.txt new file mode 100644 index 0000000000..2864e4d3b7 --- /dev/null +++ b/test.txt @@ -0,0 +1,157 @@ +{ + "traceId": "3cd65103ed08e50344404fe2b11ce807", + "displayName": "dev-run-action-wrapper", + "startTime": 1716695799448.321, + "endTime": 1716695800962.0513, + "spans": { + "1e0cd2115f1a49c5": { + "spanId": "1e0cd2115f1a49c5", + "traceId": "3cd65103ed08e50344404fe2b11ce807", + "parentSpanId": "c52101826cbbd8c6", + "startTime": 1716695799448.4412, + "endTime": 1716695800960.2124, + "attributes": { + "genkit:input": "{\"start\":{\"input\":{\"customerName\":\"Alex\"}}}", + "genkit:metadata:flow:wrapperAction": "true", + "genkit:metadata:subtype": "flow", + "genkit:name": "simpleStructuredGreeting", + "genkit:output": "{\"flowId\":\"6b3edd5d-d377-4c60-a3e4-060b891abdd8\",\"name\":\"simpleStructuredGreeting\",\"startTime\":1716695799448.457,\"input\":{\"customerName\":\"Alex\"},\"executions\":[{\"startTime\":1716695799448.46,\"traceIds\":[\"3cd65103ed08e50344404fe2b11ce807\"]}],\"operation\":{\"name\":\"6b3edd5d-d377-4c60-a3e4-060b891abdd8\",\"done\":true,\"result\":{\"response\":\"```json\\n{\\\"customerName\\\": \\\"Alex\\\", \\\"greeting\\\": \\\"Hey Alex, good to see you again!\\\", \\\"drinkOfDay\\\": \\\"Caramel Cloud Macchiato\\\"}\\n```\"}}}", + "genkit:path": "/dev-run-action-wrapper/simpleStructuredGreeting", + "genkit:state": "success", + "genkit:type": "action" + }, + "displayName": "simpleStructuredGreeting", + "instrumentationLibrary": { + "name": "genkit-tracer", + "version": "v1" + }, + "spanKind": "INTERNAL", + "sameProcessAsParentSpan": { + "value": true + }, + "status": { + "code": 0 + }, + "timeEvents": {} + }, + "558a78c92f2d7085": { + "spanId": "558a78c92f2d7085", + "traceId": "3cd65103ed08e50344404fe2b11ce807", + "parentSpanId": "1e0cd2115f1a49c5", + "startTime": 1716695799448.478, + "endTime": 1716695800958.5195, + "attributes": { + "genkit:input": "{\"customerName\":\"Alex\"}", + "genkit:isRoot": true, + "genkit:metadata:flow:dispatchType": "start", + "genkit:metadata:flow:execution": "0", + "genkit:metadata:flow:id": "6b3edd5d-d377-4c60-a3e4-060b891abdd8", + "genkit:metadata:flow:name": "simpleStructuredGreeting", + "genkit:metadata:flow:state": "done", + "genkit:metadata:subtype": "prompt", + "genkit:name": "simpleStructuredGreeting", + "genkit:output": "\"```json\\n{\\\"customerName\\\": \\\"Alex\\\", \\\"greeting\\\": \\\"Hey Alex, good to see you again!\\\", \\\"drinkOfDay\\\": \\\"Caramel Cloud Macchiato\\\"}\\n```\"", + "genkit:path": "/dev-run-action-wrapper/simpleStructuredGreeting/simpleStructuredGreeting", + "genkit:state": "success", + "genkit:type": "flow" + }, + "displayName": "simpleStructuredGreeting", + "instrumentationLibrary": { + "name": "genkit-tracer", + "version": "v1" + }, + "spanKind": "INTERNAL", + "sameProcessAsParentSpan": { + "value": true + }, + "status": { + "code": 0 + }, + "timeEvents": {} + }, + "c52101826cbbd8c6": { + "spanId": "c52101826cbbd8c6", + "traceId": "3cd65103ed08e50344404fe2b11ce807", + "startTime": 1716695799448.321, + "endTime": 1716695800962.0513, + "attributes": { + "genkit:input": "{\"start\":{\"input\":{\"customerName\":\"Alex\"}}}", + "genkit:isRoot": true, + "genkit:metadata:genkit-dev-internal": "true", + "genkit:name": "dev-run-action-wrapper", + "genkit:output": "{\"flowId\":\"6b3edd5d-d377-4c60-a3e4-060b891abdd8\",\"name\":\"simpleStructuredGreeting\",\"startTime\":1716695799448.457,\"input\":{\"customerName\":\"Alex\"},\"executions\":[{\"startTime\":1716695799448.46,\"traceIds\":[\"3cd65103ed08e50344404fe2b11ce807\"]}],\"operation\":{\"name\":\"6b3edd5d-d377-4c60-a3e4-060b891abdd8\",\"done\":true,\"result\":{\"response\":\"```json\\n{\\\"customerName\\\": \\\"Alex\\\", \\\"greeting\\\": \\\"Hey Alex, good to see you again!\\\", \\\"drinkOfDay\\\": \\\"Caramel Cloud Macchiato\\\"}\\n```\"}}}", + "genkit:path": "/dev-run-action-wrapper", + "genkit:state": "success" + }, + "displayName": "dev-run-action-wrapper", + "instrumentationLibrary": { + "name": "genkit-tracer", + "version": "v1" + }, + "spanKind": "INTERNAL", + "sameProcessAsParentSpan": { + "value": true + }, + "status": { + "code": 0 + }, + "timeEvents": {} + }, + "e811f98e392f5d5e": { + "spanId": "e811f98e392f5d5e", + "traceId": "3cd65103ed08e50344404fe2b11ce807", + "parentSpanId": "f99f5fe820a3ce71", + "startTime": 1716695799448.968, + "endTime": 1716695800954.0957, + "attributes": { + "http.method": "POST", + "http.request_content_length": 1341, + "http.status_code": 200, + "http.url": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro:streamGenerateContent?%24alt=json%3Benum-encoding%3Dint\u0026key=AIzaSyAehb8ATuR3nkFatuxhEa1HgM5Jgzkvhvg", + "net.peer.name": "generativelanguage.googleapis.com" + }, + "displayName": "HTTP POST", + "instrumentationLibrary": { + "name": "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp", + "version": "0.49.0" + }, + "spanKind": "CLIENT", + "sameProcessAsParentSpan": { + "value": true + }, + "status": { + "code": 0 + }, + "timeEvents": {} + }, + "f99f5fe820a3ce71": { + "spanId": "f99f5fe820a3ce71", + "traceId": "3cd65103ed08e50344404fe2b11ce807", + "parentSpanId": "558a78c92f2d7085", + "startTime": 1716695799448.702, + "endTime": 1716695800956.524, + "attributes": { + "genkit:input": "{\"candidates\":1,\"config\":null,\"messages\":[{\"content\":[{\"text\":\"\\nYou're a barista at a nice coffee shop.\\nA regular customer named Alex enters.\\nGreet the customer in one sentence.\\nProvide the name of the drink of the day, nothing else.\\n\"},{\"text\":\"Output should be in JSON format and conform to the following schema:\\n\\n```\\\"{\\\\\\\"$defs\\\\\\\":{\\\\\\\"simpleGreetingOutput\\\\\\\":{\\\\\\\"additionalProperties\\\\\\\":false,\\\\\\\"properties\\\\\\\":{\\\\\\\"customerName\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"drinkOfDay\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"greeting\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"}},\\\\\\\"required\\\\\\\":[\\\\\\\"customerName\\\\\\\",\\\\\\\"greeting\\\\\\\",\\\\\\\"drinkOfDay\\\\\\\"],\\\\\\\"type\\\\\\\":\\\\\\\"object\\\\\\\"}},\\\\\\\"$schema\\\\\\\":\\\\\\\"https://json-schema.org/draft/2020-12/schema\\\\\\\"}\\\"```\"},{\"text\":\"Output should be in JSON format and conform to the following schema:\\n\\n```\\\"{\\\\\\\"$defs\\\\\\\":{\\\\\\\"simpleGreetingOutput\\\\\\\":{\\\\\\\"additionalProperties\\\\\\\":false,\\\\\\\"properties\\\\\\\":{\\\\\\\"customerName\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"drinkOfDay\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"greeting\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"}},\\\\\\\"required\\\\\\\":[\\\\\\\"customerName\\\\\\\",\\\\\\\"greeting\\\\\\\",\\\\\\\"drinkOfDay\\\\\\\"],\\\\\\\"type\\\\\\\":\\\\\\\"object\\\\\\\"}},\\\\\\\"$schema\\\\\\\":\\\\\\\"https://json-schema.org/draft/2020-12/schema\\\\\\\"}\\\"```\"}],\"role\":\"user\"}],\"output\":{\"format\":\"json\",\"schema\":{\"$defs\":{\"simpleGreetingOutput\":{\"additionalProperties\":false,\"properties\":{\"customerName\":{\"type\":\"string\"},\"drinkOfDay\":{\"type\":\"string\"},\"greeting\":{\"type\":\"string\"}},\"required\":[\"customerName\",\"greeting\",\"drinkOfDay\"],\"type\":\"object\"}},\"$schema\":\"https://json-schema.org/draft/2020-12/schema\"}}}", + "genkit:metadata:subtype": "model", + "genkit:name": "gemini-1.5-pro", + "genkit:output": "{\"candidates\":[{\"finishReason\":\"stop\",\"message\":{\"content\":[{\"text\":\"```json\\n{\\\"customerName\\\": \\\"Alex\\\", \\\"greeting\\\": \\\"Hey Alex, good to see you again!\\\", \\\"drinkOfDay\\\": \\\"Caramel Cloud Macchiato\\\"}\\n```\"}],\"role\":\"model\"}}],\"request\":{\"candidates\":1,\"config\":null,\"messages\":[{\"content\":[{\"text\":\"\\nYou're a barista at a nice coffee shop.\\nA regular customer named Alex enters.\\nGreet the customer in one sentence.\\nProvide the name of the drink of the day, nothing else.\\n\"},{\"text\":\"Output should be in JSON format and conform to the following schema:\\n\\n```\\\"{\\\\\\\"$defs\\\\\\\":{\\\\\\\"simpleGreetingOutput\\\\\\\":{\\\\\\\"additionalProperties\\\\\\\":false,\\\\\\\"properties\\\\\\\":{\\\\\\\"customerName\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"drinkOfDay\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"greeting\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"}},\\\\\\\"required\\\\\\\":[\\\\\\\"customerName\\\\\\\",\\\\\\\"greeting\\\\\\\",\\\\\\\"drinkOfDay\\\\\\\"],\\\\\\\"type\\\\\\\":\\\\\\\"object\\\\\\\"}},\\\\\\\"$schema\\\\\\\":\\\\\\\"https://json-schema.org/draft/2020-12/schema\\\\\\\"}\\\"```\"},{\"text\":\"Output should be in JSON format and conform to the following schema:\\n\\n```\\\"{\\\\\\\"$defs\\\\\\\":{\\\\\\\"simpleGreetingOutput\\\\\\\":{\\\\\\\"additionalProperties\\\\\\\":false,\\\\\\\"properties\\\\\\\":{\\\\\\\"customerName\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"drinkOfDay\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"greeting\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"}},\\\\\\\"required\\\\\\\":[\\\\\\\"customerName\\\\\\\",\\\\\\\"greeting\\\\\\\",\\\\\\\"drinkOfDay\\\\\\\"],\\\\\\\"type\\\\\\\":\\\\\\\"object\\\\\\\"}},\\\\\\\"$schema\\\\\\\":\\\\\\\"https://json-schema.org/draft/2020-12/schema\\\\\\\"}\\\"```\"}],\"role\":\"user\"}],\"output\":{\"format\":\"json\",\"schema\":{\"$defs\":{\"simpleGreetingOutput\":{\"additionalProperties\":false,\"properties\":{\"customerName\":{\"type\":\"string\"},\"drinkOfDay\":{\"type\":\"string\"},\"greeting\":{\"type\":\"string\"}},\"required\":[\"customerName\",\"greeting\",\"drinkOfDay\"],\"type\":\"object\"}},\"$schema\":\"https://json-schema.org/draft/2020-12/schema\"}}}}", + "genkit:path": "/dev-run-action-wrapper/simpleStructuredGreeting/simpleStructuredGreeting/gemini-1.5-pro", + "genkit:state": "success", + "genkit:type": "action" + }, + "displayName": "gemini-1.5-pro", + "instrumentationLibrary": { + "name": "genkit-tracer", + "version": "v1" + }, + "spanKind": "INTERNAL", + "sameProcessAsParentSpan": { + "value": true + }, + "status": { + "code": 0 + }, + "timeEvents": {} + } + } +} \ No newline at end of file From 895dc59a4141ff8da388ccde8648fe805ea8f4d8 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Sat, 25 May 2024 21:59:48 -0700 Subject: [PATCH 07/27] Delete test.txt --- test.txt | 157 ------------------------------------------------------- 1 file changed, 157 deletions(-) delete mode 100644 test.txt diff --git a/test.txt b/test.txt deleted file mode 100644 index 2864e4d3b7..0000000000 --- a/test.txt +++ /dev/null @@ -1,157 +0,0 @@ -{ - "traceId": "3cd65103ed08e50344404fe2b11ce807", - "displayName": "dev-run-action-wrapper", - "startTime": 1716695799448.321, - "endTime": 1716695800962.0513, - "spans": { - "1e0cd2115f1a49c5": { - "spanId": "1e0cd2115f1a49c5", - "traceId": "3cd65103ed08e50344404fe2b11ce807", - "parentSpanId": "c52101826cbbd8c6", - "startTime": 1716695799448.4412, - "endTime": 1716695800960.2124, - "attributes": { - "genkit:input": "{\"start\":{\"input\":{\"customerName\":\"Alex\"}}}", - "genkit:metadata:flow:wrapperAction": "true", - "genkit:metadata:subtype": "flow", - "genkit:name": "simpleStructuredGreeting", - "genkit:output": "{\"flowId\":\"6b3edd5d-d377-4c60-a3e4-060b891abdd8\",\"name\":\"simpleStructuredGreeting\",\"startTime\":1716695799448.457,\"input\":{\"customerName\":\"Alex\"},\"executions\":[{\"startTime\":1716695799448.46,\"traceIds\":[\"3cd65103ed08e50344404fe2b11ce807\"]}],\"operation\":{\"name\":\"6b3edd5d-d377-4c60-a3e4-060b891abdd8\",\"done\":true,\"result\":{\"response\":\"```json\\n{\\\"customerName\\\": \\\"Alex\\\", \\\"greeting\\\": \\\"Hey Alex, good to see you again!\\\", \\\"drinkOfDay\\\": \\\"Caramel Cloud Macchiato\\\"}\\n```\"}}}", - "genkit:path": "/dev-run-action-wrapper/simpleStructuredGreeting", - "genkit:state": "success", - "genkit:type": "action" - }, - "displayName": "simpleStructuredGreeting", - "instrumentationLibrary": { - "name": "genkit-tracer", - "version": "v1" - }, - "spanKind": "INTERNAL", - "sameProcessAsParentSpan": { - "value": true - }, - "status": { - "code": 0 - }, - "timeEvents": {} - }, - "558a78c92f2d7085": { - "spanId": "558a78c92f2d7085", - "traceId": "3cd65103ed08e50344404fe2b11ce807", - "parentSpanId": "1e0cd2115f1a49c5", - "startTime": 1716695799448.478, - "endTime": 1716695800958.5195, - "attributes": { - "genkit:input": "{\"customerName\":\"Alex\"}", - "genkit:isRoot": true, - "genkit:metadata:flow:dispatchType": "start", - "genkit:metadata:flow:execution": "0", - "genkit:metadata:flow:id": "6b3edd5d-d377-4c60-a3e4-060b891abdd8", - "genkit:metadata:flow:name": "simpleStructuredGreeting", - "genkit:metadata:flow:state": "done", - "genkit:metadata:subtype": "prompt", - "genkit:name": "simpleStructuredGreeting", - "genkit:output": "\"```json\\n{\\\"customerName\\\": \\\"Alex\\\", \\\"greeting\\\": \\\"Hey Alex, good to see you again!\\\", \\\"drinkOfDay\\\": \\\"Caramel Cloud Macchiato\\\"}\\n```\"", - "genkit:path": "/dev-run-action-wrapper/simpleStructuredGreeting/simpleStructuredGreeting", - "genkit:state": "success", - "genkit:type": "flow" - }, - "displayName": "simpleStructuredGreeting", - "instrumentationLibrary": { - "name": "genkit-tracer", - "version": "v1" - }, - "spanKind": "INTERNAL", - "sameProcessAsParentSpan": { - "value": true - }, - "status": { - "code": 0 - }, - "timeEvents": {} - }, - "c52101826cbbd8c6": { - "spanId": "c52101826cbbd8c6", - "traceId": "3cd65103ed08e50344404fe2b11ce807", - "startTime": 1716695799448.321, - "endTime": 1716695800962.0513, - "attributes": { - "genkit:input": "{\"start\":{\"input\":{\"customerName\":\"Alex\"}}}", - "genkit:isRoot": true, - "genkit:metadata:genkit-dev-internal": "true", - "genkit:name": "dev-run-action-wrapper", - "genkit:output": "{\"flowId\":\"6b3edd5d-d377-4c60-a3e4-060b891abdd8\",\"name\":\"simpleStructuredGreeting\",\"startTime\":1716695799448.457,\"input\":{\"customerName\":\"Alex\"},\"executions\":[{\"startTime\":1716695799448.46,\"traceIds\":[\"3cd65103ed08e50344404fe2b11ce807\"]}],\"operation\":{\"name\":\"6b3edd5d-d377-4c60-a3e4-060b891abdd8\",\"done\":true,\"result\":{\"response\":\"```json\\n{\\\"customerName\\\": \\\"Alex\\\", \\\"greeting\\\": \\\"Hey Alex, good to see you again!\\\", \\\"drinkOfDay\\\": \\\"Caramel Cloud Macchiato\\\"}\\n```\"}}}", - "genkit:path": "/dev-run-action-wrapper", - "genkit:state": "success" - }, - "displayName": "dev-run-action-wrapper", - "instrumentationLibrary": { - "name": "genkit-tracer", - "version": "v1" - }, - "spanKind": "INTERNAL", - "sameProcessAsParentSpan": { - "value": true - }, - "status": { - "code": 0 - }, - "timeEvents": {} - }, - "e811f98e392f5d5e": { - "spanId": "e811f98e392f5d5e", - "traceId": "3cd65103ed08e50344404fe2b11ce807", - "parentSpanId": "f99f5fe820a3ce71", - "startTime": 1716695799448.968, - "endTime": 1716695800954.0957, - "attributes": { - "http.method": "POST", - "http.request_content_length": 1341, - "http.status_code": 200, - "http.url": "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro:streamGenerateContent?%24alt=json%3Benum-encoding%3Dint\u0026key=AIzaSyAehb8ATuR3nkFatuxhEa1HgM5Jgzkvhvg", - "net.peer.name": "generativelanguage.googleapis.com" - }, - "displayName": "HTTP POST", - "instrumentationLibrary": { - "name": "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp", - "version": "0.49.0" - }, - "spanKind": "CLIENT", - "sameProcessAsParentSpan": { - "value": true - }, - "status": { - "code": 0 - }, - "timeEvents": {} - }, - "f99f5fe820a3ce71": { - "spanId": "f99f5fe820a3ce71", - "traceId": "3cd65103ed08e50344404fe2b11ce807", - "parentSpanId": "558a78c92f2d7085", - "startTime": 1716695799448.702, - "endTime": 1716695800956.524, - "attributes": { - "genkit:input": "{\"candidates\":1,\"config\":null,\"messages\":[{\"content\":[{\"text\":\"\\nYou're a barista at a nice coffee shop.\\nA regular customer named Alex enters.\\nGreet the customer in one sentence.\\nProvide the name of the drink of the day, nothing else.\\n\"},{\"text\":\"Output should be in JSON format and conform to the following schema:\\n\\n```\\\"{\\\\\\\"$defs\\\\\\\":{\\\\\\\"simpleGreetingOutput\\\\\\\":{\\\\\\\"additionalProperties\\\\\\\":false,\\\\\\\"properties\\\\\\\":{\\\\\\\"customerName\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"drinkOfDay\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"greeting\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"}},\\\\\\\"required\\\\\\\":[\\\\\\\"customerName\\\\\\\",\\\\\\\"greeting\\\\\\\",\\\\\\\"drinkOfDay\\\\\\\"],\\\\\\\"type\\\\\\\":\\\\\\\"object\\\\\\\"}},\\\\\\\"$schema\\\\\\\":\\\\\\\"https://json-schema.org/draft/2020-12/schema\\\\\\\"}\\\"```\"},{\"text\":\"Output should be in JSON format and conform to the following schema:\\n\\n```\\\"{\\\\\\\"$defs\\\\\\\":{\\\\\\\"simpleGreetingOutput\\\\\\\":{\\\\\\\"additionalProperties\\\\\\\":false,\\\\\\\"properties\\\\\\\":{\\\\\\\"customerName\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"drinkOfDay\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"greeting\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"}},\\\\\\\"required\\\\\\\":[\\\\\\\"customerName\\\\\\\",\\\\\\\"greeting\\\\\\\",\\\\\\\"drinkOfDay\\\\\\\"],\\\\\\\"type\\\\\\\":\\\\\\\"object\\\\\\\"}},\\\\\\\"$schema\\\\\\\":\\\\\\\"https://json-schema.org/draft/2020-12/schema\\\\\\\"}\\\"```\"}],\"role\":\"user\"}],\"output\":{\"format\":\"json\",\"schema\":{\"$defs\":{\"simpleGreetingOutput\":{\"additionalProperties\":false,\"properties\":{\"customerName\":{\"type\":\"string\"},\"drinkOfDay\":{\"type\":\"string\"},\"greeting\":{\"type\":\"string\"}},\"required\":[\"customerName\",\"greeting\",\"drinkOfDay\"],\"type\":\"object\"}},\"$schema\":\"https://json-schema.org/draft/2020-12/schema\"}}}", - "genkit:metadata:subtype": "model", - "genkit:name": "gemini-1.5-pro", - "genkit:output": "{\"candidates\":[{\"finishReason\":\"stop\",\"message\":{\"content\":[{\"text\":\"```json\\n{\\\"customerName\\\": \\\"Alex\\\", \\\"greeting\\\": \\\"Hey Alex, good to see you again!\\\", \\\"drinkOfDay\\\": \\\"Caramel Cloud Macchiato\\\"}\\n```\"}],\"role\":\"model\"}}],\"request\":{\"candidates\":1,\"config\":null,\"messages\":[{\"content\":[{\"text\":\"\\nYou're a barista at a nice coffee shop.\\nA regular customer named Alex enters.\\nGreet the customer in one sentence.\\nProvide the name of the drink of the day, nothing else.\\n\"},{\"text\":\"Output should be in JSON format and conform to the following schema:\\n\\n```\\\"{\\\\\\\"$defs\\\\\\\":{\\\\\\\"simpleGreetingOutput\\\\\\\":{\\\\\\\"additionalProperties\\\\\\\":false,\\\\\\\"properties\\\\\\\":{\\\\\\\"customerName\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"drinkOfDay\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"greeting\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"}},\\\\\\\"required\\\\\\\":[\\\\\\\"customerName\\\\\\\",\\\\\\\"greeting\\\\\\\",\\\\\\\"drinkOfDay\\\\\\\"],\\\\\\\"type\\\\\\\":\\\\\\\"object\\\\\\\"}},\\\\\\\"$schema\\\\\\\":\\\\\\\"https://json-schema.org/draft/2020-12/schema\\\\\\\"}\\\"```\"},{\"text\":\"Output should be in JSON format and conform to the following schema:\\n\\n```\\\"{\\\\\\\"$defs\\\\\\\":{\\\\\\\"simpleGreetingOutput\\\\\\\":{\\\\\\\"additionalProperties\\\\\\\":false,\\\\\\\"properties\\\\\\\":{\\\\\\\"customerName\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"drinkOfDay\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"},\\\\\\\"greeting\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\"}},\\\\\\\"required\\\\\\\":[\\\\\\\"customerName\\\\\\\",\\\\\\\"greeting\\\\\\\",\\\\\\\"drinkOfDay\\\\\\\"],\\\\\\\"type\\\\\\\":\\\\\\\"object\\\\\\\"}},\\\\\\\"$schema\\\\\\\":\\\\\\\"https://json-schema.org/draft/2020-12/schema\\\\\\\"}\\\"```\"}],\"role\":\"user\"}],\"output\":{\"format\":\"json\",\"schema\":{\"$defs\":{\"simpleGreetingOutput\":{\"additionalProperties\":false,\"properties\":{\"customerName\":{\"type\":\"string\"},\"drinkOfDay\":{\"type\":\"string\"},\"greeting\":{\"type\":\"string\"}},\"required\":[\"customerName\",\"greeting\",\"drinkOfDay\"],\"type\":\"object\"}},\"$schema\":\"https://json-schema.org/draft/2020-12/schema\"}}}}", - "genkit:path": "/dev-run-action-wrapper/simpleStructuredGreeting/simpleStructuredGreeting/gemini-1.5-pro", - "genkit:state": "success", - "genkit:type": "action" - }, - "displayName": "gemini-1.5-pro", - "instrumentationLibrary": { - "name": "genkit-tracer", - "version": "v1" - }, - "spanKind": "INTERNAL", - "sameProcessAsParentSpan": { - "value": true - }, - "status": { - "code": 0 - }, - "timeEvents": {} - } - } -} \ No newline at end of file From a73a913813f4695b7cc837957df5c70f9f6082d5 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Sat, 25 May 2024 22:16:25 -0700 Subject: [PATCH 08/27] Update generator.go --- go/ai/generator.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/ai/generator.go b/go/ai/generator.go index 662c8900b8..9bfca57334 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -193,7 +193,7 @@ func findValidCandidates(ctx context.Context, resp *GenerateResponse) []*Candida // validateCandidate will check a candidate against the expected schema. // It will return an error if it does not match, otherwise it will return nil. func validateCandidate(candidate *Candidate, outputSchema *GenerateRequestOutput) error { - if outputSchema.Format == OutputFormatText { + if outputSchema.Format != OutputFormatJSON { return nil } From df9f1432a44e6869a56c6cc088c1034da1012d2b Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Sat, 25 May 2024 22:40:49 -0700 Subject: [PATCH 09/27] Fixed $defs/$refs issue in UI. --- go/core/action.go | 6 ++++-- go/samples/coffee-shop/main.go | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index b7361db4bc..1f485ad0d9 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -22,8 +22,8 @@ import ( "reflect" "time" - "github.com/firebase/genkit/go/internal" "github.com/firebase/genkit/go/core/tracing" + "github.com/firebase/genkit/go/internal" "github.com/invopop/jsonschema" ) @@ -211,7 +211,9 @@ func (a *Action[I, O, S]) desc() actionDesc { } func inferJSONSchema(x any) (s *jsonschema.Schema) { - var r jsonschema.Reflector + r := jsonschema.Reflector{ + DoNotReference: true, + } t := reflect.TypeOf(x) if t.Kind() == reflect.Struct { if t.NumField() == 0 { diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index b45fe262d3..d2e08b24c0 100644 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -111,7 +111,7 @@ func main() { simpleGreetingPrompt, err := dotprompt.Define("simpleGreeting", simpleGreetingPromptTemplate, &dotprompt.Config{ - Model: "google-genai/gemini-1.0-pro", + Model: "google-genai/gemini-1.5-pro", InputSchema: jsonschema.Reflect(simpleGreetingInput{}), OutputFormat: ai.OutputFormatText, }, @@ -139,7 +139,7 @@ func main() { greetingWithHistoryPrompt, err := dotprompt.Define("greetingWithHistory", greetingWithHistoryPromptTemplate, &dotprompt.Config{ - Model: "google-genai/gemini-1.0-pro", + Model: "google-genai/gemini-1.5-pro", InputSchema: jsonschema.Reflect(customerTimeAndHistoryInput{}), OutputFormat: ai.OutputFormatText, }, @@ -167,7 +167,7 @@ func main() { r := &jsonschema.Reflector{ AllowAdditionalProperties: false, - ExpandedStruct: true, + DoNotReference: true, } schema := r.Reflect(simpleGreetingOutput{}) jsonBytes, err := schema.MarshalJSON() From 805f6a9cf6ed4696f60403df52231d722f6c8ab5 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Sun, 26 May 2024 07:35:19 -0700 Subject: [PATCH 10/27] PR comments. --- go/ai/generator.go | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/go/ai/generator.go b/go/ai/generator.go index 9bfca57334..6761a6c00b 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -85,7 +85,7 @@ func Generate(ctx context.Context, g Generator, input *GenerateRequest, cb func( return nil, err } - candidates := findValidCandidates(ctx, resp) + candidates := validCandidates(ctx, resp) if len(candidates) == 0 { return nil, errors.New("generation resulted in no candidates matching provided output schema") } @@ -140,7 +140,7 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, return nil, err } - candidates := findValidCandidates(ctx, resp) + candidates := validCandidates(ctx, resp) if len(candidates) == 0 { return nil, errors.New("generation resulted in no candidates matching provided output schema") } @@ -166,8 +166,7 @@ func conformOutput(input *GenerateRequest) error { return fmt.Errorf("expected schema is not valid: %w", err) } - jsonStr := string(jsonBytes) - escapedJSON := strconv.Quote(jsonStr) + escapedJSON := strconv.Quote(string(jsonBytes)) part := &Part{ text: fmt.Sprintf("Output should be in JSON format and conform to the following schema:\n\n```%s```", escapedJSON), } @@ -176,9 +175,9 @@ func conformOutput(input *GenerateRequest) error { return nil } -// findValidCandidates finds all candidates that match the expected schema. -func findValidCandidates(ctx context.Context, resp *GenerateResponse) []*Candidate { - candidates := []*Candidate{} +// validCandidates finds all candidates that match the expected schema. +func validCandidates(ctx context.Context, resp *GenerateResponse) []*Candidate { + var candidates []*Candidate for i, c := range resp.Candidates { err := validateCandidate(c, resp.Request.Output) if err == nil { @@ -202,11 +201,10 @@ func validateCandidate(candidate *Candidate, outputSchema *GenerateRequestOutput return err } - text = stripJsonDelimiters(text) + text = stripJSONDelimiters(text) - var jsonData interface{} - err = json.Unmarshal([]byte(text), &jsonData) - if err != nil { + var jsonData any + if err = json.Unmarshal([]byte(text), &jsonData); err != nil { return fmt.Errorf("candidate did not have valid JSON: %w", err) } @@ -233,8 +231,8 @@ func validateCandidate(candidate *Candidate, outputSchema *GenerateRequestOutput return nil } -// stripJsonDelimiters strips JSON delimiters that may come back in the response. -func stripJsonDelimiters(s string) string { +// stripJSONDelimiters strips Markdown JSON delimiters that may come back in the response. +func stripJSONDelimiters(s string) string { return strings.TrimSuffix(strings.TrimPrefix(s, "```json"), "```") } From 1d25e600e0b76dc7e4dd6d4ff19db966184ef650 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Sun, 26 May 2024 16:56:48 -0700 Subject: [PATCH 11/27] Added action input and output validation. --- go/core/action.go | 76 +++++++++++++++++++++++++++++++++++++++++++++-- go/core/flow.go | 28 +++++++++++++++++ 2 files changed, 101 insertions(+), 3 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index 1f485ad0d9..b1f80390f8 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + "log" "maps" "reflect" "time" @@ -25,6 +26,7 @@ import ( "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal" "github.com/invopop/jsonschema" + "github.com/xeipuuv/gojsonschema" ) // Func is the type of function that Actions and Flows execute. @@ -100,8 +102,6 @@ func (a *Action[I, O, S]) setTracingState(tstate *tracing.State) { a.tstate = ts // Run executes the Action's function in a new trace span. func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb func(context.Context, S) error) (output O, err error) { - // TODO: validate input against JSONSchema for I. - // TODO: validate output against JSONSchema for O. internal.Logger(ctx).Debug("Action.Run", "name", a.name, "input", fmt.Sprintf("%#v", input)) @@ -119,7 +119,45 @@ func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb func(context.Cont return tracing.RunInNewSpan(ctx, tstate, a.name, "action", false, input, func(ctx context.Context, input I) (O, error) { start := time.Now() - out, err := a.fn(ctx, input, cb) + var err error + inputSchema, ok := a.Metadata["inputSchema"].(*jsonschema.Schema) + if ok { + var userInput any + switch v := any(input).(type) { + case flowInstructioner: + if v.StartInput() != nil { + userInput = v.StartInput() + } + if v.ScheduleInput() != nil { + userInput = v.ScheduleInput() + } + default: + userInput = input + } + err = ValidateObject(userInput, inputSchema) + if err != nil { + err = fmt.Errorf("invalid input: %w", err) + } + } + var out O + if err == nil { + out, err = a.fn(ctx, input, cb) + log.Printf("alexpascal: out: %+v", out) + outputSchema, ok := a.Metadata["outputSchema"].(*jsonschema.Schema) + if ok { + var result any + switch v := any(out).(type) { + case flowStater: + result = v.result() + default: + result = out + } + err = ValidateObject(result, outputSchema) + if err != nil { + err = fmt.Errorf("invalid output: %w", err) + } + } + } latency := time.Since(start) if err != nil { writeActionFailure(ctx, a.name, latency, err) @@ -229,3 +267,35 @@ func inferJSONSchema(x any) (s *jsonschema.Schema) { s.Version = "" return s } + +// ValidateObject will take any object and validate it against the expected schema. +// It will return an error if it doesn't match the schema, otherwise it will return nil. +func ValidateObject(obj any, schema *jsonschema.Schema) error { + schemaBytes, err := schema.MarshalJSON() + if err != nil { + return fmt.Errorf("schema is not valid: %w", err) + } + + jsonBytes, err := json.Marshal(obj) + if err != nil { + return fmt.Errorf("object is not a valid JSON type: %w", err) + } + + schemaLoader := gojsonschema.NewBytesLoader(schemaBytes) + documentLoader := gojsonschema.NewBytesLoader(jsonBytes) + + result, err := gojsonschema.Validate(schemaLoader, documentLoader) + if err != nil { + return err + } + + if !result.Valid() { + var errMsg string + for _, err := range result.Errors() { + errMsg += fmt.Sprintf("- %s\n", err) + } + return fmt.Errorf("object did not match expected schema:\n%s", errMsg) + } + + return nil +} diff --git a/go/core/flow.go b/go/core/flow.go index 1457906b1a..2fa2e2835f 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -164,6 +164,27 @@ type retryInstruction struct { FlowID string `json:"flowId,omitempty"` } +// flowInstructioner is the common type of all flowInstruction[I] types. +type flowInstructioner interface { + IsFlowInstruction() + StartInput() any + ScheduleInput() any +} + +func (fi *flowInstruction[I]) IsFlowInstruction() {} +func (fi *flowInstruction[I]) StartInput() any { + if fi.Start != nil { + return fi.Start.Input + } + return nil +} +func (fi *flowInstruction[I]) ScheduleInput() any { + if fi.Schedule != nil { + return fi.Schedule.Input + } + return nil +} + // A flowState is a persistent representation of a flow that may be in the middle of running. // It contains all the information needed to resume a flow, including the original input // and a cache of all completed steps. @@ -203,6 +224,7 @@ type flowStater interface { lock() unlock() cache() map[string]json.RawMessage + result() any } // isFlowState implements flowStater. @@ -210,6 +232,12 @@ func (fs *flowState[I, O]) isFlowState() {} func (fs *flowState[I, O]) lock() { fs.mu.Lock() } func (fs *flowState[I, O]) unlock() { fs.mu.Unlock() } func (fs *flowState[I, O]) cache() map[string]json.RawMessage { return fs.Cache } +func (fs *flowState[I, O]) result() any { + if fs.Operation.Done { + return fs.Operation.Result.Response + } + return nil +} // An operation describes the state of a Flow that may still be in progress. type operation[O any] struct { From 2a613f08da546f0ebeff873b006a46885bffee31 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Sun, 26 May 2024 16:57:26 -0700 Subject: [PATCH 12/27] Update action.go --- go/core/action.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index b1f80390f8..352185d0b3 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "fmt" - "log" "maps" "reflect" "time" @@ -142,7 +141,6 @@ func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb func(context.Cont var out O if err == nil { out, err = a.fn(ctx, input, cb) - log.Printf("alexpascal: out: %+v", out) outputSchema, ok := a.Metadata["outputSchema"].(*jsonschema.Schema) if ok { var result any From 6030500eb2f68df959a4257185546bbe48f6a817 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 27 May 2024 08:34:08 -0700 Subject: [PATCH 13/27] Simplified input/output validation. --- go/ai/document.go | 102 ++++++++++----------------- go/ai/document_test.go | 38 +++++----- go/ai/generator.go | 10 ++- go/ai/generator_test.go | 12 ++-- go/core/action.go | 77 ++++---------------- go/core/flow.go | 44 ++++++++---- go/core/validation.go | 60 ++++++++++++++++ go/plugins/dotprompt/genkit_test.go | 4 +- go/plugins/dotprompt/render.go | 2 +- go/plugins/dotprompt/render_test.go | 4 +- go/plugins/googleai/googleai.go | 8 +-- go/plugins/googleai/googleai_test.go | 6 +- go/plugins/localvec/localvec_test.go | 2 +- go/plugins/pinecone/genkit.go | 2 +- go/plugins/pinecone/genkit_test.go | 2 +- go/plugins/vertexai/embed.go | 2 +- go/plugins/vertexai/vertexai.go | 8 +-- go/plugins/vertexai/vertexai_test.go | 4 +- 18 files changed, 193 insertions(+), 194 deletions(-) create mode 100644 go/core/validation.go diff --git a/go/ai/document.go b/go/ai/document.go index eb40f37173..7910802e63 100644 --- a/go/ai/document.go +++ b/go/ai/document.go @@ -31,11 +31,11 @@ type Document struct { // A Part is one part of a [Document]. This may be plain text or it // may be a URL (possibly a "data:" URL with embedded data). type Part struct { - kind partKind - contentType string // valid for kind==blob - text string // valid for kind∈{text,blob} - toolRequest *ToolRequest // valid for kind==partToolRequest - toolResponse *ToolResponse // valid for kind==partToolResponse + Kind partKind `json:"kind,omitempty"` + ContentType string `json:"contentType,omitempty"` // valid for kind==blob + Text string `json:"text,omitempty"` // valid for kind∈{text,blob} + ToolRequest *ToolRequest `json:"toolreq,omitempty"` // valid for kind==partToolRequest + ToolResponse *ToolResponse `json:"toolresp,omitempty"` // valid for kind==partToolResponse } type partKind int8 @@ -50,82 +50,56 @@ const ( // NewTextPart returns a Part containing text. func NewTextPart(text string) *Part { - return &Part{kind: partText, text: text} + return &Part{Kind: partText, ContentType: "plain/text", Text: text} } // NewMediaPart returns a Part containing structured data described // by the given mimeType. func NewMediaPart(mimeType, contents string) *Part { - return &Part{kind: partMedia, contentType: mimeType, text: contents} + return &Part{Kind: partMedia, ContentType: mimeType, Text: contents} } // NewDataPart returns a Part containing raw string data. func NewDataPart(contents string) *Part { - return &Part{kind: partData, text: contents} + return &Part{Kind: partData, Text: contents} } // NewToolRequestPart returns a Part containing a request from // the model to the client to run a Tool. // (Only genkit plugins should need to use this function.) func NewToolRequestPart(r *ToolRequest) *Part { - return &Part{kind: partToolRequest, toolRequest: r} + return &Part{Kind: partToolRequest, ToolRequest: r} } // NewToolResponsePart returns a Part containing the results // of applying a Tool that the model requested. func NewToolResponsePart(r *ToolResponse) *Part { - return &Part{kind: partToolResponse, toolResponse: r} + return &Part{Kind: partToolResponse, ToolResponse: r} } // IsText reports whether the [Part] contains plain text. func (p *Part) IsText() bool { - return p.kind == partText + return p.Kind == partText } // IsMedia reports whether the [Part] contains structured media data. func (p *Part) IsMedia() bool { - return p.kind == partMedia + return p.Kind == partMedia } // IsData reports whether the [Part] contains unstructured data. func (p *Part) IsData() bool { - return p.kind == partData + return p.Kind == partData } // IsToolRequest reports whether the [Part] contains a request to run a tool. func (p *Part) IsToolRequest() bool { - return p.kind == partToolRequest + return p.Kind == partToolRequest } // IsToolResponse reports whether the [Part] contains the result of running a tool. func (p *Part) IsToolResponse() bool { - return p.kind == partToolResponse -} - -// Text returns the text. This is either plain text or a URL. -func (p *Part) Text() string { - return p.text -} - -// ContentType returns the type of the content. -// This is only interesting if IsBlob() is true. -func (p *Part) ContentType() string { - if p.kind == partText { - return "text/plain" - } - return p.contentType -} - -// ToolRequest returns a request from the model for the client to run a tool. -// Valid only if [IsToolRequest] is true. -func (p *Part) ToolRequest() *ToolRequest { - return p.toolRequest -} - -// ToolResponse returns the tool response the client is sending to the model. -// Valid only if [IsToolResponse] is true. -func (p *Part) ToolResponse() *ToolResponse { - return p.toolResponse + return p.Kind == partToolResponse } // MarshalJSON is called by the JSON marshaler to write out a Part. @@ -133,23 +107,23 @@ func (p *Part) MarshalJSON() ([]byte, error) { // This is not handled by the schema generator because // Part is defined in TypeScript as a union. - switch p.kind { + switch p.Kind { case partText: v := textPart{ - Text: p.text, + Text: p.Text, } return json.Marshal(v) case partMedia: v := mediaPart{ Media: &mediaPartMedia{ - ContentType: p.contentType, - Url: p.text, + ContentType: p.ContentType, + Url: p.Text, }, } return json.Marshal(v) case partData: v := dataPart{ - Data: p.text, + Data: p.Text, } return json.Marshal(v) case partToolRequest: @@ -159,18 +133,18 @@ func (p *Part) MarshalJSON() ([]byte, error) { v := struct { ToolReq *ToolRequest `json:"toolreq,omitempty"` }{ - ToolReq: p.toolRequest, + ToolReq: p.ToolRequest, } return json.Marshal(v) case partToolResponse: v := struct { ToolResp *ToolResponse `json:"toolresp,omitempty"` }{ - ToolResp: p.toolResponse, + ToolResp: p.ToolResponse, } return json.Marshal(v) default: - return nil, fmt.Errorf("invalid part kind %v", p.kind) + return nil, fmt.Errorf("invalid part kind %v", p.Kind) } } @@ -193,23 +167,23 @@ func (p *Part) UnmarshalJSON(b []byte) error { switch { case s.Media != nil: - p.kind = partMedia - p.text = s.Media.Url - p.contentType = s.Media.ContentType + p.Kind = partMedia + p.Text = s.Media.Url + p.ContentType = s.Media.ContentType case s.ToolReq != nil: - p.kind = partToolRequest - p.toolRequest = s.ToolReq + p.Kind = partToolRequest + p.ToolRequest = s.ToolReq case s.ToolResp != nil: - p.kind = partToolResponse - p.toolResponse = s.ToolResp + p.Kind = partToolResponse + p.ToolResponse = s.ToolResp default: - p.kind = partText - p.text = s.Text - p.contentType = "" + p.Kind = partText + p.Text = s.Text + p.ContentType = "" if s.Data != "" { // Note: if part is completely empty, we use text by default. - p.kind = partData - p.text = s.Data + p.Kind = partData + p.Text = s.Data } } return nil @@ -220,9 +194,9 @@ func (p *Part) UnmarshalJSON(b []byte) error { func DocumentFromText(text string, metadata map[string]any) *Document { return &Document{ Content: []*Part{ - &Part{ - kind: partText, - text: text, + { + Kind: partText, + Text: text, }, }, Metadata: metadata, diff --git a/go/ai/document_test.go b/go/ai/document_test.go index bc30eb7f44..c2c12dbc7b 100644 --- a/go/ai/document_test.go +++ b/go/ai/document_test.go @@ -42,28 +42,28 @@ func TestDocumentJSON(t *testing.T) { d := Document{ Content: []*Part{ &Part{ - kind: partText, - text: "hi", + Kind: partText, + Text: "hi", }, &Part{ - kind: partMedia, - contentType: "text/plain", - text: "data:,bye", + Kind: partMedia, + ContentType: "text/plain", + Text: "data:,bye", }, &Part{ - kind: partData, - text: "somedata\x00string", + Kind: partData, + Text: "somedata\x00string", }, &Part{ - kind: partToolRequest, - toolRequest: &ToolRequest{ + Kind: partToolRequest, + ToolRequest: &ToolRequest{ Name: "tool1", Input: map[string]any{"arg1": 3.3, "arg2": "foo"}, }, }, &Part{ - kind: partToolResponse, - toolResponse: &ToolResponse{ + Kind: partToolResponse, + ToolResponse: &ToolResponse{ Name: "tool1", Output: map[string]any{"res1": 4.4, "res2": "bar"}, }, @@ -83,22 +83,22 @@ func TestDocumentJSON(t *testing.T) { } cmpPart := func(a, b *Part) bool { - if a.kind != b.kind { + if a.Kind != b.Kind { return false } - switch a.kind { + switch a.Kind { case partText: - return a.text == b.text + return a.Text == b.Text case partMedia: - return a.contentType == b.contentType && a.text == b.text + return a.ContentType == b.ContentType && a.Text == b.Text case partData: - return a.text == b.text + return a.Text == b.Text case partToolRequest: - return reflect.DeepEqual(a.toolRequest, b.toolRequest) + return reflect.DeepEqual(a.ToolRequest, b.ToolRequest) case partToolResponse: - return reflect.DeepEqual(a.toolResponse, b.toolResponse) + return reflect.DeepEqual(a.ToolResponse, b.ToolResponse) default: - t.Fatalf("bad part kind %v", a.kind) + t.Fatalf("bad part kind %v", a.Kind) return false } } diff --git a/go/ai/generator.go b/go/ai/generator.go index 6761a6c00b..a4015b0bb4 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -167,9 +167,7 @@ func conformOutput(input *GenerateRequest) error { } escapedJSON := strconv.Quote(string(jsonBytes)) - part := &Part{ - text: fmt.Sprintf("Output should be in JSON format and conform to the following schema:\n\n```%s```", escapedJSON), - } + part := NewTextPart(fmt.Sprintf("Output should be in JSON format and conform to the following schema:\n\n```%s```", escapedJSON)) input.Messages[len(input.Messages)-1].Content = append(input.Messages[len(input.Messages)-1].Content, part) } return nil @@ -252,7 +250,7 @@ func handleToolRequest(ctx context.Context, req *GenerateRequest, resp *Generate return nil, nil } - toolReq := part.ToolRequest() + toolReq := part.ToolRequest output, err := RunTool(ctx, toolReq.Name, toolReq.Input) if err != nil { return nil, err @@ -296,11 +294,11 @@ func (c *Candidate) Text() (string, error) { return "", errors.New("candidate message has no content") } if len(msg.Content) == 1 { - return msg.Content[0].Text(), nil + return msg.Content[0].Text, nil } else { var sb strings.Builder for _, p := range msg.Content { - sb.WriteString(p.Text()) + sb.WriteString(p.Text) } return sb.String(), nil } diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index fa28e95f33..b744f56d2c 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -25,7 +25,7 @@ func TestValidateCandidate(t *testing.T) { candidate := &Candidate{ Message: &Message{ Content: []*Part{ - {text: "Hello, World!"}, + {Text: "Hello, World!"}, }, }, } @@ -40,7 +40,7 @@ func TestValidateCandidate(t *testing.T) { candidate := &Candidate{ Message: &Message{ Content: []*Part{ - {text: `{ + {Text: `{ "name": "John", "age": 30, "address": { @@ -81,7 +81,7 @@ func TestValidateCandidate(t *testing.T) { candidate := &Candidate{ Message: &Message{ Content: []*Part{ - {text: `{"name": "John", "age": "30"}`}, + {Text: `{"name": "John", "age": "30"}`}, }, }, } @@ -104,7 +104,7 @@ func TestValidateCandidate(t *testing.T) { candidate := &Candidate{ Message: &Message{ Content: []*Part{ - {text: `{"name": "John", "age": 30`}, // Missing trailing }. + {Text: `{"name": "John", "age": 30`}, // Missing trailing }. }, }, } @@ -142,7 +142,7 @@ func TestValidateCandidate(t *testing.T) { candidate := &Candidate{ Message: &Message{ Content: []*Part{ - {text: `{"name": "John", "height": "190"}`}, + {Text: `{"name": "John", "height": "190"}`}, }, }, } @@ -166,7 +166,7 @@ func TestValidateCandidate(t *testing.T) { candidate := &Candidate{ Message: &Message{ Content: []*Part{ - {text: `{"name": "John", "age": 30}`}, + {Text: `{"name": "John", "age": 30}`}, }, }, } diff --git a/go/core/action.go b/go/core/action.go index 352185d0b3..2695afa950 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -25,7 +25,6 @@ import ( "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal" "github.com/invopop/jsonschema" - "github.com/xeipuuv/gojsonschema" ) // Func is the type of function that Actions and Flows execute. @@ -119,39 +118,14 @@ func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb func(context.Cont func(ctx context.Context, input I) (O, error) { start := time.Now() var err error - inputSchema, ok := a.Metadata["inputSchema"].(*jsonschema.Schema) - if ok { - var userInput any - switch v := any(input).(type) { - case flowInstructioner: - if v.StartInput() != nil { - userInput = v.StartInput() - } - if v.ScheduleInput() != nil { - userInput = v.ScheduleInput() - } - default: - userInput = input - } - err = ValidateObject(userInput, inputSchema) - if err != nil { - err = fmt.Errorf("invalid input: %w", err) - } + if err = ValidateObject(input, a.inputSchema); err != nil { + err = fmt.Errorf("invalid input: %w", err) } - var out O + var output O if err == nil { - out, err = a.fn(ctx, input, cb) - outputSchema, ok := a.Metadata["outputSchema"].(*jsonschema.Schema) - if ok { - var result any - switch v := any(out).(type) { - case flowStater: - result = v.result() - default: - result = out - } - err = ValidateObject(result, outputSchema) - if err != nil { + output, err = a.fn(ctx, input, cb) + if err != nil { + if err = ValidateObject(output, a.outputSchema); err != nil { err = fmt.Errorf("invalid output: %w", err) } } @@ -162,11 +136,16 @@ func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb func(context.Cont return internal.Zero[O](), err } writeActionSuccess(ctx, a.name, latency) - return out, nil + return output, nil }) } +// runJSON runs an action with JSON input. This is only used in development mode. func (a *Action[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) { + // Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process. + if err := ValidateJSON(input, a.inputSchema); err != nil { + return nil, err + } var in I if err := json.Unmarshal(input, &in); err != nil { return nil, err @@ -265,35 +244,3 @@ func inferJSONSchema(x any) (s *jsonschema.Schema) { s.Version = "" return s } - -// ValidateObject will take any object and validate it against the expected schema. -// It will return an error if it doesn't match the schema, otherwise it will return nil. -func ValidateObject(obj any, schema *jsonschema.Schema) error { - schemaBytes, err := schema.MarshalJSON() - if err != nil { - return fmt.Errorf("schema is not valid: %w", err) - } - - jsonBytes, err := json.Marshal(obj) - if err != nil { - return fmt.Errorf("object is not a valid JSON type: %w", err) - } - - schemaLoader := gojsonschema.NewBytesLoader(schemaBytes) - documentLoader := gojsonschema.NewBytesLoader(jsonBytes) - - result, err := gojsonschema.Validate(schemaLoader, documentLoader) - if err != nil { - return err - } - - if !result.Valid() { - var errMsg string - for _, err := range result.Errors() { - errMsg += fmt.Sprintf("- %s\n", err) - } - return fmt.Errorf("object did not match expected schema:\n%s", errMsg) - } - - return nil -} diff --git a/go/core/flow.go b/go/core/flow.go index 2fa2e2835f..b218cdc386 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -28,6 +28,7 @@ import ( "github.com/firebase/genkit/go/gtime" "github.com/firebase/genkit/go/internal" "github.com/google/uuid" + "github.com/invopop/jsonschema" otrace "go.opentelemetry.io/otel/trace" ) @@ -88,10 +89,12 @@ import ( // A Flow[I, O, S] represents a function from I to O. The S parameter is for // flows that support streaming: providing their results incrementally. type Flow[I, O, S any] struct { - name string // The last component of the flow's key in the registry. - fn Func[I, O, S] // The function to run. - stateStore FlowStateStore // Where FlowStates are stored, to support resumption. - tstate *tracing.State // set from the action when the flow is defined + name string // The last component of the flow's key in the registry. + fn Func[I, O, S] // The function to run. + stateStore FlowStateStore // Where FlowStates are stored, to support resumption. + tstate *tracing.State // Set from the action when the flow is defined + inputSchema *jsonschema.Schema // Schema of the input to the flow + outputSchema *jsonschema.Schema // Schema of the output out of the flow // TODO(jba): scheduler // TODO(jba): experimentalDurable // TODO(jba): authPolicy @@ -104,9 +107,13 @@ func DefineFlow[I, O, S any](name string, fn Func[I, O, S]) *Flow[I, O, S] { } func defineFlow[I, O, S any](r *registry, name string, fn Func[I, O, S]) *Flow[I, O, S] { + var i I + var o O f := &Flow[I, O, S]{ - name: name, - fn: fn, + name: name, + fn: fn, + inputSchema: inferJSONSchema(i), + outputSchema: inferJSONSchema(o), // TODO(jba): set stateStore? } a := f.action() @@ -272,11 +279,9 @@ type FlowResult[O any] struct { // action creates an action for the flow. See the comment at the top of this file for more information. func (f *Flow[I, O, S]) action() *Action[*flowInstruction[I], *flowState[I, O], S] { - var i I - var o O metadata := map[string]any{ - "inputSchema": inferJSONSchema(i), - "outputSchema": inferJSONSchema(o), + "inputSchema": f.inputSchema, + "outputSchema": f.outputSchema, } cback := func(ctx context.Context, inst *flowInstruction[I], cb func(context.Context, S) error) (*flowState[I, O], error) { tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true") @@ -319,6 +324,10 @@ type flow interface { func (f *Flow[I, O, S]) Name() string { return f.name } func (f *Flow[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) { + // Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process. + if err := ValidateJSON(input, f.inputSchema); err != nil { + return nil, &httpError{http.StatusBadRequest, err} + } var in I if err := json.Unmarshal(input, &in); err != nil { return nil, &httpError{http.StatusBadRequest, err} @@ -398,7 +407,19 @@ func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dis // TODO(jba): Save rootSpanContext in the state. // TODO(jba): If input is missing, get it from state.input and overwrite metadata.input. start := time.Now() - output, err := f.fn(ctx, input, cb) + var err error + if err = ValidateObject(input, f.inputSchema); err != nil { + err = fmt.Errorf("invalid input: %w", err) + } + var output O + if err == nil { + output, err = f.fn(ctx, input, cb) + if err != nil { + if err = ValidateObject(output, f.outputSchema); err != nil { + err = fmt.Errorf("invalid output: %w", err) + } + } + } latency := time.Since(start) if err != nil { // TODO(jba): handle InterruptError @@ -561,7 +582,6 @@ func RunFlow[I, O, S any](ctx context.Context, flow *Flow[I, O, S], input I) (O, // InternalStreamFlow is for use by genkit.StreamFlow exclusively. // It is not subject to any backwards compatibility guarantees. func InternalStreamFlow[I, O, S any](ctx context.Context, flow *Flow[I, O, S], input I, callback func(context.Context, S) error) (O, error) { - state, err := flow.start(ctx, input, callback) if err != nil { return internal.Zero[O](), err diff --git a/go/core/validation.go b/go/core/validation.go new file mode 100644 index 0000000000..b9ea2b6bc4 --- /dev/null +++ b/go/core/validation.go @@ -0,0 +1,60 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "encoding/json" + "fmt" + + "github.com/invopop/jsonschema" + "github.com/xeipuuv/gojsonschema" +) + +// ValidateObject will validate any object against the expected schema. +// It will return an error if it doesn't match the schema, otherwise it will return nil. +func ValidateObject(obj any, schema *jsonschema.Schema) error { + jsonBytes, err := json.Marshal(obj) + if err != nil { + return fmt.Errorf("object is not a valid JSON type: %w", err) + } + return ValidateJSON(jsonBytes, schema) +} + +// ValidateJSON will validate JSON against the expected schema. +// It will return an error if it doesn't match the schema, otherwise it will return nil. +func ValidateJSON(jsonBytes json.RawMessage, schema *jsonschema.Schema) error { + schemaBytes, err := schema.MarshalJSON() + if err != nil { + return fmt.Errorf("schema is not valid: %w", err) + } + + schemaLoader := gojsonschema.NewBytesLoader(schemaBytes) + documentLoader := gojsonschema.NewBytesLoader(jsonBytes) + + result, err := gojsonschema.Validate(schemaLoader, documentLoader) + if err != nil { + return err + } + + if !result.Valid() { + var errors string + for _, err := range result.Errors() { + errors += fmt.Sprintf("- %s\n", err) + } + return fmt.Errorf("data did not match the schema:\n%s", errors) + } + + return nil +} diff --git a/go/plugins/dotprompt/genkit_test.go b/go/plugins/dotprompt/genkit_test.go index 8d03f5b24f..aef5ce1076 100644 --- a/go/plugins/dotprompt/genkit_test.go +++ b/go/plugins/dotprompt/genkit_test.go @@ -25,7 +25,7 @@ import ( type testGenerator struct{} func (testGenerator) Generate(ctx context.Context, req *ai.GenerateRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) { - input := req.Messages[0].Content[0].Text() + input := req.Messages[0].Content[0].Text output := fmt.Sprintf("AI reply to %q", input) r := &ai.GenerateResponse{ @@ -69,7 +69,7 @@ func TestExecute(t *testing.T) { t.FailNow() } } - got := msg.Content[0].Text() + got := msg.Content[0].Text want := `AI reply to "TestExecute"` if got != want { t.Errorf("fake generator replied with %q, want %q", got, want) diff --git a/go/plugins/dotprompt/render.go b/go/plugins/dotprompt/render.go index 82bda0a899..d47c0b38e4 100644 --- a/go/plugins/dotprompt/render.go +++ b/go/plugins/dotprompt/render.go @@ -41,7 +41,7 @@ func (p *Prompt) RenderText(variables map[string]any) (string, error) { if !part.IsText() { return "", errors.New("RenderText: multi-modal prompt can't be rendered as text") } - sb.WriteString(part.Text()) + sb.WriteString(part.Text) } return sb.String(), nil } diff --git a/go/plugins/dotprompt/render_test.go b/go/plugins/dotprompt/render_test.go index 8ffe1022ed..58917ad7e3 100644 --- a/go/plugins/dotprompt/render_test.go +++ b/go/plugins/dotprompt/render_test.go @@ -201,10 +201,10 @@ func TestRenderMessages(t *testing.T) { if a.IsText() != b.IsText() { return false } - if a.Text() != b.Text() { + if a.Text != b.Text { return false } - if a.ContentType() != b.ContentType() { + if a.ContentType != b.ContentType { return false } return true diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 17eea79c1d..28ba7881f0 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -235,19 +235,19 @@ func convertParts(parts []*ai.Part) []genai.Part { func convertPart(p *ai.Part) genai.Part { switch { case p.IsText(): - return genai.Text(p.Text()) + return genai.Text(p.Text) case p.IsMedia(): - return genai.Blob{MIMEType: p.ContentType(), Data: []byte(p.Text())} + return genai.Blob{MIMEType: p.ContentType, Data: []byte(p.Text)} case p.IsData(): panic("googleai does not support Data parts") case p.IsToolResponse(): - toolResp := p.ToolResponse() + toolResp := p.ToolResponse return genai.FunctionResponse{ Name: toolResp.Name, Response: toolResp.Output, } case p.IsToolRequest(): - toolReq := p.ToolRequest() + toolReq := p.ToolRequest return genai.FunctionCall{ Name: toolReq.Name, Args: toolReq.Input, diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index 96c557a56e..3330a867a3 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -83,7 +83,7 @@ func TestGenerator(t *testing.T) { if err != nil { t.Fatal(err) } - out := resp.Candidates[0].Message.Content[0].Text() + out := resp.Candidates[0].Message.Content[0].Text if out != "France" { t.Errorf("got \"%s\", expecting \"France\"", out) } @@ -115,7 +115,7 @@ func TestGeneratorStreaming(t *testing.T) { parts := 0 _, err = g.Generate(ctx, req, func(ctx context.Context, c *ai.Candidate) error { parts++ - out += c.Message.Content[0].Text() + out += c.Message.Content[0].Text return nil }) if err != nil { @@ -193,7 +193,7 @@ func TestGeneratorTool(t *testing.T) { t.Fatal(err) } - out := resp.Candidates[0].Message.Content[0].Text() + out := resp.Candidates[0].Message.Content[0].Text if !strings.Contains(out, "12.25") { t.Errorf("got %s, expecting it to contain \"12.25\"", out) } diff --git a/go/plugins/localvec/localvec_test.go b/go/plugins/localvec/localvec_test.go index 978e4b1301..80f9bf9c94 100644 --- a/go/plugins/localvec/localvec_test.go +++ b/go/plugins/localvec/localvec_test.go @@ -82,7 +82,7 @@ func TestLocalVec(t *testing.T) { t.Errorf("got %d results, expected 2", len(docs)) } for _, d := range docs { - text := d.Content[0].Text() + text := d.Content[0].Text if !strings.HasPrefix(text, "hello") { t.Errorf("returned doc text %q does not start with %q", text, "hello") } diff --git a/go/plugins/pinecone/genkit.go b/go/plugins/pinecone/genkit.go index e05e8222f9..4413c37a5c 100644 --- a/go/plugins/pinecone/genkit.go +++ b/go/plugins/pinecone/genkit.go @@ -149,7 +149,7 @@ func (r *retriever) Index(ctx context.Context, req *ai.IndexerRequest) error { // but it loses the structure of the document. var sb strings.Builder for _, p := range doc.Content { - sb.WriteString(p.Text()) + sb.WriteString(p.Text) } metadata[r.textKey] = sb.String() diff --git a/go/plugins/pinecone/genkit_test.go b/go/plugins/pinecone/genkit_test.go index cb6d5421e7..4c72fa437c 100644 --- a/go/plugins/pinecone/genkit_test.go +++ b/go/plugins/pinecone/genkit_test.go @@ -132,7 +132,7 @@ func TestGenkit(t *testing.T) { t.Errorf("got %d results, expected 2", len(docs)) } for _, d := range docs { - text := d.Content[0].Text() + text := d.Content[0].Text if !strings.HasPrefix(text, "hello") { t.Errorf("returned doc text %q does not start with %q", text, "hello") } diff --git a/go/plugins/vertexai/embed.go b/go/plugins/vertexai/embed.go index baa15ca804..22665f215b 100644 --- a/go/plugins/vertexai/embed.go +++ b/go/plugins/vertexai/embed.go @@ -59,7 +59,7 @@ func (e *embedder) Embed(ctx context.Context, req *ai.EmbedRequest) ([]float32, instances := make([]*structpb.Value, 0, len(req.Document.Content)) for _, part := range req.Document.Content { fields := map[string]any{ - "content": part.Text(), + "content": part.Text, } if title != "" { fields["title"] = title diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index f9e0c1c677..828274d79a 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -203,19 +203,19 @@ func convertParts(parts []*ai.Part) []genai.Part { func convertPart(p *ai.Part) genai.Part { switch { case p.IsText(): - return genai.Text(p.Text()) + return genai.Text(p.Text) case p.IsMedia(): - return genai.Blob{MIMEType: p.ContentType(), Data: []byte(p.Text())} + return genai.Blob{MIMEType: p.ContentType, Data: []byte(p.Text)} case p.IsData(): panic("vertexai does not support Data parts") case p.IsToolResponse(): - toolResp := p.ToolResponse() + toolResp := p.ToolResponse return genai.FunctionResponse{ Name: toolResp.Name, Response: toolResp.Output, } case p.IsToolRequest(): - toolReq := p.ToolRequest() + toolReq := p.ToolRequest return genai.FunctionCall{ Name: toolReq.Name, Args: toolReq.Input, diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index be16c25620..c53f141b0c 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -56,7 +56,7 @@ func TestGenerator(t *testing.T) { if err != nil { t.Fatal(err) } - out := resp.Candidates[0].Message.Content[0].Text() + out := resp.Candidates[0].Message.Content[0].Text if !strings.Contains(out, "France") { t.Errorf("got \"%s\", expecting it would contain \"France\"", out) } @@ -128,7 +128,7 @@ func TestGeneratorTool(t *testing.T) { t.Fatal(err) } - out := resp.Candidates[0].Message.Content[0].Text() + out := resp.Candidates[0].Message.Content[0].Text if !strings.Contains(out, "12.25") { t.Errorf("got %s, expecting it to contain \"12.25\"", out) } From be2f5a792c2b5bde9b2d43a556c37d96406275cc Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Mon, 27 May 2024 09:12:55 -0700 Subject: [PATCH 14/27] Reused new validation functions in output conformance. --- go/ai/generator.go | 18 ++++++++++++++++-- go/core/validation.go | 15 ++++++++++----- go/plugins/dotprompt/dotprompt.go | 2 +- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/go/ai/generator.go b/go/ai/generator.go index a4015b0bb4..dc389364cb 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -160,7 +160,7 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, // conformOutput appends a message to the request indicating conformance to the expected schema. func conformOutput(input *GenerateRequest) error { - if len(input.Output.Schema) > 0 && len(input.Messages) > 0 { + if input.Output.Format == OutputFormatJSON && len(input.Messages) > 0 { jsonBytes, err := json.Marshal(input.Output.Schema) if err != nil { return fmt.Errorf("expected schema is not valid: %w", err) @@ -174,10 +174,24 @@ func conformOutput(input *GenerateRequest) error { } // validCandidates finds all candidates that match the expected schema. +// It will strip JSON markdown delimiters from the response. func validCandidates(ctx context.Context, resp *GenerateResponse) []*Candidate { var candidates []*Candidate for i, c := range resp.Candidates { - err := validateCandidate(c, resp.Request.Output) + var err error + if resp.Request.Output.Format == OutputFormatJSON { + var text string + text, err = c.Text() + if err == nil { + text = stripJSONDelimiters(text) + // TODO: Replace the text in the candidate with the stripped version. + var jsonBytes []byte + jsonBytes, err = json.Marshal(resp.Request.Output.Schema) + if err == nil { + err = core.ValidateRaw([]byte(text), jsonBytes) + } + } + } if err == nil { candidates = append(candidates, c) } else { diff --git a/go/core/validation.go b/go/core/validation.go index b9ea2b6bc4..639cb3608a 100644 --- a/go/core/validation.go +++ b/go/core/validation.go @@ -24,24 +24,29 @@ import ( // ValidateObject will validate any object against the expected schema. // It will return an error if it doesn't match the schema, otherwise it will return nil. -func ValidateObject(obj any, schema *jsonschema.Schema) error { - jsonBytes, err := json.Marshal(obj) +func ValidateObject(data any, schema *jsonschema.Schema) error { + dataBytes, err := json.Marshal(data) if err != nil { return fmt.Errorf("object is not a valid JSON type: %w", err) } - return ValidateJSON(jsonBytes, schema) + return ValidateJSON(dataBytes, schema) } // ValidateJSON will validate JSON against the expected schema. // It will return an error if it doesn't match the schema, otherwise it will return nil. -func ValidateJSON(jsonBytes json.RawMessage, schema *jsonschema.Schema) error { +func ValidateJSON(dataBytes json.RawMessage, schema *jsonschema.Schema) error { schemaBytes, err := schema.MarshalJSON() if err != nil { return fmt.Errorf("schema is not valid: %w", err) } + return ValidateRaw(dataBytes, schemaBytes) +} +// ValidateRaw will validate JSON data against the JSON schema. +// It will return an error if it doesn't match the schema, otherwise it will return nil. +func ValidateRaw(dataBytes json.RawMessage, schemaBytes json.RawMessage) error { schemaLoader := gojsonschema.NewBytesLoader(schemaBytes) - documentLoader := gojsonschema.NewBytesLoader(jsonBytes) + documentLoader := gojsonschema.NewBytesLoader(dataBytes) result, err := gojsonschema.Validate(schemaLoader, documentLoader) if err != nil { diff --git a/go/plugins/dotprompt/dotprompt.go b/go/plugins/dotprompt/dotprompt.go index ab251637e2..b663810fb4 100644 --- a/go/plugins/dotprompt/dotprompt.go +++ b/go/plugins/dotprompt/dotprompt.go @@ -98,7 +98,7 @@ type Config struct { OutputFormat ai.OutputFormat // Desired output schema, for JSON output. - OutputSchema map[string]any // TODO: use *jsonschema.Schema + OutputSchema map[string]any // Cannot use *jsonschema.Schema because it will need to self-reflect. // Arbitrary metadata. Metadata map[string]any From 640ffb440b290420c25754956117dd440429e769 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 28 May 2024 14:36:42 -0700 Subject: [PATCH 15/27] Fixes for validation and updated tests. --- go/ai/document.go | 5 ++ go/ai/document_test.go | 2 +- go/ai/generator.go | 112 ++++++++++++++++++---------------------- go/ai/generator_test.go | 62 ++++++++++++---------- go/core/validation.go | 8 +-- 5 files changed, 95 insertions(+), 94 deletions(-) diff --git a/go/ai/document.go b/go/ai/document.go index 7910802e63..40ba03deef 100644 --- a/go/ai/document.go +++ b/go/ai/document.go @@ -53,6 +53,11 @@ func NewTextPart(text string) *Part { return &Part{Kind: partText, ContentType: "plain/text", Text: text} } +// NewJSONPart returns a Part containing JSON. +func NewJSONPart(text string) *Part { + return &Part{Kind: partText, ContentType: "application/json", Text: text} +} + // NewMediaPart returns a Part containing structured data described // by the given mimeType. func NewMediaPart(mimeType, contents string) *Part { diff --git a/go/ai/document_test.go b/go/ai/document_test.go index c2c12dbc7b..502bb14d5e 100644 --- a/go/ai/document_test.go +++ b/go/ai/document_test.go @@ -32,7 +32,7 @@ func TestDocumentFromText(t *testing.T) { if !p.IsText() { t.Errorf("IsText() == %t, want %t", p.IsText(), true) } - if got := p.Text(); got != data { + if got := p.Text; got != data { t.Errorf("Data() == %q, want %q", got, data) } } diff --git a/go/ai/generator.go b/go/ai/generator.go index dc389364cb..cae4ffb015 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -19,13 +19,13 @@ import ( "encoding/json" "errors" "fmt" + "log" "slices" "strconv" "strings" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/internal" - "github.com/xeipuuv/gojsonschema" ) // Generator is the interface used to query an AI model. @@ -85,9 +85,9 @@ func Generate(ctx context.Context, g Generator, input *GenerateRequest, cb func( return nil, err } - candidates := validCandidates(ctx, resp) - if len(candidates) == 0 { - return nil, errors.New("generation resulted in no candidates matching provided output schema") + candidates, err := validCandidates(ctx, resp) + if err != nil { + return nil, err } resp.Candidates = candidates @@ -140,9 +140,9 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, return nil, err } - candidates := validCandidates(ctx, resp) - if len(candidates) == 0 { - return nil, errors.New("generation resulted in no candidates matching provided output schema") + candidates, err := validCandidates(ctx, resp) + if err != nil { + return nil, err } resp.Candidates = candidates @@ -175,77 +175,63 @@ func conformOutput(input *GenerateRequest) error { // validCandidates finds all candidates that match the expected schema. // It will strip JSON markdown delimiters from the response. -func validCandidates(ctx context.Context, resp *GenerateResponse) []*Candidate { +func validCandidates(ctx context.Context, resp *GenerateResponse) ([]*Candidate, error) { var candidates []*Candidate for i, c := range resp.Candidates { - var err error - if resp.Request.Output.Format == OutputFormatJSON { - var text string - text, err = c.Text() - if err == nil { - text = stripJSONDelimiters(text) - // TODO: Replace the text in the candidate with the stripped version. - var jsonBytes []byte - jsonBytes, err = json.Marshal(resp.Request.Output.Schema) - if err == nil { - err = core.ValidateRaw([]byte(text), jsonBytes) - } - } - } + c, err := validCandidate(c, resp.Request.Output) if err == nil { candidates = append(candidates, c) } else { - internal.Logger(ctx).Debug("candidate did not match provided output schema", "index", i, "error", err.Error()) + internal.Logger(ctx).Debug("candidate did not match expected schema", "index", i, "error", err.Error()) } } - return candidates -} - -// validateCandidate will check a candidate against the expected schema. -// It will return an error if it does not match, otherwise it will return nil. -func validateCandidate(candidate *Candidate, outputSchema *GenerateRequestOutput) error { - if outputSchema.Format != OutputFormatJSON { - return nil - } - - text, err := candidate.Text() - if err != nil { - return err - } - - text = stripJSONDelimiters(text) - - var jsonData any - if err = json.Unmarshal([]byte(text), &jsonData); err != nil { - return fmt.Errorf("candidate did not have valid JSON: %w", err) - } - - schemaBytes, err := json.Marshal(outputSchema.Schema) - if err != nil { - return fmt.Errorf("expected schema is not valid: %w", err) - } - - schemaLoader := gojsonschema.NewStringLoader(string(schemaBytes)) - jsonLoader := gojsonschema.NewGoLoader(jsonData) - result, err := gojsonschema.Validate(schemaLoader, jsonLoader) - if err != nil { - return fmt.Errorf("failed to validate expected schema: %w", err) + if len(candidates) == 0 { + return nil, errors.New("generation resulted in no candidates matching expected schema") } + return candidates, nil +} - if !result.Valid() { - var errMsg string - for _, err := range result.Errors() { - errMsg += fmt.Sprintf("- %s\n", err) +// validCandidate will validate the candidate's response against the expected schema. +// It will return an error if it does not match, otherwise it will return a candidate with pure JSON content. +func validCandidate(candidate *Candidate, outputSchema *GenerateRequestOutput) (*Candidate, error) { + if outputSchema.Format == OutputFormatJSON { + text, err := candidate.Text() + if err != nil { + return nil, err + } + text = stripJSONDelimiters(text) + log.Printf("alexpascal: %s", text) + var schemaBytes []byte + schemaBytes, err = json.Marshal(outputSchema.Schema) + if err != nil { + return nil, fmt.Errorf("expected schema is not valid: %w", err) + } + if err = core.ValidateRaw([]byte(text), schemaBytes); err != nil { + return nil, err } - return fmt.Errorf("candidate did not match expected schema:\n%s", errMsg) + // TODO: Verify that it okay to replace all content with JSON. + candidate.Message.Content = []*Part{NewJSONPart(text)} } - - return nil + return candidate, nil } // stripJSONDelimiters strips Markdown JSON delimiters that may come back in the response. func stripJSONDelimiters(s string) string { - return strings.TrimSuffix(strings.TrimPrefix(s, "```json"), "```") + s = strings.TrimSpace(s) + delimiters := []string{"```", "~~~"} + for _, delimiter := range delimiters { + if strings.HasPrefix(s, delimiter) && strings.HasSuffix(s, delimiter) { + s = strings.TrimPrefix(s, delimiter) + s = strings.TrimSuffix(s, delimiter) + s = strings.TrimSpace(s) + if strings.HasPrefix(s, "json") { + s = strings.TrimPrefix(s, "json") + s = strings.TrimSpace(s) + } + break + } + } + return s } // handleToolRequest checks if a tool was requested by a generator. diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index b744f56d2c..211bd76fa6 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -20,35 +20,38 @@ import ( "github.com/stretchr/testify/assert" ) -func TestValidateCandidate(t *testing.T) { +func TestValidCandidate(t *testing.T) { + t.Parallel() + t.Run("Valid candidate with text format", func(t *testing.T) { candidate := &Candidate{ Message: &Message{ Content: []*Part{ - {Text: "Hello, World!"}, + NewTextPart("Hello, World!"), }, }, } outputSchema := &GenerateRequestOutput{ Format: OutputFormatText, } - err := validateCandidate(candidate, outputSchema) + _, err := validCandidate(candidate, outputSchema) assert.NoError(t, err) }) t.Run("Valid candidate with JSON format and matching schema", func(t *testing.T) { + json := `{ + "name": "John", + "age": 30, + "address": { + "street": "123 Main St", + "city": "New York", + "country": "USA" + } + }` candidate := &Candidate{ Message: &Message{ Content: []*Part{ - {Text: `{ - "name": "John", - "age": 30, - "address": { - "street": "123 Main St", - "city": "New York", - "country": "USA" - } - }`}, + NewTextPart(JSONMarkdown(json)), }, }, } @@ -73,15 +76,18 @@ func TestValidateCandidate(t *testing.T) { }, }, } - err := validateCandidate(candidate, outputSchema) + candidate, err := validCandidate(candidate, outputSchema) + assert.NoError(t, err) + text, err := candidate.Text() assert.NoError(t, err) + assert.EqualValues(t, text, json) }) t.Run("Invalid candidate with JSON format and non-matching schema", func(t *testing.T) { candidate := &Candidate{ Message: &Message{ Content: []*Part{ - {Text: `{"name": "John", "age": "30"}`}, + NewTextPart(JSONMarkdown(`{"name": "John", "age": "30"}`)), }, }, } @@ -95,23 +101,23 @@ func TestValidateCandidate(t *testing.T) { }, }, } - err := validateCandidate(candidate, outputSchema) + _, err := validCandidate(candidate, outputSchema) assert.Error(t, err) - assert.Contains(t, err.Error(), "candidate did not match expected schema") + assert.Contains(t, err.Error(), "data did not match expected schema") }) t.Run("Candidate with invalid JSON", func(t *testing.T) { candidate := &Candidate{ Message: &Message{ Content: []*Part{ - {Text: `{"name": "John", "age": 30`}, // Missing trailing }. + NewTextPart(JSONMarkdown(`{"name": "John", "age": 30`)), // Missing trailing }. }, }, } outputSchema := &GenerateRequestOutput{ Format: OutputFormatJSON, } - err := validateCandidate(candidate, outputSchema) + _, err := validCandidate(candidate, outputSchema) assert.Error(t, err) assert.Contains(t, err.Error(), "candidate did not have valid JSON") }) @@ -121,7 +127,7 @@ func TestValidateCandidate(t *testing.T) { outputSchema := &GenerateRequestOutput{ Format: OutputFormatJSON, } - err := validateCandidate(candidate, outputSchema) + _, err := validCandidate(candidate, outputSchema) assert.Error(t, err) assert.Equal(t, "candidate with no message", err.Error()) }) @@ -133,7 +139,7 @@ func TestValidateCandidate(t *testing.T) { outputSchema := &GenerateRequestOutput{ Format: OutputFormatJSON, } - err := validateCandidate(candidate, outputSchema) + _, err := validCandidate(candidate, outputSchema) assert.Error(t, err) assert.Equal(t, "candidate message has no content", err.Error()) }) @@ -142,7 +148,7 @@ func TestValidateCandidate(t *testing.T) { candidate := &Candidate{ Message: &Message{ Content: []*Part{ - {Text: `{"name": "John", "height": "190"}`}, + NewTextPart(JSONMarkdown(`{"name": "John", "height": 190}`)), }, }, } @@ -157,16 +163,16 @@ func TestValidateCandidate(t *testing.T) { "additionalProperties": false, }, } - err := validateCandidate(candidate, outputSchema) + _, err := validCandidate(candidate, outputSchema) assert.Error(t, err) - assert.Contains(t, err.Error(), "candidate did not match expected schema") + assert.Contains(t, err.Error(), "data did not match expected schema") }) t.Run("Invalid expected schema", func(t *testing.T) { candidate := &Candidate{ Message: &Message{ Content: []*Part{ - {Text: `{"name": "John", "age": 30}`}, + NewTextPart(JSONMarkdown(`{"name": "John", "age": 30}`)), }, }, } @@ -176,8 +182,12 @@ func TestValidateCandidate(t *testing.T) { "type": "invalid", }, } - err := validateCandidate(candidate, outputSchema) + _, err := validCandidate(candidate, outputSchema) assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to validate expected schema") + assert.Contains(t, err.Error(), "failed to validate data against expected schema") }) } + +func JSONMarkdown(text string) string { + return "```json\n" + text + "\n```" +} diff --git a/go/core/validation.go b/go/core/validation.go index 639cb3608a..94884fa699 100644 --- a/go/core/validation.go +++ b/go/core/validation.go @@ -27,7 +27,7 @@ import ( func ValidateObject(data any, schema *jsonschema.Schema) error { dataBytes, err := json.Marshal(data) if err != nil { - return fmt.Errorf("object is not a valid JSON type: %w", err) + return fmt.Errorf("data is not a valid JSON type: %w", err) } return ValidateJSON(dataBytes, schema) } @@ -37,7 +37,7 @@ func ValidateObject(data any, schema *jsonschema.Schema) error { func ValidateJSON(dataBytes json.RawMessage, schema *jsonschema.Schema) error { schemaBytes, err := schema.MarshalJSON() if err != nil { - return fmt.Errorf("schema is not valid: %w", err) + return fmt.Errorf("expected schema is not valid: %w", err) } return ValidateRaw(dataBytes, schemaBytes) } @@ -50,7 +50,7 @@ func ValidateRaw(dataBytes json.RawMessage, schemaBytes json.RawMessage) error { result, err := gojsonschema.Validate(schemaLoader, documentLoader) if err != nil { - return err + return fmt.Errorf("failed to validate data against expected schema: %w", err) } if !result.Valid() { @@ -58,7 +58,7 @@ func ValidateRaw(dataBytes json.RawMessage, schemaBytes json.RawMessage) error { for _, err := range result.Errors() { errors += fmt.Sprintf("- %s\n", err) } - return fmt.Errorf("data did not match the schema:\n%s", errors) + return fmt.Errorf("data did not match expected schema:\n%s", errors) } return nil From ddadeae550abbbb0c730f58335a0bfe93c52ddc7 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 28 May 2024 14:37:36 -0700 Subject: [PATCH 16/27] Update generator.go --- go/ai/generator.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/go/ai/generator.go b/go/ai/generator.go index cae4ffb015..fc29569a97 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -19,7 +19,6 @@ import ( "encoding/json" "errors" "fmt" - "log" "slices" "strconv" "strings" @@ -200,7 +199,6 @@ func validCandidate(candidate *Candidate, outputSchema *GenerateRequestOutput) ( return nil, err } text = stripJSONDelimiters(text) - log.Printf("alexpascal: %s", text) var schemaBytes []byte schemaBytes, err = json.Marshal(outputSchema.Schema) if err != nil { From 4ea1f868ec2ec60cc0395dd2ff24b708dccd8627 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 28 May 2024 14:39:48 -0700 Subject: [PATCH 17/27] Update generator.go --- go/ai/generator.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/ai/generator.go b/go/ai/generator.go index fc29569a97..208641400c 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -24,7 +24,7 @@ import ( "strings" "github.com/firebase/genkit/go/core" - "github.com/firebase/genkit/go/internal" + "github.com/firebase/genkit/go/core/logger" ) // Generator is the interface used to query an AI model. @@ -181,7 +181,7 @@ func validCandidates(ctx context.Context, resp *GenerateResponse) ([]*Candidate, if err == nil { candidates = append(candidates, c) } else { - internal.Logger(ctx).Debug("candidate did not match expected schema", "index", i, "error", err.Error()) + logger.FromContext(ctx).Debug("candidate did not match expected schema", "index", i, "error", err.Error()) } } if len(candidates) == 0 { From 473ff978614d6a7d54a3b8bd0fd465c14fc37a83 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 28 May 2024 15:13:13 -0700 Subject: [PATCH 18/27] Fixed broken test. --- go/ai/generator.go | 2 +- go/ai/generator_test.go | 2 +- go/core/validation.go | 6 ++++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/go/ai/generator.go b/go/ai/generator.go index 208641400c..c077d0061e 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -191,7 +191,7 @@ func validCandidates(ctx context.Context, resp *GenerateResponse) ([]*Candidate, } // validCandidate will validate the candidate's response against the expected schema. -// It will return an error if it does not match, otherwise it will return a candidate with pure JSON content. +// It will return an error if it does not match, otherwise it will return a candidate with JSON content and type. func validCandidate(candidate *Candidate, outputSchema *GenerateRequestOutput) (*Candidate, error) { if outputSchema.Format == OutputFormatJSON { text, err := candidate.Text() diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index 211bd76fa6..e8e85fedd1 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -119,7 +119,7 @@ func TestValidCandidate(t *testing.T) { } _, err := validCandidate(candidate, outputSchema) assert.Error(t, err) - assert.Contains(t, err.Error(), "candidate did not have valid JSON") + assert.Contains(t, err.Error(), "data is not valid JSON") }) t.Run("Candidate with no message", func(t *testing.T) { diff --git a/go/core/validation.go b/go/core/validation.go index 94884fa699..3966279ad0 100644 --- a/go/core/validation.go +++ b/go/core/validation.go @@ -45,6 +45,12 @@ func ValidateJSON(dataBytes json.RawMessage, schema *jsonschema.Schema) error { // ValidateRaw will validate JSON data against the JSON schema. // It will return an error if it doesn't match the schema, otherwise it will return nil. func ValidateRaw(dataBytes json.RawMessage, schemaBytes json.RawMessage) error { + var data any + // Do this check separately from below to get a better error message. + if err := json.Unmarshal(dataBytes, &data); err != nil { + return fmt.Errorf("data is not valid JSON: %w", err) + } + schemaLoader := gojsonschema.NewBytesLoader(schemaBytes) documentLoader := gojsonschema.NewBytesLoader(dataBytes) From d626ac8f4442dd4ef2b3b841bc231b5c96d27fbd Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 28 May 2024 18:38:54 -0700 Subject: [PATCH 19/27] Resolved comments. --- go/ai/generator.go | 32 ++++++++++++++-------------- go/ai/generator_test.go | 46 +++++++++++++++++++++++++---------------- go/core/flow.go | 35 +++---------------------------- go/core/validation.go | 11 +++++----- 4 files changed, 53 insertions(+), 71 deletions(-) diff --git a/go/ai/generator.go b/go/ai/generator.go index c077d0061e..a55d38b074 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -128,13 +128,13 @@ type generatorAction struct { // Generate implements Generator. This is like the [Generate] function, // but invokes the [core.Action] rather than invoking the Generator // directly. -func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { - if err := conformOutput(input); err != nil { +func (ga *generatorAction) Generate(ctx context.Context, req *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { + if err := conformOutput(req); err != nil { return nil, err } for { - resp, err := ga.action.Run(ctx, input, cb) + resp, err := ga.action.Run(ctx, req, cb) if err != nil { return nil, err } @@ -145,7 +145,7 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, } resp.Candidates = candidates - newReq, err := handleToolRequest(ctx, input, resp) + newReq, err := handleToolRequest(ctx, req, resp) if err != nil { return nil, err } @@ -153,21 +153,21 @@ func (ga *generatorAction) Generate(ctx context.Context, input *GenerateRequest, return resp, nil } - input = newReq + req = newReq } } // conformOutput appends a message to the request indicating conformance to the expected schema. -func conformOutput(input *GenerateRequest) error { - if input.Output.Format == OutputFormatJSON && len(input.Messages) > 0 { - jsonBytes, err := json.Marshal(input.Output.Schema) +func conformOutput(req *GenerateRequest) error { + if req.Output.Format == OutputFormatJSON && len(req.Messages) > 0 { + jsonBytes, err := json.Marshal(req.Output.Schema) if err != nil { return fmt.Errorf("expected schema is not valid: %w", err) } escapedJSON := strconv.Quote(string(jsonBytes)) part := NewTextPart(fmt.Sprintf("Output should be in JSON format and conform to the following schema:\n\n```%s```", escapedJSON)) - input.Messages[len(input.Messages)-1].Content = append(input.Messages[len(input.Messages)-1].Content, part) + req.Messages[len(req.Messages)-1].Content = append(req.Messages[len(req.Messages)-1].Content, part) } return nil } @@ -192,15 +192,15 @@ func validCandidates(ctx context.Context, resp *GenerateResponse) ([]*Candidate, // validCandidate will validate the candidate's response against the expected schema. // It will return an error if it does not match, otherwise it will return a candidate with JSON content and type. -func validCandidate(candidate *Candidate, outputSchema *GenerateRequestOutput) (*Candidate, error) { - if outputSchema.Format == OutputFormatJSON { - text, err := candidate.Text() +func validCandidate(c *Candidate, output *GenerateRequestOutput) (*Candidate, error) { + if output.Format == OutputFormatJSON { + text, err := c.Text() if err != nil { return nil, err } text = stripJSONDelimiters(text) var schemaBytes []byte - schemaBytes, err = json.Marshal(outputSchema.Schema) + schemaBytes, err = json.Marshal(output.Schema) if err != nil { return nil, fmt.Errorf("expected schema is not valid: %w", err) } @@ -208,9 +208,9 @@ func validCandidate(candidate *Candidate, outputSchema *GenerateRequestOutput) ( return nil, err } // TODO: Verify that it okay to replace all content with JSON. - candidate.Message.Content = []*Part{NewJSONPart(text)} + c.Message.Content = []*Part{NewJSONPart(text)} } - return candidate, nil + return c, nil } // stripJSONDelimiters strips Markdown JSON delimiters that may come back in the response. @@ -286,7 +286,7 @@ func (gr *GenerateResponse) Text() (string, error) { func (c *Candidate) Text() (string, error) { msg := c.Message if msg == nil { - return "", errors.New("candidate with no message") + return "", errors.New("candidate has no message") } if len(msg.Content) == 0 { return "", errors.New("candidate message has no content") diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index e8e85fedd1..a95f4e763b 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -15,9 +15,8 @@ package ai import ( + "strings" "testing" - - "github.com/stretchr/testify/assert" ) func TestValidCandidate(t *testing.T) { @@ -35,7 +34,9 @@ func TestValidCandidate(t *testing.T) { Format: OutputFormatText, } _, err := validCandidate(candidate, outputSchema) - assert.NoError(t, err) + if err != nil { + t.Error(err) + } }) t.Run("Valid candidate with JSON format and matching schema", func(t *testing.T) { @@ -77,10 +78,16 @@ func TestValidCandidate(t *testing.T) { }, } candidate, err := validCandidate(candidate, outputSchema) - assert.NoError(t, err) + if err != nil { + t.Error(err) + } text, err := candidate.Text() - assert.NoError(t, err) - assert.EqualValues(t, text, json) + if err != nil { + t.Error(err) + } + if text != json { + t.Errorf("mismatch (-want, +got) -%s +%s", json, text) + } }) t.Run("Invalid candidate with JSON format and non-matching schema", func(t *testing.T) { @@ -102,8 +109,7 @@ func TestValidCandidate(t *testing.T) { }, } _, err := validCandidate(candidate, outputSchema) - assert.Error(t, err) - assert.Contains(t, err.Error(), "data did not match expected schema") + errorContains(t, err, "data did not match expected schema") }) t.Run("Candidate with invalid JSON", func(t *testing.T) { @@ -118,8 +124,7 @@ func TestValidCandidate(t *testing.T) { Format: OutputFormatJSON, } _, err := validCandidate(candidate, outputSchema) - assert.Error(t, err) - assert.Contains(t, err.Error(), "data is not valid JSON") + errorContains(t, err, "data is not valid JSON") }) t.Run("Candidate with no message", func(t *testing.T) { @@ -128,8 +133,7 @@ func TestValidCandidate(t *testing.T) { Format: OutputFormatJSON, } _, err := validCandidate(candidate, outputSchema) - assert.Error(t, err) - assert.Equal(t, "candidate with no message", err.Error()) + errorContains(t, err, "candidate has no message") }) t.Run("Candidate with message but no content", func(t *testing.T) { @@ -140,8 +144,7 @@ func TestValidCandidate(t *testing.T) { Format: OutputFormatJSON, } _, err := validCandidate(candidate, outputSchema) - assert.Error(t, err) - assert.Equal(t, "candidate message has no content", err.Error()) + errorContains(t, err, "candidate message has no content") }) t.Run("Candidate contains unexpected field", func(t *testing.T) { @@ -164,8 +167,7 @@ func TestValidCandidate(t *testing.T) { }, } _, err := validCandidate(candidate, outputSchema) - assert.Error(t, err) - assert.Contains(t, err.Error(), "data did not match expected schema") + errorContains(t, err, "data did not match expected schema") }) t.Run("Invalid expected schema", func(t *testing.T) { @@ -183,11 +185,19 @@ func TestValidCandidate(t *testing.T) { }, } _, err := validCandidate(candidate, outputSchema) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to validate data against expected schema") + errorContains(t, err, "failed to validate data against expected schema") }) } func JSONMarkdown(text string) string { return "```json\n" + text + "\n```" } + +func errorContains(t *testing.T, err error, want string) { + t.Helper() + if err == nil { + t.Error("got nil, want error") + } else if !strings.Contains(err.Error(), want) { + t.Errorf("got error message %q, want it to contain %q", err, want) + } +} diff --git a/go/core/flow.go b/go/core/flow.go index 3391f04ded..af0e3e4e61 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -171,27 +171,6 @@ type retryInstruction struct { FlowID string `json:"flowId,omitempty"` } -// flowInstructioner is the common type of all flowInstruction[I] types. -type flowInstructioner interface { - IsFlowInstruction() - StartInput() any - ScheduleInput() any -} - -func (fi *flowInstruction[I]) IsFlowInstruction() {} -func (fi *flowInstruction[I]) StartInput() any { - if fi.Start != nil { - return fi.Start.Input - } - return nil -} -func (fi *flowInstruction[I]) ScheduleInput() any { - if fi.Schedule != nil { - return fi.Schedule.Input - } - return nil -} - // A flowState is a persistent representation of a flow that may be in the middle of running. // It contains all the information needed to resume a flow, including the original input // and a cache of all completed steps. @@ -199,9 +178,8 @@ type flowState[I, O any] struct { FlowID string `json:"flowId,omitempty"` FlowName string `json:"name,omitempty"` // start time in milliseconds since the epoch - StartTime tracing.Milliseconds `json:"startTime,omitempty"` - Input I `json:"input,omitempty"` - + StartTime tracing.Milliseconds `json:"startTime,omitempty"` + Input I `json:"input,omitempty"` mu sync.Mutex Cache map[string]json.RawMessage `json:"cache,omitempty"` EventsTriggered map[string]any `json:"eventsTriggered,omitempty"` @@ -231,7 +209,6 @@ type flowStater interface { lock() unlock() cache() map[string]json.RawMessage - result() any } // isFlowState implements flowStater. @@ -239,12 +216,6 @@ func (fs *flowState[I, O]) isFlowState() {} func (fs *flowState[I, O]) lock() { fs.mu.Lock() } func (fs *flowState[I, O]) unlock() { fs.mu.Unlock() } func (fs *flowState[I, O]) cache() map[string]json.RawMessage { return fs.Cache } -func (fs *flowState[I, O]) result() any { - if fs.Operation.Done { - return fs.Operation.Result.Response - } - return nil -} // An operation describes the state of a Flow that may still be in progress. type operation[O any] struct { @@ -414,7 +385,7 @@ func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dis var output O if err == nil { output, err = f.fn(ctx, input, cb) - if err != nil { + if err == nil { if err = ValidateObject(output, f.outputSchema); err != nil { err = fmt.Errorf("invalid output: %w", err) } diff --git a/go/core/validation.go b/go/core/validation.go index 3966279ad0..970b1959d0 100644 --- a/go/core/validation.go +++ b/go/core/validation.go @@ -17,14 +17,15 @@ package core import ( "encoding/json" "fmt" + "strings" "github.com/invopop/jsonschema" "github.com/xeipuuv/gojsonschema" ) -// ValidateObject will validate any object against the expected schema. +// ValidateValue will validate any value against the expected schema. // It will return an error if it doesn't match the schema, otherwise it will return nil. -func ValidateObject(data any, schema *jsonschema.Schema) error { +func ValidateValue(data any, schema *jsonschema.Schema) error { dataBytes, err := json.Marshal(data) if err != nil { return fmt.Errorf("data is not a valid JSON type: %w", err) @@ -60,11 +61,11 @@ func ValidateRaw(dataBytes json.RawMessage, schemaBytes json.RawMessage) error { } if !result.Valid() { - var errors string + var errors []string for _, err := range result.Errors() { - errors += fmt.Sprintf("- %s\n", err) + errors = append(errors, fmt.Sprintf("- %s", err)) } - return fmt.Errorf("data did not match expected schema:\n%s", errors) + return fmt.Errorf("data did not match expected schema:\n%s", strings.Join(errors, "\n")) } return nil From 632ac640846958dafc0e4dba28f514ad9776db89 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 28 May 2024 18:39:58 -0700 Subject: [PATCH 20/27] Fix. --- go/core/action.go | 4 ++-- go/core/flow.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index 51a27e258a..3cd6649791 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -119,14 +119,14 @@ func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb func(context.Cont func(ctx context.Context, input I) (O, error) { start := time.Now() var err error - if err = ValidateObject(input, a.inputSchema); err != nil { + if err = ValidateValue(input, a.inputSchema); err != nil { err = fmt.Errorf("invalid input: %w", err) } var output O if err == nil { output, err = a.fn(ctx, input, cb) if err != nil { - if err = ValidateObject(output, a.outputSchema); err != nil { + if err = ValidateValue(output, a.outputSchema); err != nil { err = fmt.Errorf("invalid output: %w", err) } } diff --git a/go/core/flow.go b/go/core/flow.go index af0e3e4e61..6a33ff6249 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -379,14 +379,14 @@ func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dis // TODO(jba): If input is missing, get it from state.input and overwrite metadata.input. start := time.Now() var err error - if err = ValidateObject(input, f.inputSchema); err != nil { + if err = ValidateValue(input, f.inputSchema); err != nil { err = fmt.Errorf("invalid input: %w", err) } var output O if err == nil { output, err = f.fn(ctx, input, cb) if err == nil { - if err = ValidateObject(output, f.outputSchema); err != nil { + if err = ValidateValue(output, f.outputSchema); err != nil { err = fmt.Errorf("invalid output: %w", err) } } From 56479679af05db1293fc3b0f32fa62890703cff0 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 28 May 2024 18:42:45 -0700 Subject: [PATCH 21/27] Exported all part related types. --- go/ai/document.go | 58 +++++++++++++++++++++--------------------- go/ai/document_test.go | 20 +++++++-------- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/go/ai/document.go b/go/ai/document.go index 40ba03deef..a54784b257 100644 --- a/go/ai/document.go +++ b/go/ai/document.go @@ -31,80 +31,80 @@ type Document struct { // A Part is one part of a [Document]. This may be plain text or it // may be a URL (possibly a "data:" URL with embedded data). type Part struct { - Kind partKind `json:"kind,omitempty"` + Kind PartKind `json:"kind,omitempty"` ContentType string `json:"contentType,omitempty"` // valid for kind==blob Text string `json:"text,omitempty"` // valid for kind∈{text,blob} ToolRequest *ToolRequest `json:"toolreq,omitempty"` // valid for kind==partToolRequest ToolResponse *ToolResponse `json:"toolresp,omitempty"` // valid for kind==partToolResponse } -type partKind int8 +type PartKind int8 const ( - partText partKind = iota - partMedia - partData - partToolRequest - partToolResponse + PartText PartKind = iota + PartMedia + PartData + PartToolRequest + PartToolResponse ) // NewTextPart returns a Part containing text. func NewTextPart(text string) *Part { - return &Part{Kind: partText, ContentType: "plain/text", Text: text} + return &Part{Kind: PartText, ContentType: "plain/text", Text: text} } // NewJSONPart returns a Part containing JSON. func NewJSONPart(text string) *Part { - return &Part{Kind: partText, ContentType: "application/json", Text: text} + return &Part{Kind: PartText, ContentType: "application/json", Text: text} } // NewMediaPart returns a Part containing structured data described // by the given mimeType. func NewMediaPart(mimeType, contents string) *Part { - return &Part{Kind: partMedia, ContentType: mimeType, Text: contents} + return &Part{Kind: PartMedia, ContentType: mimeType, Text: contents} } // NewDataPart returns a Part containing raw string data. func NewDataPart(contents string) *Part { - return &Part{Kind: partData, Text: contents} + return &Part{Kind: PartData, Text: contents} } // NewToolRequestPart returns a Part containing a request from // the model to the client to run a Tool. // (Only genkit plugins should need to use this function.) func NewToolRequestPart(r *ToolRequest) *Part { - return &Part{Kind: partToolRequest, ToolRequest: r} + return &Part{Kind: PartToolRequest, ToolRequest: r} } // NewToolResponsePart returns a Part containing the results // of applying a Tool that the model requested. func NewToolResponsePart(r *ToolResponse) *Part { - return &Part{Kind: partToolResponse, ToolResponse: r} + return &Part{Kind: PartToolResponse, ToolResponse: r} } // IsText reports whether the [Part] contains plain text. func (p *Part) IsText() bool { - return p.Kind == partText + return p.Kind == PartText } // IsMedia reports whether the [Part] contains structured media data. func (p *Part) IsMedia() bool { - return p.Kind == partMedia + return p.Kind == PartMedia } // IsData reports whether the [Part] contains unstructured data. func (p *Part) IsData() bool { - return p.Kind == partData + return p.Kind == PartData } // IsToolRequest reports whether the [Part] contains a request to run a tool. func (p *Part) IsToolRequest() bool { - return p.Kind == partToolRequest + return p.Kind == PartToolRequest } // IsToolResponse reports whether the [Part] contains the result of running a tool. func (p *Part) IsToolResponse() bool { - return p.Kind == partToolResponse + return p.Kind == PartToolResponse } // MarshalJSON is called by the JSON marshaler to write out a Part. @@ -113,12 +113,12 @@ func (p *Part) MarshalJSON() ([]byte, error) { // Part is defined in TypeScript as a union. switch p.Kind { - case partText: + case PartText: v := textPart{ Text: p.Text, } return json.Marshal(v) - case partMedia: + case PartMedia: v := mediaPart{ Media: &mediaPartMedia{ ContentType: p.ContentType, @@ -126,12 +126,12 @@ func (p *Part) MarshalJSON() ([]byte, error) { }, } return json.Marshal(v) - case partData: + case PartData: v := dataPart{ Data: p.Text, } return json.Marshal(v) - case partToolRequest: + case PartToolRequest: // TODO: make sure these types marshal/unmarshal nicely // between Go and javascript. At the very least the // field name needs to change (here and in UnmarshalJSON). @@ -141,7 +141,7 @@ func (p *Part) MarshalJSON() ([]byte, error) { ToolReq: p.ToolRequest, } return json.Marshal(v) - case partToolResponse: + case PartToolResponse: v := struct { ToolResp *ToolResponse `json:"toolresp,omitempty"` }{ @@ -172,22 +172,22 @@ func (p *Part) UnmarshalJSON(b []byte) error { switch { case s.Media != nil: - p.Kind = partMedia + p.Kind = PartMedia p.Text = s.Media.Url p.ContentType = s.Media.ContentType case s.ToolReq != nil: - p.Kind = partToolRequest + p.Kind = PartToolRequest p.ToolRequest = s.ToolReq case s.ToolResp != nil: - p.Kind = partToolResponse + p.Kind = PartToolResponse p.ToolResponse = s.ToolResp default: - p.Kind = partText + p.Kind = PartText p.Text = s.Text p.ContentType = "" if s.Data != "" { // Note: if part is completely empty, we use text by default. - p.Kind = partData + p.Kind = PartData p.Text = s.Data } } @@ -200,7 +200,7 @@ func DocumentFromText(text string, metadata map[string]any) *Document { return &Document{ Content: []*Part{ { - Kind: partText, + Kind: PartText, Text: text, }, }, diff --git a/go/ai/document_test.go b/go/ai/document_test.go index 502bb14d5e..8d173fa2b3 100644 --- a/go/ai/document_test.go +++ b/go/ai/document_test.go @@ -42,27 +42,27 @@ func TestDocumentJSON(t *testing.T) { d := Document{ Content: []*Part{ &Part{ - Kind: partText, + Kind: PartText, Text: "hi", }, &Part{ - Kind: partMedia, + Kind: PartMedia, ContentType: "text/plain", Text: "data:,bye", }, &Part{ - Kind: partData, + Kind: PartData, Text: "somedata\x00string", }, &Part{ - Kind: partToolRequest, + Kind: PartToolRequest, ToolRequest: &ToolRequest{ Name: "tool1", Input: map[string]any{"arg1": 3.3, "arg2": "foo"}, }, }, &Part{ - Kind: partToolResponse, + Kind: PartToolResponse, ToolResponse: &ToolResponse{ Name: "tool1", Output: map[string]any{"res1": 4.4, "res2": "bar"}, @@ -87,15 +87,15 @@ func TestDocumentJSON(t *testing.T) { return false } switch a.Kind { - case partText: + case PartText: return a.Text == b.Text - case partMedia: + case PartMedia: return a.ContentType == b.ContentType && a.Text == b.Text - case partData: + case PartData: return a.Text == b.Text - case partToolRequest: + case PartToolRequest: return reflect.DeepEqual(a.ToolRequest, b.ToolRequest) - case partToolResponse: + case PartToolResponse: return reflect.DeepEqual(a.ToolResponse, b.ToolResponse) default: t.Fatalf("bad part kind %v", a.Kind) From 1a6fc37996a696da1bb7843638b335cb2e695518 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 28 May 2024 18:57:05 -0700 Subject: [PATCH 22/27] Update go.mod --- go/go.mod | 3 --- 1 file changed, 3 deletions(-) diff --git a/go/go.mod b/go/go.mod index 2e2190b210..c267ab6caa 100644 --- a/go/go.mod +++ b/go/go.mod @@ -14,7 +14,6 @@ require ( github.com/google/uuid v1.6.0 github.com/invopop/jsonschema v0.12.0 github.com/jba/slog v0.2.0 - github.com/stretchr/testify v1.9.0 github.com/wk8/go-ordered-map/v2 v2.1.8 github.com/xeipuuv/gojsonschema v1.2.0 go.opentelemetry.io/otel v1.26.0 @@ -41,7 +40,6 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.46.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -52,7 +50,6 @@ require ( github.com/googleapis/gax-go/v2 v2.12.3 // indirect github.com/kr/text v0.2.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect go.opencensus.io v0.24.0 // indirect From a22d424a6153adccd4c7bc56383fab69818af396 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Tue, 28 May 2024 19:02:07 -0700 Subject: [PATCH 23/27] Update dotprompt.go --- go/plugins/dotprompt/dotprompt.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/plugins/dotprompt/dotprompt.go b/go/plugins/dotprompt/dotprompt.go index b663810fb4..ab251637e2 100644 --- a/go/plugins/dotprompt/dotprompt.go +++ b/go/plugins/dotprompt/dotprompt.go @@ -98,7 +98,7 @@ type Config struct { OutputFormat ai.OutputFormat // Desired output schema, for JSON output. - OutputSchema map[string]any // Cannot use *jsonschema.Schema because it will need to self-reflect. + OutputSchema map[string]any // TODO: use *jsonschema.Schema // Arbitrary metadata. Metadata map[string]any From 5a07723aa682cbab67198d69aba018c1730d2ffc Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 29 May 2024 09:15:20 -0700 Subject: [PATCH 24/27] Removed manual "null" case in schema generation. This made it impossible to proceed through input validation for action -> flow. --- go/core/action.go | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index 3cd6649791..dff5ebd9e6 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -19,7 +19,6 @@ import ( "encoding/json" "fmt" "maps" - "reflect" "time" "github.com/firebase/genkit/go/core/logger" @@ -230,16 +229,6 @@ func inferJSONSchema(x any) (s *jsonschema.Schema) { r := jsonschema.Reflector{ DoNotReference: true, } - t := reflect.TypeOf(x) - if t.Kind() == reflect.Struct { - if t.NumField() == 0 { - // Make struct{} correspond to ZodVoid. - return &jsonschema.Schema{Type: "null"} - } - // Put a struct definition at the "top level" of the schema, - // instead of nested inside a "$defs" object. - r.ExpandedStruct = true - } s = r.Reflect(x) // TODO: Unwind this change once Monaco Editor supports newer than JSON schema draft-07. s.Version = "" From b0414971ee135b7a41de0400340396965950d0fa Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 29 May 2024 09:23:43 -0700 Subject: [PATCH 25/27] Fixes. --- go/ai/generator_test.go | 8 ++++---- go/core/action.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index a95f4e763b..01d1ab13a6 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -35,7 +35,7 @@ func TestValidCandidate(t *testing.T) { } _, err := validCandidate(candidate, outputSchema) if err != nil { - t.Error(err) + t.Fatal(err) } }) @@ -79,14 +79,14 @@ func TestValidCandidate(t *testing.T) { } candidate, err := validCandidate(candidate, outputSchema) if err != nil { - t.Error(err) + t.Fatal(err) } text, err := candidate.Text() if err != nil { - t.Error(err) + t.Fatal(err) } if text != json { - t.Errorf("mismatch (-want, +got) -%s +%s", json, text) + t.Fatalf("got %q, want %q", json, text) } }) diff --git a/go/core/action.go b/go/core/action.go index dff5ebd9e6..b9e8d71c23 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -124,7 +124,7 @@ func (a *Action[I, O, S]) Run(ctx context.Context, input I, cb func(context.Cont var output O if err == nil { output, err = a.fn(ctx, input, cb) - if err != nil { + if err == nil { if err = ValidateValue(output, a.outputSchema); err != nil { err = fmt.Errorf("invalid output: %w", err) } From 4238ab8c82b42aa9feb00e91c3b5023c0e8496b2 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 29 May 2024 09:27:33 -0700 Subject: [PATCH 26/27] Update generator.go --- go/ai/generator.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/go/ai/generator.go b/go/ai/generator.go index a55d38b074..748b79676e 100644 --- a/go/ai/generator.go +++ b/go/ai/generator.go @@ -73,13 +73,13 @@ func RegisterGenerator(provider, name string, metadata *GeneratorMetadata, gener } // Generate applies a [Generator] to some input, handling tool requests. -func Generate(ctx context.Context, g Generator, input *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { - if err := conformOutput(input); err != nil { +func Generate(ctx context.Context, g Generator, req *GenerateRequest, cb func(context.Context, *Candidate) error) (*GenerateResponse, error) { + if err := conformOutput(req); err != nil { return nil, err } for { - resp, err := g.Generate(ctx, input, cb) + resp, err := g.Generate(ctx, req, cb) if err != nil { return nil, err } @@ -90,7 +90,7 @@ func Generate(ctx context.Context, g Generator, input *GenerateRequest, cb func( } resp.Candidates = candidates - newReq, err := handleToolRequest(ctx, input, resp) + newReq, err := handleToolRequest(ctx, req, resp) if err != nil { return nil, err } @@ -98,7 +98,7 @@ func Generate(ctx context.Context, g Generator, input *GenerateRequest, cb func( return resp, nil } - input = newReq + req = newReq } } From 48a0bf697567ffbced3ea36a537baef6f5c8636c Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 29 May 2024 09:35:10 -0700 Subject: [PATCH 27/27] Fixed O type. --- go/core/action.go | 4 ++-- go/core/flow.go | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index 526076ca2f..10d2a886ec 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -121,7 +121,7 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con if err = ValidateValue(input, a.inputSchema); err != nil { err = fmt.Errorf("invalid input: %w", err) } - var output O + var output Out if err == nil { output, err = a.fn(ctx, input, cb) if err == nil { @@ -145,7 +145,7 @@ func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMes if err := ValidateJSON(input, a.inputSchema); err != nil { return nil, err } - var in In + var in In if err := json.Unmarshal(input, &in); err != nil { return nil, err } diff --git a/go/core/flow.go b/go/core/flow.go index c6724f229e..25a38b0dc3 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -89,12 +89,12 @@ import ( // A Flow[In, Out, Stream] represents a function from In to Out. The Stream parameter is for // flows that support streaming: providing their results incrementally. type Flow[In, Out, Stream any] struct { - name string // The last component of the flow's key in the registry. - fn Func[In, Out, Stream] // The function to run. - stateStore FlowStateStore // Where FlowStates are stored, to support resumption. - tstate *tracing.State // set from the action when the flow is defined - inputSchema *jsonschema.Schema // Schema of the input to the flow - outputSchema *jsonschema.Schema // Schema of the output out of the flow + name string // The last component of the flow's key in the registry. + fn Func[In, Out, Stream] // The function to run. + stateStore FlowStateStore // Where FlowStates are stored, to support resumption. + tstate *tracing.State // set from the action when the flow is defined + inputSchema *jsonschema.Schema // Schema of the input to the flow + outputSchema *jsonschema.Schema // Schema of the output out of the flow // TODO(jba): scheduler // TODO(jba): experimentalDurable // TODO(jba): authPolicy @@ -107,11 +107,11 @@ func DefineFlow[In, Out, Stream any](name string, fn Func[In, Out, Stream]) *Flo } func defineFlow[In, Out, Stream any](r *registry, name string, fn Func[In, Out, Stream]) *Flow[In, Out, Stream] { - var i In + var i In var o Out f := &Flow[In, Out, Stream]{ - name: name, - fn: fn, + name: name, + fn: fn, inputSchema: inferJSONSchema(i), outputSchema: inferJSONSchema(o), // TODO(jba): set stateStore? @@ -295,7 +295,7 @@ type flow interface { func (f *Flow[In, Out, Stream]) Name() string { return f.name } func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) { - // Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process. + // Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process. if err := ValidateJSON(input, f.inputSchema); err != nil { return nil, &httpError{http.StatusBadRequest, err} } @@ -382,7 +382,7 @@ func (f *Flow[In, Out, Stream]) execute(ctx context.Context, state *flowState[In if err = ValidateValue(input, f.inputSchema); err != nil { err = fmt.Errorf("invalid input: %w", err) } - var output O + var output Out if err == nil { output, err = f.fn(ctx, input, cb) if err == nil {