Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import (

"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/action"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/metrics"
"github.com/firebase/genkit/go/internal/registry"
"github.com/invopop/jsonschema"
Expand Down Expand Up @@ -235,8 +235,8 @@ func (a *Action[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMes
}

// Desc returns a description of the action.
func (a *Action[I, O, S]) Desc() common.ActionDesc {
ad := common.ActionDesc{
func (a *Action[I, O, S]) Desc() action.Desc {
ad := action.Desc{
Name: a.name,
Description: a.description,
Metadata: a.metadata,
Expand Down
3 changes: 1 addition & 2 deletions go/core/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"testing"

"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/registry"
)

Expand Down Expand Up @@ -105,7 +104,7 @@ func TestActionTracing(t *testing.T) {
t.Fatal(err)
}
// The dev TraceStore is registered by Init, called from TestMain.
ts := registry.Global.LookupTraceStore(common.EnvironmentDev)
ts := registry.Global.LookupTraceStore(registry.EnvironmentDev)
tds, _, err := ts.List(ctx, nil)
if err != nil {
t.Fatal(err)
Expand Down
3 changes: 1 addition & 2 deletions go/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"context"

"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/registry"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
)
Expand All @@ -39,7 +38,7 @@ import (
// all pending data is stored.
// RegisterTraceStore panics if called more than once.
func RegisterTraceStore(ts tracing.Store) (shutdown func(context.Context) error) {
registry.Global.RegisterTraceStore(common.EnvironmentProd, ts)
registry.Global.RegisterTraceStore(registry.EnvironmentProd, ts)
return registry.Global.TracingState().AddTraceStoreBatch(ts)
}

Expand Down
3 changes: 1 addition & 2 deletions go/core/file_flow_state_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"path/filepath"

"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/common"
)

// A FileFlowStateStore is a FlowStateStore that writes flowStates to files.
Expand All @@ -37,7 +36,7 @@ func NewFileFlowStateStore(dir string) (*FileFlowStateStore, error) {
return &FileFlowStateStore{dir: dir}, nil
}

func (s *FileFlowStateStore) Save(ctx context.Context, id string, fs common.FlowStater) error {
func (s *FileFlowStateStore) Save(ctx context.Context, id string, fs base.FlowStater) error {
data, err := fs.ToJSON()
if err != nil {
return err
Expand Down
8 changes: 4 additions & 4 deletions go/core/flow_state_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ package core
import (
"context"

"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/base"
)

// A FlowStateStore stores flow states.
// Every flow state has a unique string identifier.
// A durable FlowStateStore is necessary for durable flows.
type FlowStateStore interface {
// Save saves the FlowState to the store, overwriting an existing one.
Save(ctx context.Context, id string, fs common.FlowStater) error
Save(ctx context.Context, id string, fs base.FlowStater) error
// Load reads the FlowState with the given ID from the store.
// It returns an error that is fs.ErrNotExist if there isn't one.
// pfs must be a pointer to a flowState[I, O] of the correct type.
Expand All @@ -35,5 +35,5 @@ type FlowStateStore interface {
// nopFlowStateStore is a FlowStateStore that does nothing.
type nopFlowStateStore struct{}

func (nopFlowStateStore) Save(ctx context.Context, id string, fs common.FlowStater) error { return nil }
func (nopFlowStateStore) Load(ctx context.Context, id string, pfs any) error { return nil }
func (nopFlowStateStore) Save(ctx context.Context, id string, fs base.FlowStater) error { return nil }
func (nopFlowStateStore) Load(ctx context.Context, id string, pfs any) error { return nil }
3 changes: 1 addition & 2 deletions go/genkit/conformance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/registry"
"golang.org/x/exp/maps"
)
Expand Down Expand Up @@ -116,7 +115,7 @@ func TestFlowConformance(t *testing.T) {
if test.Trace == nil {
return
}
ts := r.LookupTraceStore(common.EnvironmentDev)
ts := r.LookupTraceStore(registry.EnvironmentDev)
var gotTrace any
if err := ts.LoadAny(resp.Telemetry.TraceID, &gotTrace); err != nil {
t.Fatal(err)
Expand Down
22 changes: 12 additions & 10 deletions go/genkit/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/metrics"
"github.com/firebase/genkit/go/internal/registry"
"github.com/google/uuid"
Expand Down Expand Up @@ -107,6 +106,9 @@ type Flow[In, Out, Stream any] struct {

type noStream = func(context.Context, struct{}) error

// streamingCallback is the type of streaming callbacks.
type streamingCallback[Stream any] func(context.Context, Stream) error

// DefineFlow creates a Flow that runs fn, and registers it as an action.
//
// fn takes an input of type In and returns an output of type Out.
Expand Down Expand Up @@ -152,7 +154,7 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.
}
afunc := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) {
tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true")
return f.runInstruction(ctx, inst, common.StreamingCallback[Stream](cb))
return f.runInstruction(ctx, inst, streamingCallback[Stream](cb))
}
core.DefineActionInRegistry(r, "", f.name, atype.Flow, metadata, nil, afunc)
f.tstate = r.TracingState()
Expand Down Expand Up @@ -239,7 +241,7 @@ func newFlowState[In, Out any](id, name string, input In) *flowState[In, Out] {
}
}

// flowState implements common.FlowStater.
// flowState implements base.FlowStater.
func (fs *flowState[In, Out]) IsFlowState() {}

func (fs *flowState[In, Out]) ToJSON() ([]byte, error) {
Expand Down Expand Up @@ -297,7 +299,7 @@ type FlowResult[Out any] struct {

// runInstruction performs one of several actions on a flow, as determined by msg.
// (Called runEnvelope in the js.)
func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowInstruction[In], cb common.StreamingCallback[Stream]) (*flowState[In, Out], error) {
func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowInstruction[In], cb streamingCallback[Stream]) (*flowState[In, Out], error) {
switch {
case inst.Start != nil:
// TODO(jba): pass msg.Start.Labels.
Expand All @@ -322,7 +324,7 @@ func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowIn
// Name returns the name that the flow was defined with.
func (f *Flow[In, Out, Stream]) Name() string { return f.name }

func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb common.StreamingCallback[json.RawMessage]) (json.RawMessage, error) {
func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) {
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
if err := base.ValidateJSON(input, f.inputSchema); err != nil {
return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err}
Expand All @@ -332,7 +334,7 @@ func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessa
return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err}
}
// If there is a callback, wrap it to turn an S into a json.RawMessage.
var callback common.StreamingCallback[Stream]
var callback streamingCallback[Stream]
if cb != nil {
callback = func(ctx context.Context, s Stream) error {
bytes, err := json.Marshal(s)
Expand Down Expand Up @@ -360,7 +362,7 @@ func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessa
}

// start starts executing the flow with the given input.
func (f *Flow[In, Out, Stream]) start(ctx context.Context, input In, cb common.StreamingCallback[Stream]) (_ *flowState[In, Out], err error) {
func (f *Flow[In, Out, Stream]) start(ctx context.Context, input In, cb streamingCallback[Stream]) (_ *flowState[In, Out], err error) {
flowID, err := generateFlowID()
if err != nil {
return nil, err
Expand All @@ -377,7 +379,7 @@ func (f *Flow[In, Out, Stream]) start(ctx context.Context, input In, cb common.S
//
// This function corresponds to Flow.executeSteps in the js, but does more:
// it creates the flowContext and saves the state.
func (f *Flow[In, Out, Stream]) execute(ctx context.Context, state *flowState[In, Out], dispatchType string, cb common.StreamingCallback[Stream]) {
func (f *Flow[In, Out, Stream]) execute(ctx context.Context, state *flowState[In, Out], dispatchType string, cb streamingCallback[Stream]) {
fctx := newFlowContext(state, f.stateStore, f.tstate)
defer func() {
if err := fctx.finish(ctx); err != nil {
Expand Down Expand Up @@ -477,7 +479,7 @@ type flowContext[I, O any] struct {
// flowContexter is the type of all flowContext[I, O].
type flowContexter interface {
uniqueStepName(string) string
stater() common.FlowStater
stater() base.FlowStater
tracingState() *tracing.State
}

Expand All @@ -489,7 +491,7 @@ func newFlowContext[I, O any](state *flowState[I, O], store core.FlowStateStore,
seenSteps: map[string]int{},
}
}
func (fc *flowContext[I, O]) stater() common.FlowStater { return fc.state }
func (fc *flowContext[I, O]) stater() base.FlowStater { return fc.state }
func (fc *flowContext[I, O]) tracingState() *tracing.State { return fc.tstate }

// finish is called at the end of a flow execution.
Expand Down
3 changes: 1 addition & 2 deletions go/genkit/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"sync"
"syscall"

"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/registry"
)

Expand Down Expand Up @@ -67,7 +66,7 @@ func Init(ctx context.Context, opts *Options) error {
var wg sync.WaitGroup
errCh := make(chan error, 2)

if common.CurrentEnvironment() == common.EnvironmentDev {
if registry.CurrentEnvironment() == registry.EnvironmentDev {
wg.Add(1)
go func() {
defer wg.Done()
Expand Down
20 changes: 10 additions & 10 deletions go/genkit/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ import (

"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/action"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/registry"
"go.opentelemetry.io/otel/trace"
)
Expand Down Expand Up @@ -76,7 +76,7 @@ type flow interface {

// runJSON uses encoding/json to unmarshal the input,
// calls Flow.start, then returns the marshaled result.
runJSON(ctx context.Context, input json.RawMessage, cb common.StreamingCallback[json.RawMessage]) (json.RawMessage, error)
runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error)
}

// startServer starts an HTTP server listening on the address.
Expand Down Expand Up @@ -174,7 +174,7 @@ func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) erro
logger.FromContext(ctx).Debug("running action",
"key", body.Key,
"stream", stream)
var callback common.StreamingCallback[json.RawMessage]
var callback streamingCallback[json.RawMessage]
if stream {
// Stream results are newline-separated JSON.
callback = func(ctx context.Context, msg json.RawMessage) error {
Expand Down Expand Up @@ -204,7 +204,7 @@ type telemetry struct {
TraceID string `json:"traceId"`
}

func runAction(ctx context.Context, reg *registry.Registry, key string, input json.RawMessage, cb common.StreamingCallback[json.RawMessage]) (*runActionResponse, error) {
func runAction(ctx context.Context, reg *registry.Registry, key string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (*runActionResponse, error) {
action := reg.LookupAction(key)
if action == nil {
return nil, &base.HTTPError{Code: http.StatusNotFound, Err: fmt.Errorf("no action with key %q", key)}
Expand All @@ -227,7 +227,7 @@ func runAction(ctx context.Context, reg *registry.Registry, key string, input js
// handleListActions lists all the registered actions.
func (s *devServer) handleListActions(w http.ResponseWriter, r *http.Request) error {
descs := s.reg.ListActions()
descMap := map[string]common.ActionDesc{}
descMap := map[string]action.Desc{}
for _, d := range descs {
descMap[d.Key] = d
}
Expand All @@ -237,7 +237,7 @@ func (s *devServer) handleListActions(w http.ResponseWriter, r *http.Request) er
// handleGetTrace returns a single trace from a TraceStore.
func (s *devServer) handleGetTrace(w http.ResponseWriter, r *http.Request) error {
env := r.PathValue("env")
ts := s.reg.LookupTraceStore(common.Environment(env))
ts := s.reg.LookupTraceStore(registry.Environment(env))
if ts == nil {
return &base.HTTPError{Code: http.StatusNotFound, Err: fmt.Errorf("no TraceStore for environment %q", env)}
}
Expand All @@ -255,7 +255,7 @@ func (s *devServer) handleGetTrace(w http.ResponseWriter, r *http.Request) error
// handleListTraces returns a list of traces from a TraceStore.
func (s *devServer) handleListTraces(w http.ResponseWriter, r *http.Request) error {
env := r.PathValue("env")
ts := s.reg.LookupTraceStore(common.Environment(env))
ts := s.reg.LookupTraceStore(registry.Environment(env))
if ts == nil {
return &base.HTTPError{Code: http.StatusNotFound, Err: fmt.Errorf("no TraceStore for environment %q", env)}
}
Expand Down Expand Up @@ -287,12 +287,12 @@ type listTracesResult struct {
}

func (s *devServer) handleListFlowStates(w http.ResponseWriter, r *http.Request) error {
return writeJSON(r.Context(), w, listFlowStatesResult{[]common.FlowStater{}, ""})
return writeJSON(r.Context(), w, listFlowStatesResult{[]base.FlowStater{}, ""})
}

type listFlowStatesResult struct {
FlowStates []common.FlowStater `json:"flowStates"`
ContinuationToken string `json:"continuationToken"`
FlowStates []base.FlowStater `json:"flowStates"`
ContinuationToken string `json:"continuationToken"`
}

// NewFlowServeMux constructs a [net/http.ServeMux].
Expand Down
8 changes: 4 additions & 4 deletions go/genkit/servers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import (

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/action"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/registry"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
Expand Down Expand Up @@ -87,11 +87,11 @@ func TestDevServer(t *testing.T) {
if res.StatusCode != 200 {
t.Fatalf("got status %d, wanted 200", res.StatusCode)
}
got, err := readJSON[map[string]common.ActionDesc](res.Body)
got, err := readJSON[map[string]action.Desc](res.Body)
if err != nil {
t.Fatal(err)
}
want := map[string]common.ActionDesc{
want := map[string]action.Desc{
"/custom/devServer/inc": {
Key: "/custom/devServer/inc",
Name: "devServer/inc",
Expand Down Expand Up @@ -168,7 +168,7 @@ func TestProdServer(t *testing.T) {
}

func checkActionTrace(t *testing.T, reg *registry.Registry, tid, name string) {
ts := reg.LookupTraceStore(common.EnvironmentDev)
ts := reg.LookupTraceStore(registry.EnvironmentDev)
td, err := ts.Load(context.Background(), tid)
if err != nil {
t.Fatal(err)
Expand Down
Loading