Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 48 additions & 48 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,34 @@ type Action[In, Out, Stream any] struct {
metadata map[string]any
}

// See js/core/src/action.ts

// DefineAction creates a new Action and registers it.
func DefineAction[In, Out any](provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return defineAction(registry.Global, provider, name, atype, metadata, fn)
}
type noStream = func(context.Context, struct{}) error

func defineAction[In, Out any](r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
a := NewAction(provider+"/"+name, atype, metadata, fn)
r.RegisterAction(atype, a)
return a
}
// See js/core/src/action.ts

func DefineStreamingAction[In, Out, Stream any](provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
return defineStreamingAction(registry.Global, provider, name, atype, metadata, fn)
// DefineAction creates a new non-streaming Action and registers it.
func DefineAction[In, Out any](
provider, name string,
atype atype.ActionType,
metadata map[string]any,
fn func(context.Context, In) (Out, error),
) *Action[In, Out, struct{}] {
return DefineActionInRegistry(registry.Global, provider, name, atype, metadata, nil,
func(ctx context.Context, in In, _ noStream) (Out, error) {
return fn(ctx, in)
})
}

func defineStreamingAction[In, Out, Stream any](r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
a := NewStreamingAction(provider+"/"+name, atype, metadata, fn)
r.RegisterAction(atype, a)
return a
// DefineStreamingAction creates a new streaming action and registers it.
func DefineStreamingAction[In, Out, Stream any](
provider, name string,
atype atype.ActionType,
metadata map[string]any,
fn Func[In, Out, Stream],
) *Action[In, Out, Stream] {
return DefineActionInRegistry(registry.Global, provider, name, atype, metadata, nil, fn)
}

// DefineCustomAction defines a streaming action with type Custom.
func DefineCustomAction[In, Out, Stream any](provider, name string, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
return DefineStreamingAction(provider, name, atype.Custom, metadata, fn)
}
Expand All @@ -98,57 +103,52 @@ func DefineActionWithInputSchema[Out any](
inputSchema *jsonschema.Schema,
fn func(context.Context, any) (Out, error),
) *Action[any, Out, struct{}] {
return defineActionWithInputSchema(registry.Global, provider, name, atype, metadata, inputSchema, fn)
return DefineActionInRegistry(registry.Global, provider, name, atype, metadata, inputSchema,
func(ctx context.Context, in any, _ noStream) (Out, error) {
return fn(ctx, in)
})
}

func defineActionWithInputSchema[Out any](
// DefineActionInRegistry creates an action and registers it with the given Registry.
// For use by the Genkit module only.
func DefineActionInRegistry[In, Out, Stream any](
r *registry.Registry,
provider, name string,
atype atype.ActionType,
metadata map[string]any,
inputSchema *jsonschema.Schema,
fn func(context.Context, any) (Out, error),
) *Action[any, Out, struct{}] {
a := newActionWithInputSchema(provider+"/"+name, atype, metadata, fn, inputSchema)
fn Func[In, Out, Stream],
) *Action[In, Out, Stream] {
fullName := name
if provider != "" {
fullName = provider + "/" + name
}
a := newAction(fullName, atype, metadata, inputSchema, fn)
r.RegisterAction(atype, a)
return a
}

type noStream = func(context.Context, struct{}) error

// NewAction creates a new Action with the given name and non-streaming function.
func NewAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb noStream) (Out, error) {
return fn(ctx, in)
})
}

// NewStreamingAction creates a new Action with the given name and streaming function.
func NewStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
// newAction creates a new Action with the given name and arguments.
// If inputSchema is nil, it is inferred from In.
func newAction[In, Out, Stream any](
name string,
atype atype.ActionType,
metadata map[string]any,
inputSchema *jsonschema.Schema,
fn Func[In, Out, Stream],
) *Action[In, Out, Stream] {
var i In
var o Out
if inputSchema == nil {
inputSchema = internal.InferJSONSchema(i)
}
return &Action[In, Out, Stream]{
name: name,
atype: atype,
fn: func(ctx context.Context, input In, sc func(context.Context, Stream) error) (Out, error) {
tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype))
return fn(ctx, input, sc)
},
inputSchema: internal.InferJSONSchema(i),
outputSchema: internal.InferJSONSchema(o),
metadata: metadata,
}
}

func newActionWithInputSchema[Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, any) (Out, error), inputSchema *jsonschema.Schema) *Action[any, Out, struct{}] {
var o Out
return &Action[any, Out, struct{}]{
name: name,
atype: atype,
fn: func(ctx context.Context, input any, sc func(context.Context, struct{}) error) (Out, error) {
tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype))
return fn(ctx, input)
},
inputSchema: inputSchema,
outputSchema: internal.InferJSONSchema(o),
metadata: metadata,
Expand Down
15 changes: 5 additions & 10 deletions go/core/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ import (
"github.com/firebase/genkit/go/internal/registry"
)

