Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"sort"
"strings"

"github.com/mattn/go-isatty"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"golang.org/x/xerrors"
Expand Down Expand Up @@ -88,6 +90,19 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
return xerrors.Errorf("term height must be at least 10")
}

// Read stdin if it's piped, to be used as initial prompt
initialPrompt := viper.GetString(FlagInitialPrompt)
if initialPrompt == "" {
if !isatty.IsTerminal(os.Stdin.Fd()) {
if stdinData, err := io.ReadAll(os.Stdin); err != nil {
return xerrors.Errorf("failed to read stdin: %w", err)
} else if len(stdinData) > 0 {
initialPrompt = string(stdinData)
logger.Info("Read initial prompt from stdin", "bytes", len(stdinData))
}
}
}

printOpenAPI := viper.GetBool(FlagPrintOpenAPI)
var process *termexec.Process
if printOpenAPI {
Expand All @@ -112,7 +127,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
ChatBasePath: viper.GetString(FlagChatBasePath),
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
InitialPrompt: viper.GetString(FlagInitialPrompt),
InitialPrompt: initialPrompt,
})
if err != nil {
return xerrors.Errorf("failed to create server: %w", err)
Expand Down Expand Up @@ -213,7 +228,7 @@ func CreateServerCmd() *cobra.Command {
{FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
{FlagInitialPrompt, "I", "", "Initial prompt for the agent (recommended only if the agent doesn't support initial prompt in interaction mode)", "string"},
{FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"},
}

for _, spec := range flagSpecs {
Expand Down
69 changes: 55 additions & 14 deletions e2e/echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (

const (
testTimeout = 30 * time.Second
operationTimeout = 5 * time.Second
operationTimeout = 10 * time.Second
healthCheckTimeout = 10 * time.Second
)

Expand All @@ -40,15 +40,14 @@ func TestE2E(t *testing.T) {
t.Run("basic", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
script, apiClient := setup(ctx, t)
require.NoError(t, waitAgentAPIStable(ctx, apiClient, operationTimeout))
script, apiClient := setup(ctx, t, nil)
messageReq := agentapisdk.PostMessageParams{
Content: "This is a test message.",
Type: agentapisdk.MessageTypeUser,
}
_, err := apiClient.PostMessage(ctx, messageReq)
require.NoError(t, err, "Failed to send message via SDK")
require.NoError(t, waitAgentAPIStable(ctx, apiClient, operationTimeout))
require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "post message"))
msgResp, err := apiClient.GetMessages(ctx)
require.NoError(t, err, "Failed to get messages via SDK")
require.Len(t, msgResp.Messages, 3)
Expand All @@ -61,7 +60,7 @@ func TestE2E(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()

script, apiClient := setup(ctx, t)
script, apiClient := setup(ctx, t, nil)
messageReq := agentapisdk.PostMessageParams{
Content: "What is the answer to life, the universe, and everything?",
Type: agentapisdk.MessageTypeUser,
Expand All @@ -71,7 +70,7 @@ func TestE2E(t *testing.T) {
statusResp, err := apiClient.GetStatus(ctx)
require.NoError(t, err)
require.Equal(t, agentapisdk.StatusRunning, statusResp.Status)
require.NoError(t, waitAgentAPIStable(ctx, apiClient, 5*time.Second))
require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, 5*time.Second, "post message"))
msgResp, err := apiClient.GetMessages(ctx)
require.NoError(t, err, "Failed to get messages via SDK")
require.Len(t, msgResp.Messages, 3)
Expand All @@ -82,11 +81,45 @@ func TestE2E(t *testing.T) {
require.Equal(t, script[1].ResponseMessage, strings.TrimSpace(parts[0]))
require.Equal(t, script[2].ResponseMessage, strings.TrimSpace(parts[1]))
})

t.Run("stdin", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()

script, apiClient := setup(ctx, t, &params{
cmdFn: func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) {
defCmd, defArgs := defaultCmdFn(ctx, t, serverPort, binaryPath, cwd, scriptFilePath)
script := fmt.Sprintf(`echo "hello agent" | %s %s`, defCmd, strings.Join(defArgs, " "))
return "/bin/sh", []string{"-c", script}
},
})
require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, 5*time.Second, "stdin"))
msgResp, err := apiClient.GetMessages(ctx)
require.NoError(t, err, "Failed to get messages via SDK")
require.Len(t, msgResp.Messages, 3)
require.Equal(t, script[0].ExpectMessage, strings.TrimSpace(msgResp.Messages[1].Content))
require.Equal(t, script[0].ResponseMessage, strings.TrimSpace(msgResp.Messages[2].Content))
})
}

