diff --git a/go/genkit/flow.go b/go/genkit/flow.go index d39291a185..a359fbb4c4 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -19,6 +19,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "strconv" "sync" "time" @@ -112,6 +113,7 @@ func defineFlow[I, O, S any](r *registry, name string, fn Func[I, O, S]) *Flow[I r.registerAction(ActionTypeFlow, name, a) // TODO(jba): this is a roundabout way to transmit the tracing state. Is there a cleaner way? f.tstate = a.tstate + r.registerFlow(f) return f } @@ -231,7 +233,10 @@ type operation[O any] struct { type FlowResult[O any] struct { Response O `json:"response,omitempty"` Error string `json:"error,omitempty"` - // TODO(jba): keep the actual error around so that RunFlow can use it. + // The Error field above is not used in the code, but it gets marshaled + // into JSON. + // TODO(jba): replace with a type that implements error and json.Marshaler. + err error StackTrace string `json:"stacktrace,omitempty"` } @@ -273,6 +278,50 @@ func (f *Flow[I, O, S]) runInstruction(ctx context.Context, inst *flowInstructio } } +// flow is the type that all Flow[I, O, S] have in common. +type flow interface { + Name() string + + // runJSON uses encoding/json to unmarshal the input, + // calls Flow.start, then returns the marshaled result. + runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error) +} + +func (f *Flow[I, O, S]) Name() string { return f.name } + +func (f *Flow[I, O, S]) runJSON(ctx context.Context, input json.RawMessage, cb StreamingCallback[json.RawMessage]) (json.RawMessage, error) { + var in I + if err := json.Unmarshal(input, &in); err != nil { + return nil, &httpError{http.StatusBadRequest, err} + } + // If there is a callback, wrap it to turn an S into a json.RawMessage. + var callback StreamingCallback[S] + if cb != nil { + callback = func(ctx context.Context, s S) error { + bytes, err := json.Marshal(s) + if err != nil { + return err + } + return cb(ctx, json.RawMessage(bytes)) + } + } + fstate, err := f.start(ctx, in, callback) + if err != nil { + return nil, err + } + if fstate.Operation == nil { + return nil, errors.New("nil operation") + } + res := fstate.Operation.Result + if res == nil { + return nil, errors.New("nil result") + } + if res.err != nil { + return nil, res.err + } + return json.Marshal(res.Response) +} + // start starts executing the flow with the given input. func (f *Flow[I, O, S]) start(ctx context.Context, input I, cb StreamingCallback[S]) (_ *flowState[I, O], err error) { flowID, err := generateFlowID() @@ -346,6 +395,7 @@ func (f *Flow[I, O, S]) execute(ctx context.Context, state *flowState[I, O], dis state.Operation.Done = true if err != nil { state.Operation.Result = &FlowResult[O]{ + err: err, Error: err.Error(), // TODO(jba): stack trace? } @@ -535,8 +585,8 @@ func finishedOpResponse[O any](op *operation[O]) (O, error) { if !op.Done { return internal.Zero[O](), fmt.Errorf("flow %s did not finish execution", op.FlowID) } - if op.Result.Error != "" { - return internal.Zero[O](), fmt.Errorf("flow %s: %s", op.FlowID, op.Result.Error) + if op.Result.err != nil { + return internal.Zero[O](), fmt.Errorf("flow %s: %w", op.FlowID, op.Result.err) } return op.Result.Response, nil } diff --git a/go/genkit/flow_test.go b/go/genkit/flow_test.go index 0d24b134ce..b8d59d2d22 100644 --- a/go/genkit/flow_test.go +++ b/go/genkit/flow_test.go @@ -17,6 +17,7 @@ package genkit import ( "context" "encoding/json" + "errors" "slices" "testing" @@ -46,7 +47,10 @@ func TestFlowStart(t *testing.T) { Response: 2, }, } - if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(operation[int]{}, "FlowID")); diff != "" { + diff := cmp.Diff(want, got, + cmpopts.IgnoreFields(operation[int]{}, "FlowID"), + cmpopts.IgnoreUnexported(FlowResult[int]{}, flowState[int, int]{})) + if diff != "" { t.Errorf("mismatch (-want, +got):\n%s", diff) } } @@ -147,6 +151,7 @@ func TestFlowState(t *testing.T) { Metadata: "meta", Result: &FlowResult[int]{ Response: 6, + err: errors.New("err"), Error: "err", StackTrace: "st", }, @@ -161,7 +166,7 @@ func TestFlowState(t *testing.T) { if err := json.Unmarshal(data, &got); err != nil { t.Fatal(err) } - diff := cmp.Diff(fs, got, cmpopts.IgnoreUnexported(flowState[int, int]{})) + diff := cmp.Diff(fs, got, cmpopts.IgnoreUnexported(flowState[int, int]{}, FlowResult[int]{})) if diff != "" { t.Errorf("mismatch (-want, +got):\n%s", diff) } diff --git a/go/genkit/registry.go b/go/genkit/registry.go index 50ac53d9c9..205bb14256 100644 --- a/go/genkit/registry.go +++ b/go/genkit/registry.go @@ -46,6 +46,7 @@ type registry struct { tstate *tracing.State mu sync.Mutex actions map[string]action + flows []flow // TraceStores, at most one for each [Environment]. // Only the prod trace store is actually registered; the dev one is // always created automatically. But it's simpler if we keep them together here. @@ -144,6 +145,20 @@ func (r *registry) listActions() []actionDesc { return ads } +// registerFlow stores the flow for use by the production server (see [NewFlowServeMux]). +// It doesn't check for duplicates because registerAction will do that. +func (r *registry) registerFlow(f flow) { + r.mu.Lock() + defer r.mu.Unlock() + r.flows = append(r.flows, f) +} + +func (r *registry) listFlows() []flow { + r.mu.Lock() + defer r.mu.Unlock() + return r.flows +} + // RegisterTraceStore uses the given trace.Store to record traces in the prod environment. // (A trace.Store that writes to the local filesystem is always installed in the dev environment.) // The returned function should be called before the program ends to ensure that diff --git a/go/genkit/dev_server.go b/go/genkit/servers.go similarity index 64% rename from go/genkit/dev_server.go rename to go/genkit/servers.go index a782161122..7d8107a84b 100644 --- a/go/genkit/dev_server.go +++ b/go/genkit/servers.go @@ -12,18 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package genkit - -// This file implements a server used for development. -// The genkit CLI sends requests to it. +// This file implements production and development servers. // +// The genkit CLI sends requests to the development server. // See js/common/src/reflectionApi.ts. +// +// The production server has a route for each flow. It +// is intended for production deployments. + +package genkit import ( "context" "encoding/json" "errors" "fmt" + "io" "io/fs" "log/slog" "net/http" @@ -39,48 +43,43 @@ import ( "go.opentelemetry.io/otel/trace" ) -// StartDevServer starts the development server (reflection API) listening at the given address. -// If addr is "", it uses ":3100". -// StartDevServer always returns a non-nil error, the one returned by http.ListenAndServe. -func StartDevServer(addr string) error { - mux := newDevServerMux(globalRegistry) - if addr == "" { - port := os.Getenv("GENKIT_REFLECTION_PORT") - if port != "" { - addr = ":" + port - } else { - // Don't use "localhost" here. That only binds the IPv4 address, and the genkit tool - // wants to connect to the IPv6 address even when you tell it to use "localhost". - // Omitting the host works. - addr = ":3100" - } - } - server := &http.Server{ - Addr: addr, - Handler: mux, +// StartFlowServer starts a server serving the routes described in [NewFlowServeMux]. +// It listens on addr, or if empty, the value of the PORT environment variable, +// or if that is empty, ":3400". +// +// In development mode (if the environment variable GENKIT_ENV=dev), it also starts +// a dev server. +// +// StartFlowServer always returns a non-nil error, the one returned by http.ListenAndServe. +func StartFlowServer(addr string) error { + if currentEnvironment() == EnvironmentDev { + go func() { + err := startDevServer("") + slog.Error("dev server stopped", "err", err) + }() } - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGTERM) - go func() { - <-sigCh - slog.Info("received SIGTERM, shutting down server") - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := server.Shutdown(ctx); err != nil { - slog.Error("server shutdown failed", "err", err) - } else { - slog.Info("server shutdown successfully") - } - }() - slog.Info("listening", "addr", addr) - return server.ListenAndServe() + return startProdServer(addr) +} + +// startDevServer starts the development server (reflection API) listening at the given address. +// If addr is "", it uses the value of the environment variable GENKIT_REFLECTION_PORT +// for the port, and if that is empty it uses ":3100". +// startDevServer always returns a non-nil error, the one returned by http.ListenAndServe. +func startDevServer(addr string) error { + slog.Info("starting dev server") + // Don't use "localhost" here. That only binds the IPv4 address, and the genkit tool + // wants to connect to the IPv6 address even when you tell it to use "localhost". + // Omitting the host works. + addr = serverAddress(addr, "GENKIT_REFLECTION_PORT", ":3100") + mux := newDevServeMux(globalRegistry) + return listenAndServe(addr, mux) } type devServer struct { reg *registry } -func newDevServerMux(r *registry) *http.ServeMux { +func newDevServeMux(r *registry) *http.ServeMux { mux := http.NewServeMux() s := &devServer{r} handle(mux, "GET /api/__health", func(w http.ResponseWriter, _ *http.Request) error { @@ -95,52 +94,6 @@ func newDevServerMux(r *registry) *http.ServeMux { return mux } -// requestID is a unique ID for each request. -var requestID atomic.Int64 - -// handle registers pattern on mux with an http.Handler that calls f. -// If f returns a non-nil error, the handler calls http.Error. -// If the error is an httpError, the code it contains is used as the status code; -// otherwise a 500 status is used. -func handle(mux *http.ServeMux, pattern string, f func(w http.ResponseWriter, r *http.Request) error) { - mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { - id := requestID.Add(1) - // Create a logger that always outputs the requestID, and store it in the request context. - log := slog.Default().With("reqID", id) - log.Info("request start", - "method", r.Method, - "path", r.URL.Path) - var err error - defer func() { - if err != nil { - log.Error("request end", "err", err) - } else { - log.Info("request end") - } - }() - err = f(w, r) - if err != nil { - // If the error is an httpError, serve the status code it contains. - // Otherwise, assume this is an unexpected error and serve a 500. - var herr *httpError - if errors.As(err, &herr) { - http.Error(w, herr.Error(), herr.code) - } else { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - } - }) -} - -type httpError struct { - code int - err error -} - -func (e *httpError) Error() string { - return fmt.Sprintf("%s: %s", http.StatusText(e.code), e.err) -} - // handleRunAction looks up an action by name in the registry, runs it with the // provded JSON input, and writes back the JSON-marshaled request. func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) error { @@ -280,6 +233,166 @@ type listFlowStatesResult struct { ContinuationToken string `json:"continuationToken"` } +// startProdServer starts a production server listening at the given address. +// The Server has a route for each defined flow. +// If addr is "", it uses the value of the environment variable PORT +// for the port, and if that is empty it uses ":3400". +// startProdServer always returns a non-nil error, the one returned by http.ListenAndServe. +// +// To construct a server with additional routes, use [NewFlowServeMux]. +func startProdServer(addr string) error { + slog.Info("starting flow server") + addr = serverAddress(addr, "PORT", ":3400") + mux := NewFlowServeMux() + return listenAndServe(addr, mux) +} + +// NewFlowServeMux constructs a [net/http.ServeMux] where each defined flow is a route. +// All routes take a single query parameter, "stream", which if true will stream the +// flow's results back to the client. (Not all flows support streaming, however.) +// +// To use the returned ServeMux as part of a server with other routes, either add routes +// to it, or install it as part of another ServeMux, like so: +// +// mainMux := http.NewServeMux() +// mainMux.Handle("POST /flow/", http.StripPrefix("/flow/", NewFlowServeMux())) +func NewFlowServeMux() *http.ServeMux { + return newFlowServeMux(globalRegistry) +} + +func newFlowServeMux(r *registry) *http.ServeMux { + mux := http.NewServeMux() + for _, f := range r.listFlows() { + handle(mux, "POST /"+f.Name(), nonDurableFlowHandler(f)) + } + return mux +} + +func nonDurableFlowHandler(f flow) func(http.ResponseWriter, *http.Request) error { + return func(w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + input, err := io.ReadAll(r.Body) + if err != nil { + return err + } + stream, err := parseBoolQueryParam(r, "stream") + if err != nil { + return err + } + if stream { + // TODO(jba): implement streaming. + return &httpError{http.StatusNotImplemented, errors.New("streaming")} + } else { + // TODO(jba): telemetry + out, err := f.runJSON(r.Context(), json.RawMessage(input), nil) + if err != nil { + return err + } + // Responses for non-streaming, non-durable flows are passed back + // with the flow result stored in a field called "result." + _, err = fmt.Fprintf(w, `{"result": %s}\n`, out) + return err + } + } +} + +// serverAddress determines a server address. +func serverAddress(arg, envVar, defaultValue string) string { + if arg != "" { + return arg + } + if port := os.Getenv(envVar); port != "" { + return ":" + port + } + return defaultValue +} + +func listenAndServe(addr string, mux *http.ServeMux) error { + server := &http.Server{ + Addr: addr, + Handler: mux, + } + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGTERM) + go func() { + <-sigCh + slog.Info("received SIGTERM, shutting down server") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := server.Shutdown(ctx); err != nil { + slog.Error("server shutdown failed", "err", err) + } else { + slog.Info("server shutdown successfully") + } + }() + slog.Info("listening", "addr", addr) + return server.ListenAndServe() +} + +// requestID is a unique ID for each request. +var requestID atomic.Int64 + +// handle registers pattern on mux with an http.Handler that calls f. +// If f returns a non-nil error, the handler calls http.Error. +// If the error is an httpError, the code it contains is used as the status code; +// otherwise a 500 status is used. +func handle(mux *http.ServeMux, pattern string, f func(w http.ResponseWriter, r *http.Request) error) { + mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { + id := requestID.Add(1) + // Create a logger that always outputs the requestID, and store it in the request context. + log := slog.Default().With("reqID", id) + log.Info("request start", + "method", r.Method, + "path", r.URL.Path) + var err error + defer func() { + if err != nil { + log.Error("request end", "err", err) + } else { + log.Info("request end") + } + }() + err = f(w, r) + if err != nil { + // If the error is an httpError, serve the status code it contains. + // Otherwise, assume this is an unexpected error and serve a 500. + var herr *httpError + if errors.As(err, &herr) { + http.Error(w, herr.Error(), herr.code) + } else { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } + }) +} + +type httpError struct { + code int + err error +} + +func (e *httpError) Error() string { + return fmt.Sprintf("%s: %s", http.StatusText(e.code), e.err) +} + +func parseBoolQueryParam(r *http.Request, name string) (bool, error) { + b := false + if s := r.FormValue(name); s != "" { + var err error + b, err = strconv.ParseBool(s) + if err != nil { + return false, &httpError{http.StatusBadRequest, err} + } + } + return b, nil +} + +func currentEnvironment() Environment { + if v := os.Getenv("GENKIT_ENV"); v != "" { + return Environment(v) + } + return EnvironmentProd +} func writeJSON(ctx context.Context, w http.ResponseWriter, value any) error { data, err := json.MarshalIndent(value, "", " ") if err != nil { diff --git a/go/genkit/dev_server_test.go b/go/genkit/servers_test.go similarity index 83% rename from go/genkit/dev_server_test.go rename to go/genkit/servers_test.go index dee01d9cf8..0514d4bb46 100644 --- a/go/genkit/dev_server_test.go +++ b/go/genkit/servers_test.go @@ -44,7 +44,7 @@ func TestDevServer(t *testing.T) { r.registerAction("test", "devServer", NewAction("dec", map[string]any{ "bar": "baz", }, dec)) - srv := httptest.NewServer(newDevServerMux(r)) + srv := httptest.NewServer(newDevServeMux(r)) defer srv.Close() t.Run("runAction", func(t *testing.T) { @@ -120,6 +120,45 @@ func TestDevServer(t *testing.T) { }) } +func TestProdServer(t *testing.T) { + r, err := newRegistry() + if err != nil { + t.Fatal(err) + } + defineFlow(r, "inc", func(_ context.Context, i int, _ NoStream) (int, error) { + return i + 1, nil + }) + srv := httptest.NewServer(newFlowServeMux(r)) + defer srv.Close() + + check := func(t *testing.T, input string, wantStatus, wantResult int) { + res, err := http.Post(srv.URL+"/inc", "application/json", strings.NewReader(input)) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if g, w := res.StatusCode, wantStatus; g != w { + t.Fatalf("status: got %d, want %d", g, w) + } + if res.StatusCode != 200 { + return + } + type resultType struct { + Result int + } + got, err := readJSON[resultType](res.Body) + if err != nil { + t.Fatal(err) + } + if g, w := got.Result, wantResult; g != w { + t.Errorf("result: got %d, want %d", g, w) + } + } + + t.Run("ok", func(t *testing.T) { check(t, "2", 200, 3) }) + t.Run("bad", func(t *testing.T) { check(t, "true", 400, 0) }) +} + func checkActionTrace(t *testing.T, reg *registry, tid, name string) { ts := reg.lookupTraceStore(EnvironmentDev) td, err := ts.Load(context.Background(), tid) diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 73171c3979..272e796058 100644 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -13,6 +13,8 @@ // limitations under the License. // This program can be manually tested like so: +// +// In development mode (with the environment variable GENKIT_ENV="dev"): // Start the server listening on port 3100: // // go run . & @@ -20,6 +22,16 @@ // Tell it to run an action: // // curl -d '{"key":"/flow/testAllCoffeeFlows/testAllCoffeeFlows", "input":{"start": {"input":null}}}' http://localhost:3100/api/runAction +// +// In production mode (GENKIT_ENV missing or set to "prod"): +// Start the server listening on port 3400: +// +// go run . & +// +// Tell it to run a flow: +// +// curl -d '{"customerName": "Stimpy"}' http://localhost:3400/simpleGreeting + package main import ( @@ -171,8 +183,7 @@ func main() { } return out, nil }) - - if err := genkit.StartDevServer(""); err != nil { + if err := genkit.StartFlowServer(""); err != nil { log.Fatal(err) } } diff --git a/go/samples/flow-sample1/main.go b/go/samples/flow-sample1/main.go index 215d10130a..a42141f189 100644 --- a/go/samples/flow-sample1/main.go +++ b/go/samples/flow-sample1/main.go @@ -73,7 +73,7 @@ func main() { return fmt.Sprintf("done: %d, streamed: %d times", count, i), nil }) - if err := genkit.StartDevServer(""); err != nil { + if err := genkit.StartFlowServer(""); err != nil { log.Fatal(err) } }