diff --git a/go/core/file_flow_state_store.go b/go/core/file_flow_state_store.go index 1caafcb848..2b85ce35f8 100644 --- a/go/core/file_flow_state_store.go +++ b/go/core/file_flow_state_store.go @@ -38,9 +38,11 @@ func NewFileFlowStateStore(dir string) (*FileFlowStateStore, error) { } func (s *FileFlowStateStore) Save(ctx context.Context, id string, fs common.FlowStater) error { - fs.Lock() - defer fs.Unlock() - return internal.WriteJSONFile(filepath.Join(s.dir, internal.Clean(id)), fs) + data, err := fs.ToJSON() + if err != nil { + return err + } + return os.WriteFile(filepath.Join(s.dir, internal.Clean(id)), data, 0666) } func (s *FileFlowStateStore) Load(ctx context.Context, id string, pfs any) error { diff --git a/go/genkit/flow.go b/go/genkit/flow.go index 2f4767ec71..bcd30dbbb6 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -15,6 +15,7 @@ package genkit import ( + "bytes" "context" "encoding/json" "errors" @@ -232,10 +233,29 @@ func newFlowState[In, Out any](id, name string, input In) *flowState[In, Out] { } // flowState implements common.FlowStater. -func (fs *flowState[In, Out]) IsFlowState() {} -func (fs *flowState[In, Out]) Lock() { fs.mu.Lock() } -func (fs *flowState[In, Out]) Unlock() { fs.mu.Unlock() } -func (fs *flowState[In, Out]) GetCache() map[string]json.RawMessage { return fs.Cache } +func (fs *flowState[In, Out]) IsFlowState() {} + +func (fs *flowState[In, Out]) ToJSON() ([]byte, error) { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetIndent("", " ") // make the value easy to read for debugging + if err := enc.Encode(fs); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (fs *flowState[In, Out]) CacheAt(key string) json.RawMessage { + fs.mu.Lock() + defer fs.mu.Unlock() + return fs.Cache[key] +} + +func (fs *flowState[In, Out]) CacheSet(key string, val json.RawMessage) { + fs.mu.Lock() + defer fs.mu.Unlock() + fs.Cache[key] = val +} // An operation describes the state of a Flow that may still be in progress. type operation[Out any] struct { @@ -528,10 +548,8 @@ func Run[Out any](ctx context.Context, name string, f func() (Out, error)) (Out, // happen because every step has a unique cache key. // TODO(jba): don't memoize a nested flow (see context.ts) fs := fc.stater() - fs.Lock() - j, ok := fs.GetCache()[uName] - fs.Unlock() - if ok { + j := fs.CacheAt(uName) + if j != nil { var t Out if err := json.Unmarshal(j, &t); err != nil { return internal.Zero[Out](), err @@ -547,9 +565,7 @@ func Run[Out any](ctx context.Context, name string, f func() (Out, error)) (Out, if err != nil { return internal.Zero[Out](), err } - fs.Lock() - fs.GetCache()[uName] = json.RawMessage(bytes) - fs.Unlock() + fs.CacheSet(uName, json.RawMessage(bytes)) tracing.SetCustomMetadataAttr(ctx, "flow:state", "run") return t, nil }) diff --git a/go/internal/common/common.go b/go/internal/common/common.go index 6c2626da04..7feacf9b83 100644 --- a/go/internal/common/common.go +++ b/go/internal/common/common.go @@ -70,9 +70,9 @@ func CurrentEnvironment() Environment { // FlowStater is the common type of all flowState[I, O] types. type FlowStater interface { IsFlowState() - Lock() - Unlock() - GetCache() map[string]json.RawMessage + ToJSON() ([]byte, error) + CacheAt(key string) json.RawMessage + CacheSet(key string, val json.RawMessage) } // StreamingCallback is the type of streaming callbacks.