func setup(ctx context.Context, t testing.TB) ([]ScriptEntry, *agentapisdk.Client) {
type params struct {
cmdFn func(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string)
}

func defaultCmdFn(ctx context.Context, t testing.TB, serverPort int, binaryPath, cwd, scriptFilePath string) (string, []string) {
return binaryPath, []string{"server", fmt.Sprintf("--port=%d", serverPort), "--", "go", "run", filepath.Join(cwd, "echo.go"), scriptFilePath}
}

func setup(ctx context.Context, t testing.TB, p *params) ([]ScriptEntry, *agentapisdk.Client) {
t.Helper()

if p == nil {
p = &params{}
}
if p.cmdFn == nil {
p.cmdFn = defaultCmdFn
}

scriptFilePath := filepath.Join("testdata", filepath.Base(t.Name())+".json")
data, err := os.ReadFile(scriptFilePath)
require.NoError(t, err, "Failed to read test script file: %s", scriptFilePath)
Expand Down Expand Up @@ -116,10 +149,9 @@ func setup(ctx context.Context, t testing.TB) ([]ScriptEntry, *agentapisdk.Clien
cwd, err := os.Getwd()
require.NoError(t, err, "Failed to get current working directory")

cmd := exec.CommandContext(ctx, binaryPath, "server",
fmt.Sprintf("--port=%d", serverPort),
"--",
"go", "run", filepath.Join(cwd, "echo.go"), scriptFilePath)
bin, args := p.cmdFn(ctx, t, serverPort, binaryPath, cwd, scriptFilePath)
t.Logf("Running command: %s %s", bin, strings.Join(args, " "))
cmd := exec.CommandContext(ctx, bin, args...)

// Capture output for debugging
stdout, err := cmd.StdoutPipe()
Expand Down Expand Up @@ -160,7 +192,7 @@ func setup(ctx context.Context, t testing.TB) ([]ScriptEntry, *agentapisdk.Clien
apiClient, err := agentapisdk.NewClient(serverURL)
require.NoError(t, err, "Failed to create agentapi SDK client")

require.NoError(t, waitAgentAPIStable(ctx, apiClient, operationTimeout))
require.NoError(t, waitAgentAPIStable(ctx, t, apiClient, operationTimeout, "setup"))
return script, apiClient
}

Expand Down Expand Up @@ -198,21 +230,30 @@ func waitForServer(ctx context.Context, t testing.TB, url string, timeout time.D
}
}

func waitAgentAPIStable(ctx context.Context, apiClient *agentapisdk.Client, waitFor time.Duration) error {
func waitAgentAPIStable(ctx context.Context, t testing.TB, apiClient *agentapisdk.Client, waitFor time.Duration, msg string) error {
t.Helper()
waitCtx, waitCancel := context.WithTimeout(ctx, waitFor)
defer waitCancel()

tick := time.NewTicker(100 * time.Millisecond)
start := time.Now()
tick := time.NewTicker(time.Millisecond)
defer tick.Stop()
var prevStatus agentapisdk.AgentStatus
defer func() {
elapsed := time.Since(start)
t.Logf("%s: agent API status: %s (elapsed: %s)", msg, prevStatus, elapsed.Round(100*time.Millisecond))
}()
for {
select {
case <-waitCtx.Done():
return waitCtx.Err()
case <-tick.C:
tick.Reset(100 * time.Millisecond)
sr, err := apiClient.GetStatus(ctx)
if err != nil {
continue
}
prevStatus = sr.Status
if sr.Status == agentapisdk.StatusStable {
return nil
}
Expand Down
6 changes: 6 additions & 0 deletions e2e/testdata/stdin.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[
{
"expectMessage": "hello agent",
"responseMessage": "Hello! I'm ready to help you. Please send me a message to echo back."
}
]