Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions lib/httpapi/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type EventEmitter struct {
agentType mf.AgentType
chans map[int]chan Event
chanIdx int
subscriptionBufSize int
subscriptionBufSize uint
screen string
}

Expand All @@ -81,20 +81,37 @@ func convertStatus(status st.ConversationStatus) AgentStatus {
}
}

// subscriptionBufSize is the size of the buffer for each subscription.
// Once the buffer is full, the channel will be closed.
// Listeners must actively drain the channel, so it's important to
// set this to a value that is large enough to handle the expected
// number of events.
func NewEventEmitter(subscriptionBufSize int) *EventEmitter {
return &EventEmitter{
mu: sync.Mutex{},
const defaultSubscriptionBufSize uint = 1024

type EventEmitterOption func(*EventEmitter)

func WithSubscriptionBufSize(size uint) EventEmitterOption {
return func(e *EventEmitter) {
if size == 0 {
e.subscriptionBufSize = defaultSubscriptionBufSize
} else {
e.subscriptionBufSize = size
}
}
}

func WithAgentType(agentType mf.AgentType) EventEmitterOption {
return func(e *EventEmitter) {
e.agentType = agentType
}
}

func NewEventEmitter(opts ...EventEmitterOption) *EventEmitter {
e := &EventEmitter{
messages: make([]st.ConversationMessage, 0),
status: AgentStatusRunning,
chans: make(map[int]chan Event),
chanIdx: 0,
subscriptionBufSize: subscriptionBufSize,
subscriptionBufSize: defaultSubscriptionBufSize,
}
for _, opt := range opts {
opt(e)
}
return e
}

// Assumes the caller holds the lock.
Expand Down Expand Up @@ -122,7 +139,7 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) {

// Assumes that only the last message can change or new messages can be added.
// If a new message is injected between existing messages (identified by Id), the behavior is undefined.
func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.ConversationMessage) {
func (e *EventEmitter) EmitMessages(newMessages []st.ConversationMessage) {
e.mu.Lock()
defer e.mu.Unlock()

Expand All @@ -137,6 +154,9 @@ func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.Conversatio
newMsg = newMessages[i]
}
if oldMsg != newMsg {
if i >= len(newMessages) {
continue
}
e.notifyChannels(EventTypeMessageUpdate, MessageUpdateBody{
Id: newMessages[i].Id,
Role: newMessages[i].Role,
Expand All @@ -149,7 +169,7 @@ func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.Conversatio
e.messages = newMessages
}

func (e *EventEmitter) UpdateStatusAndEmitChanges(newStatus st.ConversationStatus, agentType mf.AgentType) {
func (e *EventEmitter) EmitStatus(newStatus st.ConversationStatus) {
e.mu.Lock()
defer e.mu.Unlock()

Expand All @@ -158,12 +178,11 @@ func (e *EventEmitter) UpdateStatusAndEmitChanges(newStatus st.ConversationStatu
return
}

e.notifyChannels(EventTypeStatusChange, StatusChangeBody{Status: newAgentStatus, AgentType: agentType})
e.notifyChannels(EventTypeStatusChange, StatusChangeBody{Status: newAgentStatus, AgentType: e.agentType})
e.status = newAgentStatus
e.agentType = agentType
}

func (e *EventEmitter) UpdateScreenAndEmitChanges(newScreen string) {
func (e *EventEmitter) EmitScreen(newScreen string) {
e.mu.Lock()
defer e.mu.Unlock()

Expand Down
19 changes: 9 additions & 10 deletions lib/httpapi/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ import (
"testing"
"time"

mf "github.com/coder/agentapi/lib/msgfmt"
st "github.com/coder/agentapi/lib/screentracker"
"github.com/stretchr/testify/assert"
)

func TestEventEmitter(t *testing.T) {
t.Run("single-subscription", func(t *testing.T) {
emitter := NewEventEmitter(10)
emitter := NewEventEmitter(WithSubscriptionBufSize(10))
_, ch, stateEvents := emitter.Subscribe()
assert.Empty(t, ch)
assert.Equal(t, []Event{
Expand All @@ -27,7 +26,7 @@ func TestEventEmitter(t *testing.T) {
}, stateEvents)

now := time.Now()
emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{
emitter.EmitMessages([]st.ConversationMessage{
{Id: 1, Message: "Hello, world!", Role: st.ConversationRoleUser, Time: now},
})
newEvent := <-ch
Expand All @@ -36,7 +35,7 @@ func TestEventEmitter(t *testing.T) {
Payload: MessageUpdateBody{Id: 1, Message: "Hello, world!", Role: st.ConversationRoleUser, Time: now},
}, newEvent)

emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{
emitter.EmitMessages([]st.ConversationMessage{
{Id: 1, Message: "Hello, world! (updated)", Role: st.ConversationRoleUser, Time: now},
{Id: 2, Message: "What's up?", Role: st.ConversationRoleAgent, Time: now},
})
Expand All @@ -52,24 +51,24 @@ func TestEventEmitter(t *testing.T) {
Payload: MessageUpdateBody{Id: 2, Message: "What's up?", Role: st.ConversationRoleAgent, Time: now},
}, newEvent)

emitter.UpdateStatusAndEmitChanges(st.ConversationStatusStable, mf.AgentTypeAider)
emitter.EmitStatus(st.ConversationStatusStable)
newEvent = <-ch
assert.Equal(t, Event{
Type: EventTypeStatusChange,
Payload: StatusChangeBody{Status: AgentStatusStable, AgentType: mf.AgentTypeAider},
Payload: StatusChangeBody{Status: AgentStatusStable, AgentType: ""},
}, newEvent)
})

t.Run("multiple-subscriptions", func(t *testing.T) {
emitter := NewEventEmitter(10)
emitter := NewEventEmitter(WithSubscriptionBufSize(10))
channels := make([]<-chan Event, 0, 10)
for i := 0; i < 10; i++ {
_, ch, _ := emitter.Subscribe()
channels = append(channels, ch)
}
now := time.Now()

emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{
emitter.EmitMessages([]st.ConversationMessage{
{Id: 1, Message: "Hello, world!", Role: st.ConversationRoleUser, Time: now},
})
for _, ch := range channels {
Expand All @@ -82,10 +81,10 @@ func TestEventEmitter(t *testing.T) {
})

t.Run("close-channel", func(t *testing.T) {
emitter := NewEventEmitter(1)
emitter := NewEventEmitter(WithSubscriptionBufSize(1))
_, ch, _ := emitter.Subscribe()
for i := range 5 {
emitter.UpdateMessagesAndEmitChanges([]st.ConversationMessage{
emitter.EmitMessages([]st.ConversationMessage{
{Id: i, Message: fmt.Sprintf("Hello, world! %d", i), Role: st.ConversationRoleUser, Time: time.Now()},
})
}
Expand Down
14 changes: 3 additions & 11 deletions lib/httpapi/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
return mf.FormatToolCall(config.AgentType, message)
}

emitter := NewEventEmitter(1024)
emitter := NewEventEmitter(WithAgentType(config.AgentType))

// Format initial prompt into message parts if provided
var initialPrompt []st.MessagePart
Expand All @@ -262,16 +262,8 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
ReadyForInitialPrompt: isAgentReadyForInitialPrompt,
FormatToolCall: formatToolCall,
InitialPrompt: initialPrompt,
// OnSnapshot uses a callback rather than passing the emitter directly
// to keep the screentracker package decoupled from httpapi concerns.
// This preserves clean package boundaries and avoids import cycles.
OnSnapshot: func(status st.ConversationStatus, messages []st.ConversationMessage, screen string) {
emitter.UpdateStatusAndEmitChanges(status, config.AgentType)
emitter.UpdateMessagesAndEmitChanges(messages)
emitter.UpdateScreenAndEmitChanges(screen)
},
Logger: logger,
})
Logger: logger,
}, emitter)

// Create temporary directory for uploads
tempDir, err := os.MkdirTemp("", "agentapi-uploads-")
Expand Down
7 changes: 7 additions & 0 deletions lib/screentracker/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ type Conversation interface {
Text() string
}

// Emitter receives conversation state updates.
type Emitter interface {
EmitMessages([]ConversationMessage)
EmitStatus(ConversationStatus)
EmitScreen(string)
}

type ConversationMessage struct {
Id int
Message string
Expand Down
26 changes: 17 additions & 9 deletions lib/screentracker/pty_conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ type PTYConversationConfig struct {
FormatToolCall func(message string) (string, []string)
// InitialPrompt is the initial prompt to send to the agent once ready
InitialPrompt []MessagePart
// OnSnapshot is called after each snapshot with current status, messages, and screen content
OnSnapshot func(status ConversationStatus, messages []ConversationMessage, screen string)
Logger *slog.Logger
Logger *slog.Logger
}

func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int {
Expand All @@ -86,7 +84,8 @@ func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int {
// PTYConversation is a conversation that uses a pseudo-terminal (PTY) for communication.
// It uses a combination of polling and diffs to detect changes in the screen.
type PTYConversation struct {
cfg PTYConversationConfig
cfg PTYConversationConfig
emitter Emitter
// How many stable snapshots are required to consider the screen stable
stableSnapshotsThreshold int
snapshotBuffer *RingBuffer[screenSnapshot]
Expand Down Expand Up @@ -115,13 +114,23 @@ type PTYConversation struct {

var _ Conversation = &PTYConversation{}

func NewPTY(ctx context.Context, cfg PTYConversationConfig) *PTYConversation {
type noopEmitter struct{}

func (noopEmitter) EmitMessages([]ConversationMessage) {}
func (noopEmitter) EmitStatus(ConversationStatus) {}
func (noopEmitter) EmitScreen(string) {}

func NewPTY(ctx context.Context, cfg PTYConversationConfig, emitter Emitter) *PTYConversation {
if cfg.Clock == nil {
cfg.Clock = quartz.NewReal()
}
if emitter == nil {
emitter = noopEmitter{}
}
threshold := cfg.getStableSnapshotsThreshold()
c := &PTYConversation{
cfg: cfg,
emitter: emitter,
stableSnapshotsThreshold: threshold,
snapshotBuffer: NewRingBuffer[screenSnapshot](threshold),
messages: []ConversationMessage{
Expand All @@ -139,9 +148,6 @@ func NewPTY(ctx context.Context, cfg PTYConversationConfig) *PTYConversation {
if len(cfg.InitialPrompt) > 0 {
c.outboundQueue <- outboundMessage{parts: cfg.InitialPrompt, errCh: nil}
}
if c.cfg.OnSnapshot == nil {
c.cfg.OnSnapshot = func(ConversationStatus, []ConversationMessage, string) {}
}
if c.cfg.ReadyForInitialPrompt == nil {
c.cfg.ReadyForInitialPrompt = func(string) bool { return true }
}
Expand Down Expand Up @@ -173,7 +179,9 @@ func (c *PTYConversation) Start(ctx context.Context) {
}
c.lock.Unlock()

c.cfg.OnSnapshot(status, messages, screen)
c.emitter.EmitStatus(status)
c.emitter.EmitMessages(messages)
c.emitter.EmitScreen(screen)
return nil
}, "snapshot")

Expand Down
Loading