func inc(_ context.Context, x int) (int, error) {
func inc(_ context.Context, x int, _ noStream) (int, error) {
return x + 1, nil
}

func TestActionRun(t *testing.T) {
a := NewAction("inc", atype.Custom, nil, inc)
a := newAction("inc", atype.Custom, nil, nil, inc)
got, err := a.Run(context.Background(), 3, nil)
if err != nil {
t.Fatal(err)
Expand All @@ -41,7 +41,7 @@ func TestActionRun(t *testing.T) {
}

func TestActionRunJSON(t *testing.T) {
a := NewAction("inc", atype.Custom, nil, inc)
a := newAction("inc", atype.Custom, nil, nil, inc)
input := []byte("3")
want := []byte("4")
got, err := a.RunJSON(context.Background(), input, nil)
Expand All @@ -53,11 +53,6 @@ func TestActionRunJSON(t *testing.T) {
}
}

func TestNewAction(t *testing.T) {
// Verify that struct{} can occur in the function signature.
_ = NewAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil })
}

// count streams the numbers from 0 to n-1, then returns n.
func count(ctx context.Context, n int, cb func(context.Context, int) error) (int, error) {
if cb != nil {
Expand All @@ -72,7 +67,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int

func TestActionStreaming(t *testing.T) {
ctx := context.Background()
a := NewStreamingAction("count", atype.Custom, nil, count)
a := newAction("count", atype.Custom, nil, nil, count)
const n = 3

// Non-streaming.
Expand Down Expand Up @@ -105,7 +100,7 @@ func TestActionStreaming(t *testing.T) {
func TestActionTracing(t *testing.T) {
ctx := context.Background()
const actionName = "TestTracing-inc"
a := NewAction(actionName, atype.Custom, nil, inc)
a := newAction(actionName, atype.Custom, nil, nil, inc)
if _, err := a.Run(context.Background(), 3, nil); err != nil {
t.Fatal(err)
}
Expand Down
24 changes: 9 additions & 15 deletions go/genkit/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,15 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.
outputSchema: internal.InferJSONSchema(o),
// TODO(jba): set stateStore?
}
a := f.action()
r.RegisterAction(atype.Flow, a)
metadata := map[string]any{
"inputSchema": f.inputSchema,
"outputSchema": f.outputSchema,
}
afunc := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) {
tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true")
return f.runInstruction(ctx, inst, common.StreamingCallback[Stream](cb))
}
core.DefineActionInRegistry(r, "", f.name, atype.Flow, metadata, nil, afunc)
f.tstate = r.TracingState()
r.RegisterFlow(f)
return f
Expand Down Expand Up @@ -288,19 +295,6 @@ type FlowResult[Out any] struct {

// FlowResult is called FlowResponse in the javascript.

// action creates an action for the flow. See the comment at the top of this file for more information.
func (f *Flow[In, Out, Stream]) action() *core.Action[*flowInstruction[In], *flowState[In, Out], Stream] {
metadata := map[string]any{
"inputSchema": f.inputSchema,
"outputSchema": f.outputSchema,
}
cback := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) {
tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true")
return f.runInstruction(ctx, inst, common.StreamingCallback[Stream](cb))
}
return core.NewStreamingAction(f.name, atype.Flow, metadata, cback)
}

// runInstruction performs one of several actions on a flow, as determined by msg.
// (Called runEnvelope in the js.)
func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowInstruction[In], cb common.StreamingCallback[Stream]) (*flowState[In, Out], error) {
Expand Down
12 changes: 6 additions & 6 deletions go/genkit/servers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ import (
"github.com/invopop/jsonschema"
)

func inc(_ context.Context, x int) (int, error) {
func inc(_ context.Context, x int, _ noStream) (int, error) {
return x + 1, nil
}

func dec(_ context.Context, x int) (int, error) {
func dec(_ context.Context, x int, _ noStream) (int, error) {
return x - 1, nil
}

Expand All @@ -46,12 +46,12 @@ func TestDevServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r.RegisterAction(atype.Custom, core.NewAction("devServer/inc", atype.Custom, map[string]any{
core.DefineActionInRegistry(r, "devServer", "inc", atype.Custom, map[string]any{
"foo": "bar",
}, inc))
r.RegisterAction(atype.Custom, core.NewAction("devServer/dec", atype.Custom, map[string]any{
}, nil, inc)
core.DefineActionInRegistry(r, "devServer", "dec", atype.Custom, map[string]any{
"bar": "baz",
}, dec))
}, nil, dec)
srv := httptest.NewServer(newDevServeMux(r))
defer srv.Close()

Expand Down