diff --git a/cmd/root/root_test.go b/cmd/root/root_test.go index 68e71edee..07f7abc98 100644 --- a/cmd/root/root_test.go +++ b/cmd/root/root_test.go @@ -9,8 +9,6 @@ import ( func TestDefaultToRun(t *testing.T) { t.Parallel() - rootCmd := NewRootCmd() - tests := []struct { name string args []string @@ -102,7 +100,7 @@ func TestDefaultToRun(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := defaultToRun(rootCmd, tt.args) + got := defaultToRun(NewRootCmd(), tt.args) assert.Equal(t, tt.want, got) }) } diff --git a/pkg/fake/proxy_test.go b/pkg/fake/proxy_test.go index 5927cb637..34fc7dd0c 100644 --- a/pkg/fake/proxy_test.go +++ b/pkg/fake/proxy_test.go @@ -300,6 +300,28 @@ func TestSimulatedStreamCopy_SSEEvents(t *testing.T) { assert.GreaterOrEqual(t, elapsed, 3*chunkDelay, "should have delays between data chunks") } +// notifyWriter wraps an http.ResponseWriter and signals on first Write. +type notifyWriter struct { + http.ResponseWriter + notify chan struct{} + notified bool +} + +func (w *notifyWriter) Write(p []byte) (int, error) { + n, err := w.ResponseWriter.Write(p) + if n > 0 && !w.notified { + w.notified = true + close(w.notify) + } + return n, err +} + +func (w *notifyWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + func TestSimulatedStreamCopy_ContextCancellation(t *testing.T) { // Create a reader that provides some data then blocks // to allow context cancellation to be tested @@ -321,17 +343,24 @@ func TestSimulatedStreamCopy_ContextCancellation(t *testing.T) { rec := httptest.NewRecorder() ctx, cancel := context.WithCancel(t.Context()) req = req.WithContext(ctx) - c := e.NewContext(req, rec) + + // Wrap the recorder so we get notified when the first chunk is written, + // without racing on rec.Body. + firstWrite := make(chan struct{}) + nw := ¬ifyWriter{ResponseWriter: rec, notify: firstWrite} + c := e.NewContext(req, nw) done := make(chan error, 1) go func() { done <- SimulatedStreamCopy(c, resp, 10*time.Millisecond) }() - // Wait until at least the first chunk has been written to the recorder - require.Eventually(t, func() bool { - return rec.Body.Len() > 0 - }, time.Second, 5*time.Millisecond, "expected first chunk to be written") + // Wait until the first chunk has been written to the recorder. + select { + case <-firstWrite: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first chunk to be written") + } // Cancel the context and close the body (simulating client disconnect) cancel() @@ -347,6 +376,6 @@ func TestSimulatedStreamCopy_ContextCancellation(t *testing.T) { t.Fatal("SimulatedStreamCopy did not return after context cancellation") } - // Verify first chunk was written + // Verify first chunk was written (safe to read after goroutine finished) assert.Contains(t, rec.Body.String(), "data: first") } diff --git a/pkg/session/session.go b/pkg/session/session.go index 85b6304d9..037455c7d 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -4,6 +4,7 @@ import ( "log/slog" "os" "strings" + "sync" "time" "github.com/google/uuid" @@ -51,6 +52,9 @@ func (si *Item) IsSubSession() bool { // Session represents the agent's state including conversation history and variables type Session struct { + // mu protects Messages from concurrent read/write access. + mu sync.RWMutex `json:"-"` + // ID is the unique identifier for the session ID string `json:"id"` @@ -216,16 +220,67 @@ type EvalCriteria struct { Setup string `json:"setup,omitempty"` // Optional sh script to run in the container before cagent run --exec } +// deepCopyMessage returns a deep copy of a session Message. +// It copies the inner chat.Message's slice and pointer fields so that the +// returned value shares no mutable state with the original. +func deepCopyMessage(m *Message) *Message { + cp := *m + cp.Message = deepCopyChatMessage(m.Message) + return &cp +} + +// deepCopyChatMessage returns a deep copy of a chat.Message, duplicating +// all slice and pointer fields that would otherwise alias the original. +func deepCopyChatMessage(m chat.Message) chat.Message { + if m.MultiContent != nil { + orig := m.MultiContent + m.MultiContent = make([]chat.MessagePart, len(orig)) + for i, part := range orig { + if part.ImageURL != nil { + imgCopy := *part.ImageURL + part.ImageURL = &imgCopy + } + if part.File != nil { + fileCopy := *part.File + part.File = &fileCopy + } + m.MultiContent[i] = part + } + } + if m.FunctionCall != nil { + fcCopy := *m.FunctionCall + m.FunctionCall = &fcCopy + } + if m.ToolCalls != nil { + m.ToolCalls = append([]tools.ToolCall(nil), m.ToolCalls...) + } + if m.ToolDefinitions != nil { + m.ToolDefinitions = append([]tools.Tool(nil), m.ToolDefinitions...) + } + if m.Usage != nil { + usageCopy := *m.Usage + m.Usage = &usageCopy + } + if m.ThoughtSignature != nil { + m.ThoughtSignature = append([]byte(nil), m.ThoughtSignature...) + } + return m +} + // Session helper methods // AddMessage adds a message to the session func (s *Session) AddMessage(msg *Message) { + s.mu.Lock() s.Messages = append(s.Messages, NewMessageItem(msg)) + s.mu.Unlock() } // AddSubSession adds a sub-session to the session func (s *Session) AddSubSession(subSession *Session) { + s.mu.Lock() s.Messages = append(s.Messages, NewSubSessionItem(subSession)) + s.mu.Unlock() } // Duration calculates the duration of the session from message timestamps. @@ -258,8 +313,19 @@ func (s *Session) AllowedDirectories() []string { // GetAllMessages extracts all messages from the session, including from sub-sessions func (s *Session) GetAllMessages() []Message { + s.mu.RLock() + items := make([]Item, len(s.Messages)) + for i, item := range s.Messages { + if item.Message != nil { + items[i] = Item{Message: deepCopyMessage(item.Message)} + } else { + items[i] = item + } + } + s.mu.RUnlock() + var messages []Message - for _, item := range s.Messages { + for _, item := range items { if item.IsMessage() && item.Message.Message.Role != chat.MessageRoleSystem { messages = append(messages, *item.Message) } else if item.IsSubSession() { @@ -408,6 +474,9 @@ func (s *Session) IsSubSession() bool { // MessageCount returns the number of items that contain a message. func (s *Session) MessageCount() int { + s.mu.RLock() + defer s.mu.RUnlock() + n := 0 for _, item := range s.Messages { if item.IsMessage() { @@ -421,6 +490,9 @@ func (s *Session) MessageCount() int { // sub-sessions, and summary items. It does not use the session-level Cost // field, which exists only for backward-compatible persistence. func (s *Session) TotalCost() float64 { + s.mu.RLock() + defer s.mu.RUnlock() + var cost float64 for _, item := range s.Messages { switch { @@ -439,6 +511,9 @@ func (s *Session) TotalCost() float64 { // This is used for live event emissions where sub-sessions report their // own costs separately. func (s *Session) OwnCost() float64 { + s.mu.RLock() + defer s.mu.RUnlock() + var cost float64 for _, item := range s.Messages { if item.IsMessage() { @@ -609,22 +684,22 @@ func buildContextSpecificSystemMessages(a *agent.Agent, s *Session) []chat.Messa // if one exists. Session summaries are context-specific per session and thus should not have a checkpoint (they will be cached alongside the first user message anyway) // // lastSummaryIndex is the index of the last summary item in s.Messages, or -1 if none exists. -func buildSessionSummaryMessages(s *Session) ([]chat.Message, int) { +func buildSessionSummaryMessages(items []Item) ([]chat.Message, int) { var messages []chat.Message // Find the last summary index to determine where conversation messages start // and to include the summary in session summary messages lastSummaryIndex := -1 - for i := len(s.Messages) - 1; i >= 0; i-- { - if s.Messages[i].Summary != "" { + for i := len(items) - 1; i >= 0; i-- { + if items[i].Summary != "" { lastSummaryIndex = i break } } - if lastSummaryIndex >= 0 && lastSummaryIndex < len(s.Messages) { + if lastSummaryIndex >= 0 && lastSummaryIndex < len(items) { messages = append(messages, chat.Message{ Role: chat.MessageRoleUser, - Content: "Session Summary: " + s.Messages[lastSummaryIndex].Summary, + Content: "Session Summary: " + items[lastSummaryIndex].Summary, CreatedAt: time.Now().Format(time.RFC3339), }) } @@ -643,8 +718,21 @@ func (s *Session) GetMessages(a *agent.Agent) []chat.Message { contextMessages := buildContextSpecificSystemMessages(a, s) markLastMessageAsCacheControl(contextMessages) + // Take a snapshot of Messages under the lock, copying Message structs + // to avoid racing with UpdateMessage which may modify the pointed-to objects. + s.mu.RLock() + items := make([]Item, len(s.Messages)) + for i, item := range s.Messages { + if item.Message != nil { + items[i] = Item{Message: deepCopyMessage(item.Message), Summary: item.Summary, SubSession: item.SubSession, Cost: item.Cost} + } else { + items[i] = item + } + } + s.mu.RUnlock() + // Build session summary messages (vary per session) - summaryMessages, lastSummaryIndex := buildSessionSummaryMessages(s) + summaryMessages, lastSummaryIndex := buildSessionSummaryMessages(items) var messages []chat.Message messages = append(messages, invariantMessages...) @@ -654,8 +742,8 @@ func (s *Session) GetMessages(a *agent.Agent) []chat.Message { startIndex := lastSummaryIndex + 1 // Begin adding conversation messages - for i := startIndex; i < len(s.Messages); i++ { - item := s.Messages[i] + for i := startIndex; i < len(items); i++ { + item := items[i] if item.IsMessage() { messages = append(messages, item.Message.Message) } diff --git a/pkg/session/store.go b/pkg/session/store.go index 7e8d37adf..33fed9e60 100644 --- a/pkg/session/store.go +++ b/pkg/session/store.go @@ -195,16 +195,41 @@ func (s *InMemorySessionStore) UpdateSession(_ context.Context, session *Session return ErrEmptyID } - // Create a shallow copy of the session - newSession := *session - newSession.Messages = nil // Messages stored separately via AddMessage + // Build a new session with the same metadata but a fresh mutex. + // Messages are stored separately via AddMessage. + newSession := &Session{ + ID: session.ID, + Title: session.Title, + Evals: session.Evals, + CreatedAt: session.CreatedAt, + ToolsApproved: session.ToolsApproved, + Thinking: session.Thinking, + HideToolResults: session.HideToolResults, + WorkingDir: session.WorkingDir, + SendUserMessage: session.SendUserMessage, + MaxIterations: session.MaxIterations, + Starred: session.Starred, + InputTokens: session.InputTokens, + OutputTokens: session.OutputTokens, + Cost: session.Cost, + Permissions: session.Permissions, + AgentModelOverrides: session.AgentModelOverrides, + CustomModelsUsed: session.CustomModelsUsed, + BranchParentSessionID: session.BranchParentSessionID, + BranchParentPosition: session.BranchParentPosition, + BranchCreatedAt: session.BranchCreatedAt, + ParentID: session.ParentID, + } // Preserve existing messages if session already exists if existing, exists := s.sessions.Load(session.ID); exists { - newSession.Messages = existing.Messages + existing.mu.RLock() + newSession.Messages = make([]Item, len(existing.Messages)) + copy(newSession.Messages, existing.Messages) + existing.mu.RUnlock() } - s.sessions.Store(session.ID, &newSession) + s.sessions.Store(session.ID, newSession) return nil } @@ -240,18 +265,25 @@ func (s *InMemorySessionStore) AddMessage(_ context.Context, sessionID string, m // UpdateMessage updates an existing message by its ID. func (s *InMemorySessionStore) UpdateMessage(_ context.Context, messageID int64, msg *Message) error { + // Create a deep copy of the message to avoid mutating the caller's pointer, + // which may be shared with another Session object. + updated := deepCopyMessage(msg) + updated.ID = messageID + // For in-memory store, we need to find the message across all sessions var found bool s.sessions.Range(func(_ string, session *Session) bool { + session.mu.Lock() for i := range session.Messages { - if session.Messages[i].Message != nil && session.Messages[i].Message.ID == messageID { - // Preserve the message ID when updating - msg.ID = messageID - session.Messages[i].Message = msg - found = true - return false + if session.Messages[i].Message == nil || session.Messages[i].Message.ID != messageID { + continue } + session.Messages[i].Message = updated + found = true + session.mu.Unlock() + return false } + session.mu.Unlock() return true }) if !found { @@ -284,7 +316,9 @@ func (s *InMemorySessionStore) AddSummary(_ context.Context, sessionID, summary if !exists { return ErrNotFound } + session.mu.Lock() session.Messages = append(session.Messages, Item{Summary: summary}) + session.mu.Unlock() return nil }