Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pkg/leantui/leantui.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ type model struct {
status statusData
sessionState *service.SessionState

usageBySession map[string]usageSnapshot
rootSessionID string
latestUsageSessionID string
sessionStack []string

blocks []*block
busy bool
spinnerFrame int
Expand Down Expand Up @@ -202,6 +207,7 @@ func newModel(term *terminal, cfg Config) *model {
tools: make(map[string]*toolView),
status: statusData{workingDir: cfg.WorkingDir, branch: gitBranch(cfg.WorkingDir)},
sessionState: sessionState,
usageBySession: make(map[string]usageSnapshot),
appName: appName,
disabledCommands: disabled,
}
Expand Down
27 changes: 23 additions & 4 deletions pkg/leantui/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"strconv"
"strings"

"github.com/docker/docker-agent/pkg/tui/components/toolcommon"
)

// statusData is the snapshot of run state shown in the footer.
Expand All @@ -18,12 +20,21 @@ type statusData struct {
contextLength int64
contextLimit int64
tokens int64 // input + output tokens used so far
cost float64
costKnown bool
}

type usageSnapshot struct {
contextLength int64
contextLimit int64
tokens int64
cost float64
}

// renderStatus builds the two-line footer:
//
// <working dir> ⎇ <branch> <agent>
// <context bar> <pct> · <tokens> <model> · <effort>
// <context bar> <pct> · <tokens> · <cost> <model> · <effort>
func renderStatus(d statusData, width int) []string {
dir := stSecondary().Render(truncate(shortenPath(d.workingDir), max(10, width/2)))
left1 := dir
Expand Down Expand Up @@ -54,11 +65,12 @@ func renderStatus(d statusData, width int) []string {
}

func renderContext(d statusData) string {
cost := renderCostSuffix(d)
if d.contextLimit <= 0 {
if d.tokens > 0 {
return stMuted().Render(formatTokens(d.tokens) + " tokens")
return stMuted().Render(formatTokens(d.tokens)+" tokens") + cost
}
return stMuted().Render("context: —")
return renderBar(0) + stMuted().Render(" 0% · 0/0") + cost
}

pct := float64(d.contextLength) / float64(d.contextLimit)
Expand All @@ -71,7 +83,14 @@ func renderContext(d statusData) string {
formatTokens(d.contextLength),
formatTokens(d.contextLimit),
)
return bar + stMuted().Render(label)
return bar + stMuted().Render(label) + cost
}

func renderCostSuffix(d statusData) string {
if !d.costKnown {
return ""
}
return stMuted().Render(" · ") + stAccent().Render(toolcommon.FormatCostUSD(d.cost))
}

// contextBarWidth is the cell width of the context-usage gauge.
Expand Down
100 changes: 100 additions & 0 deletions pkg/leantui/status_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package leantui

import (
"strings"
"testing"

"github.com/stretchr/testify/assert"

"github.com/docker/docker-agent/pkg/runtime"
)

func TestFormatTokens(t *testing.T) {
Expand Down Expand Up @@ -40,6 +43,23 @@ func TestRenderBarWidth(t *testing.T) {
assert.Equal(t, contextBarWidth, displayWidth(renderBar(1.5))) // clamped
}

func TestRenderContextShowsZerosBeforeUsage(t *testing.T) {
t.Parallel()
out := renderContext(statusData{})
assert.NotContains(t, out, "context")
assert.Contains(t, out, "0% · 0/0")
}

func TestAgentInfoContextLimitShownBeforeUsage(t *testing.T) {
t.Parallel()
m := bareModel(24)

m.handleEvent(t.Context(), runtime.AgentInfo("root", "test/model", "", "", 200_000))

assert.Equal(t, int64(200_000), m.status.contextLimit)
assert.Contains(t, renderContext(m.status), "0% · 0/200.0k")
}

func TestRenderStatusFitsWidth(t *testing.T) {
t.Parallel()
d := statusData{
Expand All @@ -51,10 +71,90 @@ func TestRenderStatusFitsWidth(t *testing.T) {
contextLength: 24_000,
contextLimit: 200_000,
tokens: 24_000,
cost: 0.05,
costKnown: true,
}
lines := renderStatus(d, 80)
assert.Len(t, lines, 2)
assert.Contains(t, strings.Join(lines, "\n"), "$0.05")
for _, l := range lines {
assert.LessOrEqual(t, displayWidth(l), 80)
}
}

func TestTokenUsageEventAggregatesSessionCost(t *testing.T) {
t.Parallel()
m := bareModel(24)

m.handleEvent(t.Context(), runtime.StreamStarted("root-session", "root"))
m.handleEvent(t.Context(), runtime.NewTokenUsageEvent("root-session", "root", &runtime.Usage{
InputTokens: 2_000,
OutputTokens: 1_000,
ContextLength: 3_000,
ContextLimit: 10_000,
Cost: 0.10,
}))
m.handleEvent(t.Context(), runtime.StreamStarted("child-session", "developer"))
m.handleEvent(t.Context(), runtime.NewTokenUsageEvent("child-session", "developer", &runtime.Usage{
InputTokens: 800,
OutputTokens: 200,
ContextLength: 1_000,
ContextLimit: 20_000,
Cost: 0.05,
}))

assert.Equal(t, int64(1_000), m.status.tokens)
assert.InDelta(t, 0.15, m.status.cost, 0.0001)
assert.True(t, m.status.costKnown)
assert.Contains(t, strings.Join(renderStatus(m.status, 80), "\n"), "$0.15")

m.handleEvent(t.Context(), runtime.StreamStopped("child-session", "developer", "normal"))

assert.Equal(t, int64(3_000), m.status.tokens)
assert.InDelta(t, 0.15, m.status.cost, 0.0001)
}

func TestTokenUsageBeforeStreamUsesFirstSessionAsRoot(t *testing.T) {
t.Parallel()
m := bareModel(24)

m.handleEvent(t.Context(), runtime.NewTokenUsageEvent("root-session", "root", &runtime.Usage{
InputTokens: 2_000,
OutputTokens: 1_000,
ContextLength: 3_000,
ContextLimit: 10_000,
Cost: 0.10,
}))
m.handleEvent(t.Context(), runtime.NewTokenUsageEvent("child-session", "developer", &runtime.Usage{
InputTokens: 800,
OutputTokens: 200,
ContextLength: 1_000,
ContextLimit: 20_000,
Cost: 0.05,
}))

assert.Equal(t, "root-session", m.rootSessionID)
assert.Equal(t, int64(3_000), m.status.tokens)
assert.InDelta(t, 0.15, m.status.cost, 0.0001)
}

func TestEmptySessionUsageDoesNotOverrideSessionScopedUsage(t *testing.T) {
t.Parallel()
m := bareModel(24)

m.handleEvent(t.Context(), runtime.NewTokenUsageEvent("root-session", "root", &runtime.Usage{
InputTokens: 2_000,
OutputTokens: 1_000,
ContextLength: 3_000,
ContextLimit: 10_000,
Cost: 0.10,
}))
m.handleEvent(t.Context(), runtime.NewTokenUsageEvent("", "root", &runtime.Usage{
InputTokens: 50,
ContextLength: 50,
Cost: 0.99,
}))

assert.Equal(t, int64(3_000), m.status.tokens)
assert.InDelta(t, 0.10, m.status.cost, 0.0001)
}
17 changes: 9 additions & 8 deletions pkg/leantui/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ func bareModel(height int) *model {
var buf bytes.Buffer
w := bufio.NewWriter(&buf)
return &model{
width: width,
height: height,
r: newRenderer(w, width, height),
editor: newEditor("type here"),
ac: newAutocomplete(),
tools: map[string]*toolView{},
status: statusData{workingDir: "/tmp/project"},
sessionState: service.NewSessionState(nil),
width: width,
height: height,
r: newRenderer(w, width, height),
editor: newEditor("type here"),
ac: newAutocomplete(),
tools: map[string]*toolView{},
status: statusData{workingDir: "/tmp/project"},
sessionState: service.NewSessionState(nil),
usageBySession: map[string]usageSnapshot{},
}
}

Expand Down
118 changes: 112 additions & 6 deletions pkg/leantui/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ func (m *model) handleEvent(ctx context.Context, ev any) {
switch e := ev.(type) {
case *runtime.StreamStartedEvent:
m.busy = true
m.trackStreamStarted(e.SessionID)
case *runtime.StreamStoppedEvent:
m.trackStreamStopped()
m.handleStreamStopped(ctx)
case *runtime.AgentChoiceReasoningEvent:
m.appendPending(blockReasoning, e.Content)
Expand Down Expand Up @@ -314,11 +316,7 @@ func (m *model) handleEvent(ctx context.Context, ev any) {
toolView: *newToolView(e.GetAgentName(), e.ToolCall, toolDef, tuitypes.ToolStatusConfirmation),
}
case *runtime.TokenUsageEvent:
if e.Usage != nil {
m.status.contextLength = e.Usage.ContextLength
m.status.contextLimit = e.Usage.ContextLimit
m.status.tokens = e.Usage.InputTokens + e.Usage.OutputTokens
}
m.setTokenUsage(e.SessionID, e.Usage)
case *runtime.AgentInfoEvent:
m.status.agent = e.AgentName
if m.sessionState != nil {
Expand All @@ -327,6 +325,9 @@ func (m *model) handleEvent(ctx context.Context, ev any) {
if e.Model != "" {
m.status.model = e.Model
}
if e.ContextLimit > 0 {
m.status.contextLimit = e.ContextLimit
}
case *runtime.TeamInfoEvent:
m.applyTeamInfo(ctx, e)
case *runtime.SessionCompactionEvent:
Expand Down Expand Up @@ -355,11 +356,107 @@ func (m *model) handleStreamStopped(ctx context.Context) {
return
}

if m.app.ShouldExitAfterFirstResponse() {
if m.app != nil && m.app.ShouldExitAfterFirstResponse() {
m.quit()
}
}

func (m *model) trackStreamStarted(sessionID string) {
if sessionID == "" {
return
}
if len(m.sessionStack) == 0 {
m.rootSessionID = sessionID
}
m.sessionStack = append(m.sessionStack, sessionID)
m.applyUsageSnapshot()
}

func (m *model) trackStreamStopped() {
if n := len(m.sessionStack); n > 0 {
m.sessionStack = m.sessionStack[:n-1]
}
m.applyUsageSnapshot()
}

func (m *model) setTokenUsage(sessionID string, usage *runtime.Usage) {
if usage == nil {
return
}

snapshot := usageSnapshot{
contextLength: usage.ContextLength,
contextLimit: usage.ContextLimit,
tokens: usage.InputTokens + usage.OutputTokens,
cost: usage.Cost,
}
if sessionID == "" {
Comment thread
rumpl marked this conversation as resolved.
// Once session-scoped usage exists, it is authoritative for the chat
// footer. Empty-session usage comes from side work such as RAG indexing.
if len(m.usageBySession) == 0 {
m.applyStatusUsage(snapshot, usage.Cost, true)
}
return
}
if m.usageBySession == nil {
m.usageBySession = make(map[string]usageSnapshot)
}
if m.rootSessionID == "" && len(m.usageBySession) == 0 {
m.rootSessionID = sessionID
}
m.usageBySession[sessionID] = snapshot
m.latestUsageSessionID = sessionID
m.applyUsageSnapshot()
}

func (m *model) applyUsageSnapshot() {
if len(m.usageBySession) == 0 {
return
}

var totalCost float64
for _, usage := range m.usageBySession {
totalCost += usage.cost
}

if usage, ok := m.activeUsage(); ok {
m.applyStatusUsage(usage, totalCost, true)
return
}

m.status.cost = totalCost
m.status.costKnown = true
}

func (m *model) activeUsage() (usageSnapshot, bool) {
if n := len(m.sessionStack); n > 0 {
usage, ok := m.usageBySession[m.sessionStack[n-1]]
return usage, ok
}
if m.rootSessionID != "" {
usage, ok := m.usageBySession[m.rootSessionID]
return usage, ok
}
if m.latestUsageSessionID != "" {
usage, ok := m.usageBySession[m.latestUsageSessionID]
return usage, ok
}
if len(m.usageBySession) == 1 {
for _, usage := range m.usageBySession {
return usage, true
}
}
return usageSnapshot{}, false
}

func (m *model) applyStatusUsage(usage usageSnapshot, cost float64, costKnown bool) {
m.status.contextLength = usage.contextLength
m.status.contextLimit = usage.contextLimit
m.status.tokens = usage.tokens
m.status.cost = cost
m.status.costKnown = costKnown
}

func (m *model) handleSessionCompaction(ctx context.Context, e *runtime.SessionCompactionEvent) {
switch e.Status {
case "started":
Expand Down Expand Up @@ -571,6 +668,15 @@ func (m *model) resetConversation() {
m.queue = nil
m.busy = false
m.confirm = nil
m.usageBySession = make(map[string]usageSnapshot)
m.rootSessionID = ""
m.latestUsageSessionID = ""
m.sessionStack = nil
m.status.contextLength = 0
m.status.contextLimit = 0
m.status.tokens = 0
m.status.cost = 0
m.status.costKnown = false
}

func (m *model) clearScreen() {
Expand Down
Loading
Loading