diff --git a/cmd/server/server.go b/cmd/server/server.go index 3afe050..6a7fa7f 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -4,12 +4,14 @@ import ( "context" "errors" "fmt" + "io" "log/slog" "net/http" "os" "sort" "strings" + "github.com/mattn/go-isatty" "github.com/spf13/cobra" "github.com/spf13/viper" "golang.org/x/xerrors" @@ -88,6 +90,19 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er return xerrors.Errorf("term height must be at least 10") } + // Read stdin if it's piped, to be used as initial prompt + initialPrompt := viper.GetString(FlagInitialPrompt) + if initialPrompt == "" { + if !isatty.IsTerminal(os.Stdin.Fd()) { + if stdinData, err := io.ReadAll(os.Stdin); err != nil { + return xerrors.Errorf("failed to read stdin: %w", err) + } else if len(stdinData) > 0 { + initialPrompt = string(stdinData) + logger.Info("Read initial prompt from stdin", "bytes", len(stdinData)) + } + } + } + printOpenAPI := viper.GetBool(FlagPrintOpenAPI) var process *termexec.Process if printOpenAPI { @@ -112,7 +127,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er ChatBasePath: viper.GetString(FlagChatBasePath), AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), - InitialPrompt: viper.GetString(FlagInitialPrompt), + InitialPrompt: initialPrompt, }) if err != nil { return xerrors.Errorf("failed to create server: %w", err) @@ -213,7 +228,7 @@ func CreateServerCmd() *cobra.Command { {FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"}, // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, - {FlagInitialPrompt, "I", "", "Initial prompt for the agent (recommended only if the agent doesn't support initial prompt in interaction mode)", "string"}, + {FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"}, } for _, spec := range flagSpecs { diff --git a/e2e/echo_test.go b/e2e/echo_test.go index 5784027..eb30294 100644 --- a/e2e/echo_test.go +++ b/e2e/echo_test.go @@ -22,7 +22,7 @@ import ( const ( testTimeout = 30 * time.Second - operationTimeout = 5 * time.Second + operationTimeout = 10 * time.Second healthCheckTimeout = 10 * time.Second ) @@ -40,15 +40,14 @@ func TestE2E(t *testing.T) { t.Run("basic", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - script, apiClient := setup(ctx, t) - require.NoError(t, waitAgentAPIStable(ctx, apiClient, operationTimeout)) + script, apiClient := setup(ctx, t, nil) messageReq := agentapisdk.PostMessageParams{ Content: "This is a test message.", Type: agentapisdk.MessageTypeUser, } _, err := apiClient.PostMessage(ctx, messageReq) require.NoError(t, err, "Failed to send message via SDK") - require.NoError(t, waitAgentAPIStable(ctx, apiClient, operationTimeout)) + require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "post message")) msgResp, err := apiClient.GetMessages(ctx) require.NoError(t, err, "Failed to get messages via SDK") require.Len(t, msgResp.Messages, 3) @@ -61,7 +60,7 @@ func TestE2E(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - script, apiClient := setup(ctx, t) + script, apiClient := setup(ctx, t, nil) messageReq := agentapisdk.PostMessageParams{ Content: "What is the answer to life, the universe, and everything?", Type: agentapisdk.MessageTypeUser, @@ -71,7 +70,7 @@ func TestE2E(t *testing.T) { statusResp, err := apiClient.GetStatus(ctx) require.NoError(t, err) require.Equal(t, agentapisdk.StatusRunning, statusResp.Status) - require.NoError(t, waitAgentAPIStable(ctx, apiClient, 5*time.Second)) + require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, 5*time.Second, "post message")) msgResp, err := apiClient.GetMessages(ctx) require.NoError(t, err, "Failed to get messages via SDK") require.Len(t, msgResp.Messages, 3) @@ -82,11 +81,45 @@ func TestE2E(t *testing.T) { require.Equal(t, script[1].ResponseMessage, strings.TrimSpace(parts[0])) require.Equal(t, script[2].ResponseMessage, strings.TrimSpace(parts[1])) }) + + t.Run("stdin", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + script, apiClient := setup(ctx, t, ¶ms{ + cmdFn: func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) { + defCmd, defArgs := defaultCmdFn(ctx, t, serverPort, binaryPath, cwd, scriptFilePath) + script := fmt.Sprintf(`echo "hello agent" | %s %s`, defCmd, strings.Join(defArgs, " ")) + return "/bin/sh", []string{"-c", script} + }, + }) + require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, 5*time.Second, "stdin")) + msgResp, err := apiClient.GetMessages(ctx) + require.NoError(t, err, "Failed to get messages via SDK") + require.Len(t, msgResp.Messages, 3) + require.Equal(t, script[0].ExpectMessage, strings.TrimSpace(msgResp.Messages[1].Content)) + require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[2].Content)) + }) } -func setup(ctx context.Context, t testing.TB) ([]ScriptEntry, *agentapisdk.Client) { +type params struct { + cmdFn func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) +} + +func defaultCmdFn(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) { + return binaryPath, []string{"server", fmt.Sprintf("--port=%d", serverPort), "--", "go", "run", filepath.Join(cwd, "echo.go"), scriptFilePath} +} + +func setup(ctx context.Context, t testing.TB, p *params) ([]ScriptEntry, *agentapisdk.Client) { t.Helper() + if p == nil { + p = ¶ms{} + } + if p.cmdFn == nil { + p.cmdFn = defaultCmdFn + } + scriptFilePath := filepath.Join("testdata", filepath.Base(t.Name())+".json") data, err := os.ReadFile(scriptFilePath) require.NoError(t, err, "Failed to read test script file: %s", scriptFilePath) @@ -116,10 +149,9 @@ func setup(ctx context.Context, t testing.TB) ([]ScriptEntry, *agentapisdk.Clien cwd, err := os.Getwd() require.NoError(t, err, "Failed to get current working directory") - cmd := exec.CommandContext(ctx, binaryPath, "server", - fmt.Sprintf("--port=%d", serverPort), - "--", - "go", "run", filepath.Join(cwd, "echo.go"), scriptFilePath) + bin, args := p.cmdFn(ctx, t, serverPort, binaryPath, cwd, scriptFilePath) + t.Logf("Running command: %s %s", bin, strings.Join(args, " ")) + cmd := exec.CommandContext(ctx, bin, args...) // Capture output for debugging stdout, err := cmd.StdoutPipe() @@ -160,7 +192,7 @@ func setup(ctx context.Context, t testing.TB) ([]ScriptEntry, *agentapisdk.Clien apiClient, err := agentapisdk.NewClient(serverURL) require.NoError(t, err, "Failed to create agentapi SDK client") - require.NoError(t, waitAgentAPIStable(ctx, apiClient, operationTimeout)) + require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "setup")) return script, apiClient } @@ -198,21 +230,30 @@ func waitForServer(ctx context.Context, t testing.TB, url string, timeout time.D } } -func waitAgentAPIStable(ctx context.Context, apiClient *agentapisdk.Client, waitFor time.Duration) error { +func waitAgentAPIStable(ctx context.Context, t testing.TB, apiClient *agentapisdk.Client, waitFor time.Duration, msg string) error { + t.Helper() waitCtx, waitCancel := context.WithTimeout(ctx, waitFor) defer waitCancel() - tick := time.NewTicker(100 * time.Millisecond) + start := time.Now() + tick := time.NewTicker(time.Millisecond) defer tick.Stop() + var prevStatus agentapisdk.AgentStatus + defer func() { + elapsed := time.Since(start) + t.Logf("%s: agent API status: %s (elapsed: %s)", msg, prevStatus, elapsed.Round(100*time.Millisecond)) + }() for { select { case <-waitCtx.Done(): return waitCtx.Err() case <-tick.C: + tick.Reset(100 * time.Millisecond) sr, err := apiClient.GetStatus(ctx) if err != nil { continue } + prevStatus = sr.Status if sr.Status == agentapisdk.StatusStable { return nil } diff --git a/e2e/testdata/stdin.json b/e2e/testdata/stdin.json new file mode 100644 index 0000000..309624b --- /dev/null +++ b/e2e/testdata/stdin.json @@ -0,0 +1,6 @@ +[ + { + "expectMessage": "hello agent", + "responseMessage": "Hello! I'm ready to help you. Please send me a message to echo back." + } +]