From 74e26dee103b3046a5ca7ddcb872376cfc2fba5e Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 2 Jul 2024 16:52:23 -0400 Subject: [PATCH] [Go] improve core action API Define a single, general newAction function. Define a single fully general method for defining an action in a registry. All other Define functions call it. Export it so the genkit package can use it to define flows. It's unfortunate that we have to export it, but code outside the module can't call it anyway because it takes a registry, which has an internal type. --- go/core/action.go | 96 +++++++++++++++++++-------------------- go/core/action_test.go | 15 ++---- go/genkit/flow.go | 24 ++++------ go/genkit/servers_test.go | 12 ++--- 4 files changed, 68 insertions(+), 79 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index 1740032e1b..dc37026c0f 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -60,29 +60,34 @@ type Action[In, Out, Stream any] struct { metadata map[string]any } -// See js/core/src/action.ts - -// DefineAction creates a new Action and registers it. -func DefineAction[In, Out any](provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { - return defineAction(registry.Global, provider, name, atype, metadata, fn) -} +type noStream = func(context.Context, struct{}) error -func defineAction[In, Out any](r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { - a := NewAction(provider+"/"+name, atype, metadata, fn) - r.RegisterAction(atype, a) - return a -} +// See js/core/src/action.ts -func DefineStreamingAction[In, Out, Stream any](provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { - return defineStreamingAction(registry.Global, provider, name, atype, metadata, fn) +// DefineAction creates a new non-streaming Action and registers it. +func DefineAction[In, Out any]( + provider, name string, + atype atype.ActionType, + metadata map[string]any, + fn func(context.Context, In) (Out, error), +) *Action[In, Out, struct{}] { + return DefineActionInRegistry(registry.Global, provider, name, atype, metadata, nil, + func(ctx context.Context, in In, _ noStream) (Out, error) { + return fn(ctx, in) + }) } -func defineStreamingAction[In, Out, Stream any](r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { - a := NewStreamingAction(provider+"/"+name, atype, metadata, fn) - r.RegisterAction(atype, a) - return a +// DefineStreamingAction creates a new streaming action and registers it. +func DefineStreamingAction[In, Out, Stream any]( + provider, name string, + atype atype.ActionType, + metadata map[string]any, + fn Func[In, Out, Stream], +) *Action[In, Out, Stream] { + return DefineActionInRegistry(registry.Global, provider, name, atype, metadata, nil, fn) } +// DefineCustomAction defines a streaming action with type Custom. func DefineCustomAction[In, Out, Stream any](provider, name string, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { return DefineStreamingAction(provider, name, atype.Custom, metadata, fn) } @@ -98,35 +103,45 @@ func DefineActionWithInputSchema[Out any]( inputSchema *jsonschema.Schema, fn func(context.Context, any) (Out, error), ) *Action[any, Out, struct{}] { - return defineActionWithInputSchema(registry.Global, provider, name, atype, metadata, inputSchema, fn) + return DefineActionInRegistry(registry.Global, provider, name, atype, metadata, inputSchema, + func(ctx context.Context, in any, _ noStream) (Out, error) { + return fn(ctx, in) + }) } -func defineActionWithInputSchema[Out any]( +// DefineActionInRegistry creates an action and registers it with the given Registry. +// For use by the Genkit module only. +func DefineActionInRegistry[In, Out, Stream any]( r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, inputSchema *jsonschema.Schema, - fn func(context.Context, any) (Out, error), -) *Action[any, Out, struct{}] { - a := newActionWithInputSchema(provider+"/"+name, atype, metadata, fn, inputSchema) + fn Func[In, Out, Stream], +) *Action[In, Out, Stream] { + fullName := name + if provider != "" { + fullName = provider + "/" + name + } + a := newAction(fullName, atype, metadata, inputSchema, fn) r.RegisterAction(atype, a) return a } -type noStream = func(context.Context, struct{}) error - -// NewAction creates a new Action with the given name and non-streaming function. -func NewAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { - return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb noStream) (Out, error) { - return fn(ctx, in) - }) -} - -// NewStreamingAction creates a new Action with the given name and streaming function. -func NewStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { +// newAction creates a new Action with the given name and arguments. +// If inputSchema is nil, it is inferred from In. +func newAction[In, Out, Stream any]( + name string, + atype atype.ActionType, + metadata map[string]any, + inputSchema *jsonschema.Schema, + fn Func[In, Out, Stream], +) *Action[In, Out, Stream] { var i In var o Out + if inputSchema == nil { + inputSchema = internal.InferJSONSchema(i) + } return &Action[In, Out, Stream]{ name: name, atype: atype, @@ -134,21 +149,6 @@ func NewStreamingAction[In, Out, Stream any](name string, atype atype.ActionType tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype)) return fn(ctx, input, sc) }, - inputSchema: internal.InferJSONSchema(i), - outputSchema: internal.InferJSONSchema(o), - metadata: metadata, - } -} - -func newActionWithInputSchema[Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, any) (Out, error), inputSchema *jsonschema.Schema) *Action[any, Out, struct{}] { - var o Out - return &Action[any, Out, struct{}]{ - name: name, - atype: atype, - fn: func(ctx context.Context, input any, sc func(context.Context, struct{}) error) (Out, error) { - tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype)) - return fn(ctx, input) - }, inputSchema: inputSchema, outputSchema: internal.InferJSONSchema(o), metadata: metadata, diff --git a/go/core/action_test.go b/go/core/action_test.go index c1dad2722b..b52c6a2dfa 100644 --- a/go/core/action_test.go +++ b/go/core/action_test.go @@ -25,12 +25,12 @@ import ( "github.com/firebase/genkit/go/internal/registry" ) -func inc(_ context.Context, x int) (int, error) { +func inc(_ context.Context, x int, _ noStream) (int, error) { return x + 1, nil } func TestActionRun(t *testing.T) { - a := NewAction("inc", atype.Custom, nil, inc) + a := newAction("inc", atype.Custom, nil, nil, inc) got, err := a.Run(context.Background(), 3, nil) if err != nil { t.Fatal(err) @@ -41,7 +41,7 @@ func TestActionRun(t *testing.T) { } func TestActionRunJSON(t *testing.T) { - a := NewAction("inc", atype.Custom, nil, inc) + a := newAction("inc", atype.Custom, nil, nil, inc) input := []byte("3") want := []byte("4") got, err := a.RunJSON(context.Background(), input, nil) @@ -53,11 +53,6 @@ func TestActionRunJSON(t *testing.T) { } } -func TestNewAction(t *testing.T) { - // Verify that struct{} can occur in the function signature. - _ = NewAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil }) -} - // count streams the numbers from 0 to n-1, then returns n. func count(ctx context.Context, n int, cb func(context.Context, int) error) (int, error) { if cb != nil { @@ -72,7 +67,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int func TestActionStreaming(t *testing.T) { ctx := context.Background() - a := NewStreamingAction("count", atype.Custom, nil, count) + a := newAction("count", atype.Custom, nil, nil, count) const n = 3 // Non-streaming. @@ -105,7 +100,7 @@ func TestActionStreaming(t *testing.T) { func TestActionTracing(t *testing.T) { ctx := context.Background() const actionName = "TestTracing-inc" - a := NewAction(actionName, atype.Custom, nil, inc) + a := newAction(actionName, atype.Custom, nil, nil, inc) if _, err := a.Run(context.Background(), 3, nil); err != nil { t.Fatal(err) } diff --git a/go/genkit/flow.go b/go/genkit/flow.go index bcd30dbbb6..532b9557c3 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -146,8 +146,15 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core. outputSchema: internal.InferJSONSchema(o), // TODO(jba): set stateStore? } - a := f.action() - r.RegisterAction(atype.Flow, a) + metadata := map[string]any{ + "inputSchema": f.inputSchema, + "outputSchema": f.outputSchema, + } + afunc := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) { + tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true") + return f.runInstruction(ctx, inst, common.StreamingCallback[Stream](cb)) + } + core.DefineActionInRegistry(r, "", f.name, atype.Flow, metadata, nil, afunc) f.tstate = r.TracingState() r.RegisterFlow(f) return f @@ -288,19 +295,6 @@ type FlowResult[Out any] struct { // FlowResult is called FlowResponse in the javascript. -// action creates an action for the flow. See the comment at the top of this file for more information. -func (f *Flow[In, Out, Stream]) action() *core.Action[*flowInstruction[In], *flowState[In, Out], Stream] { - metadata := map[string]any{ - "inputSchema": f.inputSchema, - "outputSchema": f.outputSchema, - } - cback := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) { - tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true") - return f.runInstruction(ctx, inst, common.StreamingCallback[Stream](cb)) - } - return core.NewStreamingAction(f.name, atype.Flow, metadata, cback) -} - // runInstruction performs one of several actions on a flow, as determined by msg. // (Called runEnvelope in the js.) func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowInstruction[In], cb common.StreamingCallback[Stream]) (*flowState[In, Out], error) { diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index 88d8c0b0b7..59a29ceba6 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -33,11 +33,11 @@ import ( "github.com/invopop/jsonschema" ) -func inc(_ context.Context, x int) (int, error) { +func inc(_ context.Context, x int, _ noStream) (int, error) { return x + 1, nil } -func dec(_ context.Context, x int) (int, error) { +func dec(_ context.Context, x int, _ noStream) (int, error) { return x - 1, nil } @@ -46,12 +46,12 @@ func TestDevServer(t *testing.T) { if err != nil { t.Fatal(err) } - r.RegisterAction(atype.Custom, core.NewAction("devServer/inc", atype.Custom, map[string]any{ + core.DefineActionInRegistry(r, "devServer", "inc", atype.Custom, map[string]any{ "foo": "bar", - }, inc)) - r.RegisterAction(atype.Custom, core.NewAction("devServer/dec", atype.Custom, map[string]any{ + }, nil, inc) + core.DefineActionInRegistry(r, "devServer", "dec", atype.Custom, map[string]any{ "bar": "baz", - }, dec)) + }, nil, dec) srv := httptest.NewServer(newDevServeMux(r)) defer srv.Close()