diff --git a/go/core/action.go b/go/core/action.go index 1740032e1b..dc37026c0f 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -60,29 +60,34 @@ type Action[In, Out, Stream any] struct { metadata map[string]any } -// See js/core/src/action.ts - -// DefineAction creates a new Action and registers it. -func DefineAction[In, Out any](provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { - return defineAction(registry.Global, provider, name, atype, metadata, fn) -} +type noStream = func(context.Context, struct{}) error -func defineAction[In, Out any](r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { - a := NewAction(provider+"/"+name, atype, metadata, fn) - r.RegisterAction(atype, a) - return a -} +// See js/core/src/action.ts -func DefineStreamingAction[In, Out, Stream any](provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { - return defineStreamingAction(registry.Global, provider, name, atype, metadata, fn) +// DefineAction creates a new non-streaming Action and registers it. +func DefineAction[In, Out any]( + provider, name string, + atype atype.ActionType, + metadata map[string]any, + fn func(context.Context, In) (Out, error), +) *Action[In, Out, struct{}] { + return DefineActionInRegistry(registry.Global, provider, name, atype, metadata, nil, + func(ctx context.Context, in In, _ noStream) (Out, error) { + return fn(ctx, in) + }) } -func defineStreamingAction[In, Out, Stream any](r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { - a := NewStreamingAction(provider+"/"+name, atype, metadata, fn) - r.RegisterAction(atype, a) - return a +// DefineStreamingAction creates a new streaming action and registers it. +func DefineStreamingAction[In, Out, Stream any]( + provider, name string, + atype atype.ActionType, + metadata map[string]any, + fn Func[In, Out, Stream], +) *Action[In, Out, Stream] { + return DefineActionInRegistry(registry.Global, provider, name, atype, metadata, nil, fn) } +// DefineCustomAction defines a streaming action with type Custom. func DefineCustomAction[In, Out, Stream any](provider, name string, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { return DefineStreamingAction(provider, name, atype.Custom, metadata, fn) } @@ -98,35 +103,45 @@ func DefineActionWithInputSchema[Out any]( inputSchema *jsonschema.Schema, fn func(context.Context, any) (Out, error), ) *Action[any, Out, struct{}] { - return defineActionWithInputSchema(registry.Global, provider, name, atype, metadata, inputSchema, fn) + return DefineActionInRegistry(registry.Global, provider, name, atype, metadata, inputSchema, + func(ctx context.Context, in any, _ noStream) (Out, error) { + return fn(ctx, in) + }) } -func defineActionWithInputSchema[Out any]( +// DefineActionInRegistry creates an action and registers it with the given Registry. +// For use by the Genkit module only. +func DefineActionInRegistry[In, Out, Stream any]( r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, inputSchema *jsonschema.Schema, - fn func(context.Context, any) (Out, error), -) *Action[any, Out, struct{}] { - a := newActionWithInputSchema(provider+"/"+name, atype, metadata, fn, inputSchema) + fn Func[In, Out, Stream], +) *Action[In, Out, Stream] { + fullName := name + if provider != "" { + fullName = provider + "/" + name + } + a := newAction(fullName, atype, metadata, inputSchema, fn) r.RegisterAction(atype, a) return a } -type noStream = func(context.Context, struct{}) error - -// NewAction creates a new Action with the given name and non-streaming function. -func NewAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { - return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb noStream) (Out, error) { - return fn(ctx, in) - }) -} - -// NewStreamingAction creates a new Action with the given name and streaming function. -func NewStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { +// newAction creates a new Action with the given name and arguments. +// If inputSchema is nil, it is inferred from In. +func newAction[In, Out, Stream any]( + name string, + atype atype.ActionType, + metadata map[string]any, + inputSchema *jsonschema.Schema, + fn Func[In, Out, Stream], +) *Action[In, Out, Stream] { var i In var o Out + if inputSchema == nil { + inputSchema = internal.InferJSONSchema(i) + } return &Action[In, Out, Stream]{ name: name, atype: atype, @@ -134,21 +149,6 @@ func NewStreamingAction[In, Out, Stream any](name string, atype atype.ActionType tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype)) return fn(ctx, input, sc) }, - inputSchema: internal.InferJSONSchema(i), - outputSchema: internal.InferJSONSchema(o), - metadata: metadata, - } -} - -func newActionWithInputSchema[Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, any) (Out, error), inputSchema *jsonschema.Schema) *Action[any, Out, struct{}] { - var o Out - return &Action[any, Out, struct{}]{ - name: name, - atype: atype, - fn: func(ctx context.Context, input any, sc func(context.Context, struct{}) error) (Out, error) { - tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype)) - return fn(ctx, input) - }, inputSchema: inputSchema, outputSchema: internal.InferJSONSchema(o), metadata: metadata, diff --git a/go/core/action_test.go b/go/core/action_test.go index c1dad2722b..b52c6a2dfa 100644 --- a/go/core/action_test.go +++ b/go/core/action_test.go @@ -25,12 +25,12 @@ import ( "github.com/firebase/genkit/go/internal/registry" ) -func inc(_ context.Context, x int) (int, error) { +func inc(_ context.Context, x int, _ noStream) (int, error) { return x + 1, nil } func TestActionRun(t *testing.T) { - a := NewAction("inc", atype.Custom, nil, inc) + a := newAction("inc", atype.Custom, nil, nil, inc) got, err := a.Run(context.Background(), 3, nil) if err != nil { t.Fatal(err) @@ -41,7 +41,7 @@ func TestActionRun(t *testing.T) { } func TestActionRunJSON(t *testing.T) { - a := NewAction("inc", atype.Custom, nil, inc) + a := newAction("inc", atype.Custom, nil, nil, inc) input := []byte("3") want := []byte("4") got, err := a.RunJSON(context.Background(), input, nil) @@ -53,11 +53,6 @@ func TestActionRunJSON(t *testing.T) { } } -func TestNewAction(t *testing.T) { - // Verify that struct{} can occur in the function signature. - _ = NewAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil }) -} - // count streams the numbers from 0 to n-1, then returns n. func count(ctx context.Context, n int, cb func(context.Context, int) error) (int, error) { if cb != nil { @@ -72,7 +67,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int func TestActionStreaming(t *testing.T) { ctx := context.Background() - a := NewStreamingAction("count", atype.Custom, nil, count) + a := newAction("count", atype.Custom, nil, nil, count) const n = 3 // Non-streaming. @@ -105,7 +100,7 @@ func TestActionStreaming(t *testing.T) { func TestActionTracing(t *testing.T) { ctx := context.Background() const actionName = "TestTracing-inc" - a := NewAction(actionName, atype.Custom, nil, inc) + a := newAction(actionName, atype.Custom, nil, nil, inc) if _, err := a.Run(context.Background(), 3, nil); err != nil { t.Fatal(err) } diff --git a/go/genkit/flow.go b/go/genkit/flow.go index bcd30dbbb6..532b9557c3 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -146,8 +146,15 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core. outputSchema: internal.InferJSONSchema(o), // TODO(jba): set stateStore? } - a := f.action() - r.RegisterAction(atype.Flow, a) + metadata := map[string]any{ + "inputSchema": f.inputSchema, + "outputSchema": f.outputSchema, + } + afunc := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) { + tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true") + return f.runInstruction(ctx, inst, common.StreamingCallback[Stream](cb)) + } + core.DefineActionInRegistry(r, "", f.name, atype.Flow, metadata, nil, afunc) f.tstate = r.TracingState() r.RegisterFlow(f) return f @@ -288,19 +295,6 @@ type FlowResult[Out any] struct { // FlowResult is called FlowResponse in the javascript. -// action creates an action for the flow. See the comment at the top of this file for more information. -func (f *Flow[In, Out, Stream]) action() *core.Action[*flowInstruction[In], *flowState[In, Out], Stream] { - metadata := map[string]any{ - "inputSchema": f.inputSchema, - "outputSchema": f.outputSchema, - } - cback := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) { - tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true") - return f.runInstruction(ctx, inst, common.StreamingCallback[Stream](cb)) - } - return core.NewStreamingAction(f.name, atype.Flow, metadata, cback) -} - // runInstruction performs one of several actions on a flow, as determined by msg. // (Called runEnvelope in the js.) func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowInstruction[In], cb common.StreamingCallback[Stream]) (*flowState[In, Out], error) { diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index 88d8c0b0b7..59a29ceba6 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -33,11 +33,11 @@ import ( "github.com/invopop/jsonschema" ) -func inc(_ context.Context, x int) (int, error) { +func inc(_ context.Context, x int, _ noStream) (int, error) { return x + 1, nil } -func dec(_ context.Context, x int) (int, error) { +func dec(_ context.Context, x int, _ noStream) (int, error) { return x - 1, nil } @@ -46,12 +46,12 @@ func TestDevServer(t *testing.T) { if err != nil { t.Fatal(err) } - r.RegisterAction(atype.Custom, core.NewAction("devServer/inc", atype.Custom, map[string]any{ + core.DefineActionInRegistry(r, "devServer", "inc", atype.Custom, map[string]any{ "foo": "bar", - }, inc)) - r.RegisterAction(atype.Custom, core.NewAction("devServer/dec", atype.Custom, map[string]any{ + }, nil, inc) + core.DefineActionInRegistry(r, "devServer", "dec", atype.Custom, map[string]any{ "bar": "baz", - }, dec)) + }, nil, dec) srv := httptest.NewServer(newDevServeMux(r)) defer srv.Close()