Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cmd/root/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (
func TestDefaultToRun(t *testing.T) {
t.Parallel()

rootCmd := NewRootCmd()

tests := []struct {
name string
args []string
Expand Down Expand Up @@ -102,7 +100,7 @@ func TestDefaultToRun(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

got := defaultToRun(rootCmd, tt.args)
got := defaultToRun(NewRootCmd(), tt.args)
assert.Equal(t, tt.want, got)
})
}
Expand Down
41 changes: 35 additions & 6 deletions pkg/fake/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,28 @@ func TestSimulatedStreamCopy_SSEEvents(t *testing.T) {
assert.GreaterOrEqual(t, elapsed, 3*chunkDelay, "should have delays between data chunks")
}

// notifyWriter wraps an http.ResponseWriter and signals on first Write.
type notifyWriter struct {
http.ResponseWriter
notify chan struct{}
notified bool
}

func (w *notifyWriter) Write(p []byte) (int, error) {
n, err := w.ResponseWriter.Write(p)
if n > 0 && !w.notified {
w.notified = true
close(w.notify)
}
return n, err
}

func (w *notifyWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}

func TestSimulatedStreamCopy_ContextCancellation(t *testing.T) {
// Create a reader that provides some data then blocks
// to allow context cancellation to be tested
Expand All @@ -321,17 +343,24 @@ func TestSimulatedStreamCopy_ContextCancellation(t *testing.T) {
rec := httptest.NewRecorder()
ctx, cancel := context.WithCancel(t.Context())
req = req.WithContext(ctx)
c := e.NewContext(req, rec)

// Wrap the recorder so we get notified when the first chunk is written,
// without racing on rec.Body.
firstWrite := make(chan struct{})
nw := &notifyWriter{ResponseWriter: rec, notify: firstWrite}
c := e.NewContext(req, nw)

done := make(chan error, 1)
go func() {
done <- SimulatedStreamCopy(c, resp, 10*time.Millisecond)
}()

// Wait until at least the first chunk has been written to the recorder
require.Eventually(t, func() bool {
return rec.Body.Len() > 0
}, time.Second, 5*time.Millisecond, "expected first chunk to be written")
// Wait until the first chunk has been written to the recorder.
select {
case <-firstWrite:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first chunk to be written")
}

// Cancel the context and close the body (simulating client disconnect)
cancel()
Expand All @@ -347,6 +376,6 @@ func TestSimulatedStreamCopy_ContextCancellation(t *testing.T) {
t.Fatal("SimulatedStreamCopy did not return after context cancellation")
}

// Verify first chunk was written
// Verify first chunk was written (safe to read after goroutine finished)
assert.Contains(t, rec.Body.String(), "data: first")
}
106 changes: 97 additions & 9 deletions pkg/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"log/slog"
"os"
"strings"
"sync"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -51,6 +52,9 @@ func (si *Item) IsSubSession() bool {

// Session represents the agent's state including conversation history and variables
type Session struct {
// mu protects Messages from concurrent read/write access.
mu sync.RWMutex `json:"-"`

// ID is the unique identifier for the session
ID string `json:"id"`

Expand Down Expand Up @@ -216,16 +220,67 @@ type EvalCriteria struct {
Setup string `json:"setup,omitempty"` // Optional sh script to run in the container before cagent run --exec
}

// deepCopyMessage returns a deep copy of a session Message.
// It copies the inner chat.Message's slice and pointer fields so that the
// returned value shares no mutable state with the original.
func deepCopyMessage(m *Message) *Message {
cp := *m
cp.Message = deepCopyChatMessage(m.Message)
return &cp
}

// deepCopyChatMessage returns a deep copy of a chat.Message, duplicating
// all slice and pointer fields that would otherwise alias the original.
func deepCopyChatMessage(m chat.Message) chat.Message {
if m.MultiContent != nil {
orig := m.MultiContent
m.MultiContent = make([]chat.MessagePart, len(orig))
for i, part := range orig {
if part.ImageURL != nil {
imgCopy := *part.ImageURL
part.ImageURL = &imgCopy
}
if part.File != nil {
fileCopy := *part.File
part.File = &fileCopy
}
m.MultiContent[i] = part
}
}
if m.FunctionCall != nil {
fcCopy := *m.FunctionCall
m.FunctionCall = &fcCopy
}
if m.ToolCalls != nil {
m.ToolCalls = append([]tools.ToolCall(nil), m.ToolCalls...)
}
if m.ToolDefinitions != nil {
m.ToolDefinitions = append([]tools.Tool(nil), m.ToolDefinitions...)
}
if m.Usage != nil {
usageCopy := *m.Usage
m.Usage = &usageCopy
}
if m.ThoughtSignature != nil {
m.ThoughtSignature = append([]byte(nil), m.ThoughtSignature...)
}
return m
}

// Session helper methods

// AddMessage adds a message to the session
func (s *Session) AddMessage(msg *Message) {
s.mu.Lock()
s.Messages = append(s.Messages, NewMessageItem(msg))
s.mu.Unlock()
}

// AddSubSession adds a sub-session to the session
func (s *Session) AddSubSession(subSession *Session) {
s.mu.Lock()
s.Messages = append(s.Messages, NewSubSessionItem(subSession))
s.mu.Unlock()
}

// Duration calculates the duration of the session from message timestamps.
Expand Down Expand Up @@ -258,8 +313,19 @@ func (s *Session) AllowedDirectories() []string {

// GetAllMessages extracts all messages from the session, including from sub-sessions
func (s *Session) GetAllMessages() []Message {
s.mu.RLock()
items := make([]Item, len(s.Messages))
for i, item := range s.Messages {
if item.Message != nil {
items[i] = Item{Message: deepCopyMessage(item.Message)}
} else {
items[i] = item
}
}
s.mu.RUnlock()

var messages []Message
for _, item := range s.Messages {
for _, item := range items {
if item.IsMessage() && item.Message.Message.Role != chat.MessageRoleSystem {
messages = append(messages, *item.Message)
} else if item.IsSubSession() {
Expand Down Expand Up @@ -408,6 +474,9 @@ func (s *Session) IsSubSession() bool {

// MessageCount returns the number of items that contain a message.
func (s *Session) MessageCount() int {
s.mu.RLock()
defer s.mu.RUnlock()

n := 0
for _, item := range s.Messages {
if item.IsMessage() {
Expand All @@ -421,6 +490,9 @@ func (s *Session) MessageCount() int {
// sub-sessions, and summary items. It does not use the session-level Cost
// field, which exists only for backward-compatible persistence.
func (s *Session) TotalCost() float64 {
s.mu.RLock()
defer s.mu.RUnlock()

var cost float64
for _, item := range s.Messages {
switch {
Expand All @@ -439,6 +511,9 @@ func (s *Session) TotalCost() float64 {
// This is used for live event emissions where sub-sessions report their
// own costs separately.
func (s *Session) OwnCost() float64 {
s.mu.RLock()
defer s.mu.RUnlock()

var cost float64
for _, item := range s.Messages {
if item.IsMessage() {
Expand Down Expand Up @@ -609,22 +684,22 @@ func buildContextSpecificSystemMessages(a *agent.Agent, s *Session) []chat.Messa
// if one exists. Session summaries are context-specific per session and thus should not have a checkpoint (they will be cached alongside the first user message anyway)
//
// lastSummaryIndex is the index of the last summary item in s.Messages, or -1 if none exists.
func buildSessionSummaryMessages(s *Session) ([]chat.Message, int) {
func buildSessionSummaryMessages(items []Item) ([]chat.Message, int) {
var messages []chat.Message
// Find the last summary index to determine where conversation messages start
// and to include the summary in session summary messages
lastSummaryIndex := -1
for i := len(s.Messages) - 1; i >= 0; i-- {
if s.Messages[i].Summary != "" {
for i := len(items) - 1; i >= 0; i-- {
if items[i].Summary != "" {
lastSummaryIndex = i
break
}
}

if lastSummaryIndex >= 0 && lastSummaryIndex < len(s.Messages) {
if lastSummaryIndex >= 0 && lastSummaryIndex < len(items) {
messages = append(messages, chat.Message{
Role: chat.MessageRoleUser,
Content: "Session Summary: " + s.Messages[lastSummaryIndex].Summary,
Content: "Session Summary: " + items[lastSummaryIndex].Summary,
CreatedAt: time.Now().Format(time.RFC3339),
})
}
Expand All @@ -643,8 +718,21 @@ func (s *Session) GetMessages(a *agent.Agent) []chat.Message {
contextMessages := buildContextSpecificSystemMessages(a, s)
markLastMessageAsCacheControl(contextMessages)

// Take a snapshot of Messages under the lock, copying Message structs
// to avoid racing with UpdateMessage which may modify the pointed-to objects.
s.mu.RLock()
items := make([]Item, len(s.Messages))
for i, item := range s.Messages {
if item.Message != nil {
items[i] = Item{Message: deepCopyMessage(item.Message), Summary: item.Summary, SubSession: item.SubSession, Cost: item.Cost}
} else {
items[i] = item
}
}
s.mu.RUnlock()

// Build session summary messages (vary per session)
summaryMessages, lastSummaryIndex := buildSessionSummaryMessages(s)
summaryMessages, lastSummaryIndex := buildSessionSummaryMessages(items)

var messages []chat.Message
messages = append(messages, invariantMessages...)
Expand All @@ -654,8 +742,8 @@ func (s *Session) GetMessages(a *agent.Agent) []chat.Message {
startIndex := lastSummaryIndex + 1

// Begin adding conversation messages
for i := startIndex; i < len(s.Messages); i++ {
item := s.Messages[i]
for i := startIndex; i < len(items); i++ {
item := items[i]
if item.IsMessage() {
messages = append(messages, item.Message.Message)
}
Expand Down
56 changes: 45 additions & 11 deletions pkg/session/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,41 @@ func (s *InMemorySessionStore) UpdateSession(_ context.Context, session *Session
return ErrEmptyID
}

// Create a shallow copy of the session
newSession := *session
newSession.Messages = nil // Messages stored separately via AddMessage
// Build a new session with the same metadata but a fresh mutex.
// Messages are stored separately via AddMessage.
newSession := &Session{
ID: session.ID,
Title: session.Title,
Evals: session.Evals,
CreatedAt: session.CreatedAt,
ToolsApproved: session.ToolsApproved,
Thinking: session.Thinking,
HideToolResults: session.HideToolResults,
WorkingDir: session.WorkingDir,
SendUserMessage: session.SendUserMessage,
MaxIterations: session.MaxIterations,
Starred: session.Starred,
InputTokens: session.InputTokens,
OutputTokens: session.OutputTokens,
Cost: session.Cost,
Permissions: session.Permissions,
AgentModelOverrides: session.AgentModelOverrides,
CustomModelsUsed: session.CustomModelsUsed,
BranchParentSessionID: session.BranchParentSessionID,
BranchParentPosition: session.BranchParentPosition,
BranchCreatedAt: session.BranchCreatedAt,
ParentID: session.ParentID,
}

// Preserve existing messages if session already exists
if existing, exists := s.sessions.Load(session.ID); exists {
newSession.Messages = existing.Messages
existing.mu.RLock()
newSession.Messages = make([]Item, len(existing.Messages))
copy(newSession.Messages, existing.Messages)
existing.mu.RUnlock()
}

s.sessions.Store(session.ID, &newSession)
s.sessions.Store(session.ID, newSession)
return nil
}

Expand Down Expand Up @@ -240,18 +265,25 @@ func (s *InMemorySessionStore) AddMessage(_ context.Context, sessionID string, m

// UpdateMessage updates an existing message by its ID.
func (s *InMemorySessionStore) UpdateMessage(_ context.Context, messageID int64, msg *Message) error {
// Create a deep copy of the message to avoid mutating the caller's pointer,
// which may be shared with another Session object.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 MEDIUM SEVERITY: Shallow copy in UpdateMessage shares data with caller

updated := *msg creates a shallow copy of the Message struct, but the chat.Message field contains slices/pointers that remain shared with the caller's msg parameter. If the caller modifies msg after this call, it races with readers accessing the stored message.

Fix: Deep copy the chat.Message field when creating updated, or document that callers must not modify msg after calling this function.

updated := deepCopyMessage(msg)
updated.ID = messageID

// For in-memory store, we need to find the message across all sessions
var found bool
s.sessions.Range(func(_ string, session *Session) bool {
session.mu.Lock()
for i := range session.Messages {
if session.Messages[i].Message != nil && session.Messages[i].Message.ID == messageID {
// Preserve the message ID when updating
msg.ID = messageID
session.Messages[i].Message = msg
found = true
return false
if session.Messages[i].Message == nil || session.Messages[i].Message.ID != messageID {
continue
}
session.Messages[i].Message = updated
found = true
session.mu.Unlock()
return false
}
session.mu.Unlock()
return true
})
if !found {
Expand Down Expand Up @@ -284,7 +316,9 @@ func (s *InMemorySessionStore) AddSummary(_ context.Context, sessionID, summary
if !exists {
return ErrNotFound
}
session.mu.Lock()
session.Messages = append(session.Messages, Item{Summary: summary})
session.mu.Unlock()
return nil
}

Expand Down