Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 4 additions & 83 deletions cmd/root/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@ import (
"fmt"
"log/slog"
"os"
"path/filepath"
"time"

"github.com/spf13/cobra"

"github.com/docker/cagent/pkg/agentfile"
"github.com/docker/cagent/pkg/cli"
"github.com/docker/cagent/pkg/config"
"github.com/docker/cagent/pkg/remote"
"github.com/docker/cagent/pkg/server"
"github.com/docker/cagent/pkg/session"
"github.com/docker/cagent/pkg/teamloader"
"github.com/docker/cagent/pkg/telemetry"
)

Expand Down Expand Up @@ -73,96 +69,21 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) error {

slog.Debug("Starting server", "agents", agentsPath, "addr", ln.Addr().String())

resolvedPath, err := agentfile.Resolve(ctx, out, agentsPath)
if err != nil {
return err
}

sessionStore, err := session.NewSQLiteSessionStore(f.sessionDB)
if err != nil {
return fmt.Errorf("failed to create session store: %w", err)
}

var opts []server.Opt

if !agentfile.IsOCIReference(agentsPath) {
stat, err := os.Stat(resolvedPath)
if err != nil {
return fmt.Errorf("failed to stat agents path: %w", err)
}
if stat.IsDir() {
// For directories: only set agentsDir, not agentsPath
opts = append(opts, server.WithAgentsDir(resolvedPath))
} else {
opts = append(opts, server.WithAgentsPath(resolvedPath), server.WithAgentsDir(filepath.Dir(resolvedPath)))
}
}

teams, err := teamloader.LoadTeams(ctx, resolvedPath, &f.runConfig)
sources, err := agentfile.ResolveSources(ctx, nil, agentsPath)
if err != nil {
return fmt.Errorf("failed to load teams: %w", err)
}

// For OCI refs: store the reference for later per-session reloading, then clean up temp file
if agentfile.IsOCIReference(agentsPath) {
teamKey := filepath.Base(resolvedPath)
opts = append(opts, server.WithOCIRef(teamKey, agentsPath))

if err := os.Remove(resolvedPath); err != nil {
slog.Warn("Failed to remove temporary OCI file", "path", resolvedPath, "error", err)
} else {
slog.Debug("Cleaned up temporary OCI file", "path", resolvedPath)
}
return fmt.Errorf("failed to resolve agent sources: %w", err)
}

defer func() {
for _, team := range teams {
if err := team.StopToolSets(ctx); err != nil {
slog.Error("Failed to stop tool sets", "error", err)
}
}
}()

s, err := server.New(sessionStore, &f.runConfig, teams, opts...)
s, err := server.New(sessionStore, &f.runConfig, sources)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
}

// Start background auto-pull for OCI references if enabled
if f.pullIntervalMins > 0 {
go func() {
ticker := time.NewTicker(time.Duration(f.pullIntervalMins) * time.Minute)
defer ticker.Stop()

slog.Info("Auto-pull enabled for OCI reference", "reference", agentsPath, "interval_minutes", f.pullIntervalMins)

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
slog.Info("Auto-pulling OCI reference", "reference", agentsPath)
if _, err := remote.Pull(ctx, agentsPath, false); err != nil {
slog.Error("Failed to auto-pull OCI reference", "reference", agentsPath, "error", err)
continue
}

// Resolve the OCI reference to get the updated file path
newResolvedPath, err := agentfile.Resolve(ctx, out, agentsPath)
if err != nil {
slog.Error("Failed to resolve OCI reference after pull", "reference", agentsPath, "error", err)
continue
}

if err := s.ReloadTeams(ctx, newResolvedPath); err != nil {
slog.Error("Failed to reload teams", "reference", agentsPath, "error", err)
} else {
slog.Info("Successfully reloaded teams from updated OCI reference", "reference", agentsPath)
}
}
}
}()
}
// TODO(rumpl): implement pull interval

return s.Serve(ctx, ln)
}
8 changes: 6 additions & 2 deletions cmd/root/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"github.com/google/go-containerregistry/pkg/crane"
"github.com/spf13/cobra"

"github.com/docker/cagent/pkg/agentfile"
"github.com/docker/cagent/pkg/cli"
"github.com/docker/cagent/pkg/content"
"github.com/docker/cagent/pkg/remote"
"github.com/docker/cagent/pkg/telemetry"
)
Expand Down Expand Up @@ -52,7 +52,11 @@ func (f *pullFlags) runPullCommand(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to pull artifact: %w", err)
}

yamlFile, err := agentfile.FromStore(registryRef)
store, err := content.NewStore()
if err != nil {
return fmt.Errorf("failed to open content store: %w", err)
}
yamlFile, err := store.GetArtifact(registryRef)
if err != nil {
return fmt.Errorf("failed to get agent yaml: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/acp/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
type Agent struct {
conn *acp.AgentSideConnection
team *team.Team
agentFilename string
source teamloader.AgentSource
runtimeConfig *config.RuntimeConfig
sessions map[string]*Session
mu sync.Mutex
Expand All @@ -39,9 +39,9 @@ type Session struct {
}

// NewAgent creates a new ACP agent
func NewAgent(agentFilename string, runtimeConfig *config.RuntimeConfig) *Agent {
func NewAgent(source teamloader.AgentSource, runtimeConfig *config.RuntimeConfig) *Agent {
return &Agent{
agentFilename: agentFilename,
source: source,
runtimeConfig: runtimeConfig,
sessions: make(map[string]*Session),
}
Expand Down Expand Up @@ -69,7 +69,7 @@ func (a *Agent) Initialize(ctx context.Context, params acp.InitializeRequest) (a

a.mu.Lock()
defer a.mu.Unlock()
t, err := teamloader.Load(ctx, a.agentFilename, a.runtimeConfig, teamloader.WithToolsetRegistry(createToolsetRegistry(a)))
t, err := teamloader.LoadFrom(ctx, a.source, a.runtimeConfig, teamloader.WithToolsetRegistry(createToolsetRegistry(a)))
if err != nil {
return acp.InitializeResponse{}, fmt.Errorf("failed to load teams: %w", err)
}
Expand Down
8 changes: 2 additions & 6 deletions pkg/acp/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,15 @@ import (
"github.com/docker/cagent/pkg/config"
)

type discardOutput struct{}

func (d *discardOutput) Printf(string, ...any) {}

func Run(ctx context.Context, agentFilename string, stdin io.Reader, stdout io.Writer, runConfig *config.RuntimeConfig) error {
slog.Debug("Starting ACP server", "agent", agentFilename)

agentFilename, err := agentfile.Resolve(ctx, &discardOutput{}, agentFilename)
source, err := agentfile.ResolveSource(ctx, nil, agentFilename)
if err != nil {
return err
}

acpAgent := NewAgent(agentFilename, runConfig)
acpAgent := NewAgent(source, runConfig)
conn := acpsdk.NewAgentSideConnection(acpAgent, stdout, stdin)
conn.SetLogger(slog.Default())
acpAgent.SetAgentConnection(conn)
Expand Down
Loading
Loading