diff --git a/README.md b/README.md index 2191865..3d450cf 100644 --- a/README.md +++ b/README.md @@ -50,9 +50,25 @@ ana auth login --endpoint https://app.textql.com ana org show ana connector list ana chat send "show me last month's revenue" +ana api textql.rpc.public.auth.PublicAuthService/GetOrganization # raw JSON passthrough ana update # replace the running binary with the latest release ``` +### `ana api` — raw authenticated passthrough + +`ana api ` sends an authenticated HTTP request and prints the response. +Two path forms: + +- `/` — Connect-RPC short form, prefixed with `/rpc/public/`. +- `/v1/...` (or any leading-slash path) — sent verbatim. Covers both the + documented REST API (`docs.textql.com/api-reference`) and pre-resolved RPC + paths. + +Body can be supplied with `--data ''` or `--data-stdin`. Default method +is `POST` with a `{}` body (so short-form RPC calls Just Work); GET/HEAD +auto-omit the body. `--raw` passes the response through verbatim instead of +pretty-printing. + Run `ana --help` or `ana --help` for command-specific flags. `ana` checks GitHub for a newer release after each verb and prints a one-line diff --git a/cmd/ana/main.go b/cmd/ana/main.go index a8c70be..2062448 100644 --- a/cmd/ana/main.go +++ b/cmd/ana/main.go @@ -16,6 +16,7 @@ import ( "os/signal" "time" + "github.com/highperformance-tech/ana-cli/internal/api" "github.com/highperformance-tech/ana-cli/internal/audit" "github.com/highperformance-tech/ana-cli/internal/auth" "github.com/highperformance-tech/ana-cli/internal/chat" @@ -208,6 +209,7 @@ func drainNudge(ch chan string, timeout time.Duration, verbErr error, stderr io. // self-hosted and non-prod profiles at the right place. func buildVerbs(client *transport.Client, env func(string) string, cfgPath, profileName, endpoint string) map[string]cli.Command { return map[string]cli.Command{ + "api": api.New(api.Deps{DoRaw: client.DoRaw}), "auth": auth.New(authDeps(client, env, cfgPath, profileName)), "profile": profile.New(profileDeps(env, cfgPath)), "org": org.New(org.Deps{Unary: client.Unary}), diff --git a/cmd/ana/main_test.go b/cmd/ana/main_test.go index 5fa507f..f0aa622 100644 --- a/cmd/ana/main_test.go +++ b/cmd/ana/main_test.go @@ -287,7 +287,7 @@ func TestBuildVerbs_Shape(t *testing.T) { t.Parallel() client := transport.New("https://example", func(context.Context) (string, error) { return "", nil }) verbs := buildVerbs(client, func(string) string { return "" }, "", "default", "https://example") - want := []string{"auth", "profile", "org", "connector", "chat", "dashboard", "playbook", "ontology", "feed", "audit", "version", "update"} + want := []string{"api", "auth", "profile", "org", "connector", "chat", "dashboard", "playbook", "ontology", "feed", "audit", "version", "update"} for _, v := range want { if _, ok := verbs[v]; !ok { t.Errorf("missing verb: %q", v) diff --git a/docs/features.md b/docs/features.md index 31e860e..b74c20b 100644 --- a/docs/features.md +++ b/docs/features.md @@ -7,6 +7,7 @@ Per-endpoint request/response schemas live in `api-catalog/` (~95 endpoints as o ## API shape (global) - **Style:** Connect-RPC (buf-connect). All calls are `POST https://app.textql.com/rpc/public//`. +- **CLI raw access:** `ana api ` is the untyped escape hatch. Short form (`/`) maps to Connect-RPC; a leading-slash path (`/v1/...`) hits the documented REST API at `docs.textql.com/api-reference`. Both share host + bearer auth, so one verb covers both surfaces. - **Content-Type:** `application/json` request + response. - **Field casing:** protobuf JSON — **camelCase only**. Sending both `chatId` and `chat_id` → 400 `"duplicate field"`. CLI must emit camelCase. - **Error shape:** `{"code": "", "message": ""}` (e.g. `invalid_argument`, `not_found`, `internal`). diff --git a/internal/CLAUDE.md b/internal/CLAUDE.md index fac24bb..b3a3127 100644 --- a/internal/CLAUDE.md +++ b/internal/CLAUDE.md @@ -1,6 +1,6 @@ # internal -All domain logic for the `ana` CLI. Each verb package is pure dispatch: it declares a narrow `Deps` struct, registers its Connect-RPC service prefix, and exposes a `New(deps) *cli.Group` that `cmd/ana/main.go` wires up. Verb packages do not import `internal/transport` or `internal/config` (except `cli`, which is the dispatch core, and `profile`, whose whole purpose is config management). +All domain logic for the `ana` CLI. Each verb package is pure dispatch: it declares a narrow `Deps` struct, registers its Connect-RPC service prefix, and exposes a `New(deps) *cli.Group` (or a `cli.Command` leaf when there are no subcommands, as in `api/`) that `cmd/ana/main.go` wires up. Verb packages do not import `internal/transport` or `internal/config` (except `cli`, which is the dispatch core, and `profile`, whose whole purpose is config management). ## Test layout convention @@ -13,7 +13,8 @@ Multi-file verb packages use one `_test.go` per source file (e.g. `list. | `cli/` | Dispatch core: `Command` interface, `Group`, `ParseFlags`, `ParseGlobal`, `Dispatch`, exit-code mapping. | | `testcli/` | Test helpers for verb packages (stdlib `httptest` analogue): `FailingWriter`, `FailingIO`, `NewIO`. | | `config/` | Multi-profile config file reader/writer. XDG path resolution, `Resolve` precedence. | -| `transport/` | Connect-RPC HTTP client. Unary JSON + server-streaming JSON framing. | +| `transport/` | Connect-RPC HTTP client. Unary JSON + server-streaming JSON framing + `DoRaw` passthrough. Bearer + User-Agent applied via RoundTripper middleware. | +| `api/` | `ana api ` — raw authenticated HTTP passthrough for Connect-RPC short form + documented REST. Single leaf. | | `auth/` | `ana auth` verb tree — login/logout/whoami/keys/service-accounts. | | `profile/` | `ana profile` verb tree — add/use/remove/list/show. Imports `internal/config` by design. | | `org/` | `ana org` — list/show + nested members/roles/permissions. | diff --git a/internal/api/CLAUDE.md b/internal/api/CLAUDE.md new file mode 100644 index 0000000..680bf8b --- /dev/null +++ b/internal/api/CLAUDE.md @@ -0,0 +1,10 @@ +# internal/api + +The `ana api ` verb — authenticated raw-JSON passthrough over the shared transport client. Single leaf, no subcommands. Covers both Connect-RPC (`textql.rpc.public./` or `/rpc/public/...`) and the documented REST API (`/v1/...`) — one verb, two surfaces, distinguished by leading slash. + +## Files + +- `api.go` — `Deps` (single `DoRaw` function field), `New` (returns a leaf `cli.Command`, not a `*cli.Group` — no subcommands), and the `/rpc/public/` prefix constant. +- `call.go` — the leaf: flag parsing, path dispatch (leading slash → verbatim; else prefix-prepend), body resolution (`--data` / `--data-stdin` / default `{}` for POST, `nil` for GET/HEAD), and the `emitError`/`emitSuccess` split. Non-2xx writes the server body to stderr and returns an `api: HTTP ` summary error; 2xx writes pretty JSON to stdout (fallthrough to raw if the body isn't valid JSON; `--raw` skips pretty-print entirely). +- `api_test.go` — shared `fakeDeps` + `TestNew*`/`TestHelp*`. +- `call_test.go` — per-source test file covering both path forms, every body-resolution branch, mutual-exclusion, non-2xx stderr + trailing-newline branches, and the raw/pretty/non-JSON 2xx paths. diff --git a/internal/api/api.go b/internal/api/api.go new file mode 100644 index 0000000..ee5d99e --- /dev/null +++ b/internal/api/api.go @@ -0,0 +1,37 @@ +// Package api provides the `ana api` verb — an authenticated raw-JSON +// passthrough over the shared transport client. Dispatches to a single leaf. +// Two path forms: +// +// - Leading slash → sent verbatim (REST e.g. `/v1/...`, or a pre-resolved +// RPC path e.g. `/rpc/public//`). +// - No leading slash → treated as a fully-qualified Connect-RPC short form +// (e.g. `textql.rpc.public.auth.PublicAuthService/GetOrganization`) and +// prefixed with `/rpc/public/`. +// +// Like every other verb package, api never imports internal/transport — the +// caller adapts its transport client to the narrow Deps.DoRaw field. +package api + +import ( + "context" + + "github.com/highperformance-tech/ana-cli/internal/cli" +) + +// connectRPCPrefix is the path prefix applied to Connect-RPC short-form paths +// (those without a leading slash). Matches what every typed verb hard-codes. +const connectRPCPrefix = "/rpc/public/" + +// Deps is the injection boundary. A real wiring layer adapts +// transport.Client.DoRaw; tests pass fakes that record (method, path, body) +// so assertions can inspect the outbound request and the returned response. +type Deps struct { + DoRaw func(ctx context.Context, method, path string, body []byte) (int, []byte, error) +} + +// New returns the `api` verb as a single leaf command. Unlike other verb +// packages this is not a *cli.Group — there are no subcommands, just a path +// positional. +func New(deps Deps) cli.Command { + return &callCmd{deps: deps} +} diff --git a/internal/api/api_test.go b/internal/api/api_test.go new file mode 100644 index 0000000..6d93e5d --- /dev/null +++ b/internal/api/api_test.go @@ -0,0 +1,77 @@ +package api + +import ( + "context" + "strings" + "testing" + + "github.com/highperformance-tech/ana-cli/internal/cli" +) + +// --- fakes and helpers --- + +// fakeDeps is the package-wide fake for Deps. DoRaw delegates to doRawFn if +// set; either way it records (method, path, body) for post-call assertions. +type fakeDeps struct { + doRawFn func(ctx context.Context, method, path string, body []byte) (int, []byte, error) + lastMethod string + lastPath string + lastBody []byte +} + +func (f *fakeDeps) deps() Deps { + return Deps{ + DoRaw: func(ctx context.Context, method, path string, body []byte) (int, []byte, error) { + f.lastMethod = method + f.lastPath = path + // Copy so callers that reuse the slice don't mutate recorded state. + if body != nil { + f.lastBody = append(f.lastBody[:0], body...) + } else { + f.lastBody = nil + } + if f.doRawFn != nil { + return f.doRawFn(ctx, method, path, body) + } + return 200, []byte(`{"ok":true}`), nil + }, + } +} + +// --- New / leaf surface --- + +func TestNewReturnsLeaf(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + if cmd == nil { + t.Fatalf("New returned nil") + } + // Must NOT be a *cli.Group — api is a single-leaf verb. + if _, isGroup := cmd.(*cli.Group); isGroup { + t.Fatalf("api.New returned *cli.Group; want a leaf Command") + } + if _, isFlagger := cmd.(cli.Flagger); !isFlagger { + t.Fatalf("api leaf should implement cli.Flagger so --help can render a Flags block") + } +} + +// --- Help() --- + +func TestHelpContainsBothPathForms(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + h := cmd.Help() + for _, want := range []string{ + "Usage", + "/", + "/rpc/public/", + "/v1/", + "--raw", + } { + if !strings.Contains(h, want) { + t.Errorf("help missing %q in:\n%s", want, h) + } + } +} diff --git a/internal/api/call.go b/internal/api/call.go new file mode 100644 index 0000000..2acd5ba --- /dev/null +++ b/internal/api/call.go @@ -0,0 +1,176 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "strings" + + "github.com/highperformance-tech/ana-cli/internal/cli" +) + +// callCmd implements the `ana api ` leaf. Flags: +// +// --method HTTP verb (default POST; any non-empty string accepted). +// --data literal JSON request body. +// --data-stdin read the request body from stdin (mutually exclusive with --data). +// --raw emit the response body verbatim (skip json.Indent). +// +// The global --json flag is a no-op here — the default output IS pretty JSON; +// --raw is the opposite. Documented in Help(). +type callCmd struct { + deps Deps + + method string + data string + dataStdin bool + raw bool +} + +func (c *callCmd) Help() string { + return "api Authenticated HTTP passthrough. JSON pretty-printed by default; --raw for verbatim bytes.\n" + + "Usage: ana api [--method M] [--data JSON | --data-stdin] [--raw]\n" + + "\n" + + "Paths:\n" + + " / Connect-RPC short form; prefixed with /rpc/public/\n" + + " e.g. textql.rpc.public.auth.PublicAuthService/GetOrganization\n" + + " /rpc/public/<...> Connect-RPC full path; sent as-is\n" + + " /v1/<...> Documented REST API path (docs.textql.com/api-reference)\n" + + "\n" + + "Note: the global --json flag is ignored here; output is JSON by default." +} + +// Flags declares this leaf's flags. Implementing cli.Flagger lets dispatchChild +// render a `Flags:` block under --help so the four knobs are discoverable. +func (c *callCmd) Flags(fs *flag.FlagSet) { + fs.StringVar(&c.method, "method", "POST", "HTTP method (default POST)") + fs.StringVar(&c.data, "data", "", "literal JSON request body") + fs.BoolVar(&c.dataStdin, "data-stdin", false, "read the request body from stdin") + fs.BoolVar(&c.raw, "raw", false, "pass the response body through verbatim (skip pretty-print)") +} + +func (c *callCmd) Run(ctx context.Context, args []string, stdio cli.IO) error { + fs := cli.NewFlagSet("api") + c.Flags(fs) + cli.ApplyAncestorFlags(ctx, fs) + if err := cli.ParseFlags(fs, args); err != nil { + return err + } + + if fs.NArg() == 0 { + return cli.UsageErrf("api: positional argument required") + } + if fs.NArg() > 1 { + return cli.UsageErrf("api: unexpected positional arguments: %v", fs.Args()[1:]) + } + // Trim once and reuse — otherwise a whitespace-padded arg like + // `" /v1/things "` passes the blank check but gets forwarded to the + // transport verbatim, which joinURL would then stitch into a malformed URL. + path := strings.TrimSpace(fs.Arg(0)) + if path == "" { + return cli.UsageErrf("api: positional argument required") + } + + if c.method == "" { + return cli.UsageErrf("api: --method must not be empty") + } + dataSet := cli.FlagWasSet(fs, "data") + if dataSet && c.dataStdin { + return cli.UsageErrf("api: --data and --data-stdin are mutually exclusive") + } + + // Path dispatch: leading slash → verbatim; otherwise Connect-RPC short form. + resolvedPath := path + if !strings.HasPrefix(path, "/") { + resolvedPath = connectRPCPrefix + path + } + + body, err := resolveBody(c, dataSet, stdio.Stdin) + if err != nil { + return err + } + + status, respBody, err := c.deps.DoRaw(ctx, c.method, resolvedPath, body) + if err != nil { + return fmt.Errorf("api: %w", err) + } + + if status < 200 || status >= 300 { + return c.emitError(stdio, status, respBody) + } + return c.emitSuccess(stdio, respBody) +} + +// resolveBody picks the outbound body bytes. Precedence (after the +// --data/--data-stdin mutual-exclusion check in the caller): +// +// - --data-stdin: io.ReadAll so the bytes round-trip exactly (ReadToken +// would trim whitespace, which matters for binary-ish JSON payloads). +// - --data set (even to ""): use the literal bytes. +// - neither: nil for GET/HEAD (no body), `{}` otherwise so Connect-RPC's +// required-body contract is still satisfied. +func resolveBody(c *callCmd, dataSet bool, stdin io.Reader) ([]byte, error) { + switch { + case c.dataStdin: + b, err := io.ReadAll(stdin) + if err != nil { + return nil, fmt.Errorf("api: read stdin: %w", err) + } + return b, nil + case dataSet: + return []byte(c.data), nil + } + if strings.EqualFold(c.method, "GET") || strings.EqualFold(c.method, "HEAD") { + return nil, nil + } + return []byte("{}"), nil +} + +// emitError writes the server's error body to stderr (so stdout stays empty +// for `| jq`) and returns a summary error. Main's fallback printer adds the +// `api: HTTP ` line on its own stderr write — body + status together. +// A trailing newline is appended if the body didn't already end with one so +// the status line doesn't get glued to the last byte of the response. +func (c *callCmd) emitError(stdio cli.IO, status int, body []byte) error { + if len(body) > 0 { + if _, werr := stdio.Stderr.Write(body); werr != nil { + return fmt.Errorf("api: %w", werr) + } + if !bytes.HasSuffix(body, []byte("\n")) { + if _, werr := fmt.Fprintln(stdio.Stderr); werr != nil { + return fmt.Errorf("api: %w", werr) + } + } + } + return fmt.Errorf("api: HTTP %d", status) +} + +// emitSuccess writes the 2xx body. With --raw (or when the body is empty) +// the bytes are passed through verbatim; otherwise json.Indent pretty-prints. +// A body that isn't valid JSON (e.g. 204 empty, or some future text endpoint) +// falls through to the raw path so we don't fail an otherwise-successful call. +func (c *callCmd) emitSuccess(stdio cli.IO, body []byte) error { + if c.raw || len(body) == 0 { + if _, werr := stdio.Stdout.Write(body); werr != nil { + return fmt.Errorf("api: %w", werr) + } + return nil + } + var pretty bytes.Buffer + if err := json.Indent(&pretty, body, "", " "); err != nil { + if _, werr := stdio.Stdout.Write(body); werr != nil { + return fmt.Errorf("api: %w", werr) + } + return nil + } + if _, werr := stdio.Stdout.Write(pretty.Bytes()); werr != nil { + return fmt.Errorf("api: %w", werr) + } + if _, werr := stdio.Stdout.Write([]byte("\n")); werr != nil { + return fmt.Errorf("api: %w", werr) + } + return nil +} diff --git a/internal/api/call_test.go b/internal/api/call_test.go new file mode 100644 index 0000000..029b154 --- /dev/null +++ b/internal/api/call_test.go @@ -0,0 +1,520 @@ +package api + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/highperformance-tech/ana-cli/internal/cli" + "github.com/highperformance-tech/ana-cli/internal/testcli" +) + +// --- path dispatch --- + +func TestRPCShortFormPrefixed(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{"foo.Bar/Baz"}, stdio) + if err != nil { + t.Fatalf("Run: %v", err) + } + if f.lastPath != "/rpc/public/foo.Bar/Baz" { + t.Errorf("path = %q, want /rpc/public/foo.Bar/Baz", f.lastPath) + } + if f.lastMethod != "POST" { + t.Errorf("method = %q, want POST", f.lastMethod) + } + // No --data, no GET → default body is `{}`. + if string(f.lastBody) != "{}" { + t.Errorf("body = %q, want {}", f.lastBody) + } +} + +func TestRESTLeadingSlashPassthrough(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{"/v1/things"}, stdio) + if err != nil { + t.Fatalf("Run: %v", err) + } + if f.lastPath != "/v1/things" { + t.Errorf("path = %q", f.lastPath) + } +} + +func TestRPCLeadingSlashPassthrough(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{"/rpc/public/foo/Bar"}, stdio) + if err != nil { + t.Fatalf("Run: %v", err) + } + if f.lastPath != "/rpc/public/foo/Bar" { + t.Errorf("path = %q", f.lastPath) + } +} + +// --- method / body --- + +func TestMethodGETHasNoBody(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{"--method", "GET", "/v1/things"}, stdio) + if err != nil { + t.Fatalf("Run: %v", err) + } + if f.lastMethod != "GET" { + t.Errorf("method = %q", f.lastMethod) + } + if f.lastBody != nil { + t.Errorf("body = %q, want nil for GET", f.lastBody) + } +} + +func TestMethodHEADHasNoBody(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{"--method", "HEAD", "/v1/things"}, stdio) + if err != nil { + t.Fatalf("Run: %v", err) + } + if f.lastBody != nil { + t.Errorf("body = %q, want nil for HEAD", f.lastBody) + } +} + +func TestDataFlagUsedAsBody(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), + []string{"--data", `{"x":1}`, "foo/Bar"}, stdio) + if err != nil { + t.Fatalf("Run: %v", err) + } + if string(f.lastBody) != `{"x":1}` { + t.Errorf("body = %q", f.lastBody) + } +} + +// TestDataEmptyStringSkipsDefaultBody pins the "explicit empty body" contract: +// `--data ""` must hand DoRaw an empty slice (which DoRaw then treats like +// nil), NOT the POST default `{}`. Guards against the body-resolution +// refactor silently drifting back to using `c.data != ""` as the presence +// check. +func TestDataEmptyStringSkipsDefaultBody(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), + []string{"--data", "", "foo/Bar"}, stdio) + if err != nil { + t.Fatalf("Run: %v", err) + } + if string(f.lastBody) != "" { + t.Errorf("body = %q, want empty (explicit --data \"\" wins over the POST default {})", f.lastBody) + } +} + +func TestDataStdinUsedAsBody(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(strings.NewReader(`{"chatId":"abc"}`)) + err := cmd.Run(context.Background(), + []string{"--data-stdin", "foo/Bar"}, stdio) + if err != nil { + t.Fatalf("Run: %v", err) + } + if string(f.lastBody) != `{"chatId":"abc"}` { + t.Errorf("body = %q", f.lastBody) + } +} + +// errReader always fails — exercises the io.ReadAll branch in resolveBody. +type errReader struct{} + +func (errReader) Read(_ []byte) (int, error) { return 0, errors.New("stdin boom") } + +func TestDataStdinReadErr(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(errReader{}) + err := cmd.Run(context.Background(), + []string{"--data-stdin", "foo/Bar"}, stdio) + if err == nil || !strings.Contains(err.Error(), "read stdin") { + t.Fatalf("want read stdin error, got %v", err) + } +} + +func TestDataAndDataStdinMutuallyExclusive(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(strings.NewReader("")) + err := cmd.Run(context.Background(), + []string{"--data", "{}", "--data-stdin", "foo/Bar"}, stdio) + if !errors.Is(err, cli.ErrUsage) { + t.Errorf("err = %v, want ErrUsage", err) + } + if !strings.Contains(err.Error(), "mutually exclusive") { + t.Errorf("err = %v, want mutual-exclusion message", err) + } +} + +// --- validation --- + +func TestMissingPathIsUsageError(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), nil, stdio) + if !errors.Is(err, cli.ErrUsage) { + t.Errorf("err = %v, want ErrUsage", err) + } +} + +func TestBlankPathIsUsageError(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + // A whitespace-only arg is just as useless as a missing one. + err := cmd.Run(context.Background(), []string{" "}, stdio) + if !errors.Is(err, cli.ErrUsage) { + t.Errorf("err = %v, want ErrUsage", err) + } +} + +// TestPathIsTrimmedBeforeDispatch pins the contract that a +// whitespace-padded path is trimmed before being forwarded. Without this, +// `ana api " /v1/things "` would pass the blank-check and then hit a +// malformed URL downstream. +func TestPathIsTrimmedBeforeDispatch(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{" /v1/things\t"}, stdio) + if err != nil { + t.Fatalf("Run: %v", err) + } + if f.lastPath != "/v1/things" { + t.Errorf("path = %q, want %q (trimmed)", f.lastPath, "/v1/things") + } +} + +// TestExtraPositionalsRejected — silently dropping extras masks typos like +// `ana api /v1/things stray` (user probably meant a flag). Exactly one +// non-blank positional is required. +func TestExtraPositionalsRejected(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{"/v1/things", "stray"}, stdio) + if !errors.Is(err, cli.ErrUsage) { + t.Errorf("err = %v, want ErrUsage", err) + } + if !strings.Contains(err.Error(), "stray") { + t.Errorf("err should name the unexpected arg: %v", err) + } +} + +func TestEmptyMethodIsUsageError(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{"--method", "", "foo/Bar"}, stdio) + if !errors.Is(err, cli.ErrUsage) { + t.Errorf("err = %v, want ErrUsage", err) + } +} + +func TestUnknownFlagIsUsageError(t *testing.T) { + t.Parallel() + f := &fakeDeps{} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{"--nope", "foo/Bar"}, stdio) + if !errors.Is(err, cli.ErrUsage) { + t.Errorf("err = %v, want ErrUsage", err) + } +} + +// --- DoRaw error / non-2xx --- + +func TestTransportErrorWrapped(t *testing.T) { + t.Parallel() + boom := errors.New("dial fail") + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 0, nil, boom + }} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio) + if err == nil || !errors.Is(err, boom) { + t.Fatalf("want wrapped %v, got %v", boom, err) + } + if !strings.HasPrefix(err.Error(), "api:") { + t.Errorf("err missing api: prefix: %v", err) + } +} + +func TestNon2xxBodyToStderr(t *testing.T) { + t.Parallel() + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 404, []byte(`{"code":"not_found"}`), nil + }} + cmd := New(f.deps()) + stdio, out, errb := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio) + if err == nil || !strings.Contains(err.Error(), "HTTP 404") { + t.Fatalf("err = %v, want HTTP 404", err) + } + if out.Len() != 0 { + t.Errorf("stdout should be empty on non-2xx, got %q", out.String()) + } + if !strings.Contains(errb.String(), "not_found") { + t.Errorf("stderr missing body: %q", errb.String()) + } + // Body did not end in newline, so one should be appended so the caller's + // prompt isn't glued to the response. + if !strings.HasSuffix(errb.String(), "\n") { + t.Errorf("stderr should end in newline, got %q", errb.String()) + } +} + +func TestNon2xxBodyAlreadyEndsInNewline(t *testing.T) { + t.Parallel() + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 500, []byte("{}\n"), nil + }} + cmd := New(f.deps()) + stdio, _, errb := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio) + if err == nil { + t.Fatal("expected error") + } + // One newline total — we shouldn't double it. + if strings.HasSuffix(errb.String(), "\n\n") { + t.Errorf("stderr double-newlined: %q", errb.String()) + } +} + +func TestNon2xxEmptyBody(t *testing.T) { + t.Parallel() + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 400, nil, nil + }} + cmd := New(f.deps()) + stdio, _, errb := testcli.NewIO(nil) + err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio) + if err == nil || !strings.Contains(err.Error(), "HTTP 400") { + t.Fatalf("err = %v", err) + } + if errb.Len() != 0 { + t.Errorf("stderr should be empty for empty body, got %q", errb.String()) + } +} + +func TestNon2xxStderrWriteErr(t *testing.T) { + t.Parallel() + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 500, []byte(`{"code":"x"}`), nil + }} + cmd := New(f.deps()) + stdio := testcli.FailingIO() + // FailingIO.Stdout is the failing writer; swap stderr too via a copy. + stdio.Stderr = testcli.FailingWriter{} + err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio) + if err == nil || !strings.Contains(err.Error(), "w boom") { + t.Fatalf("err = %v, want stderr write error", err) + } +} + +func TestNon2xxStderrTrailingNewlineWriteErr(t *testing.T) { + t.Parallel() + // Body lacks trailing newline → emitError adds one. A writer that accepts + // the body then refuses the trailing newline exercises the second Fprintln + // branch. + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 500, []byte(`{"code":"x"}`), nil // no \n + }} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + stdio.Stderr = &acceptThenFail{} + err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio) + if err == nil || !strings.Contains(err.Error(), "api:") { + t.Fatalf("err = %v, want api: wrap of trailing-newline write error", err) + } +} + +// --- 2xx happy output --- + +func TestSuccessPrettyPrint(t *testing.T) { + t.Parallel() + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 200, []byte(`{"a":1,"b":2}`), nil + }} + cmd := New(f.deps()) + stdio, out, _ := testcli.NewIO(nil) + if err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio); err != nil { + t.Fatalf("Run: %v", err) + } + s := out.String() + // Pretty output has newlines + 2-space indent between keys. + if !strings.Contains(s, "\n \"a\": 1,") { + t.Errorf("expected indented output, got %q", s) + } + if !strings.HasSuffix(s, "\n") { + t.Errorf("expected trailing newline, got %q", s) + } +} + +func TestSuccessRawPassthrough(t *testing.T) { + t.Parallel() + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 200, []byte(`{"a":1}`), nil + }} + cmd := New(f.deps()) + stdio, out, _ := testcli.NewIO(nil) + if err := cmd.Run(context.Background(), []string{"--raw", "foo/Bar"}, stdio); err != nil { + t.Fatalf("Run: %v", err) + } + if out.String() != `{"a":1}` { + t.Errorf("raw passthrough expected verbatim bytes, got %q", out.String()) + } +} + +func TestSuccessNonJSONFallsThroughToRaw(t *testing.T) { + t.Parallel() + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 200, []byte("plain text"), nil + }} + cmd := New(f.deps()) + stdio, out, _ := testcli.NewIO(nil) + if err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio); err != nil { + t.Fatalf("Run: %v", err) + } + if out.String() != "plain text" { + t.Errorf("non-JSON body should pass through verbatim, got %q", out.String()) + } +} + +func TestSuccessEmptyBody(t *testing.T) { + t.Parallel() + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 204, nil, nil + }} + cmd := New(f.deps()) + stdio, out, _ := testcli.NewIO(nil) + if err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio); err != nil { + t.Fatalf("Run: %v", err) + } + if out.Len() != 0 { + t.Errorf("expected empty stdout, got %q", out.String()) + } +} + +func TestSuccessRawStdoutErr(t *testing.T) { + t.Parallel() + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 200, []byte(`{"a":1}`), nil + }} + cmd := New(f.deps()) + stdio := testcli.FailingIO() + err := cmd.Run(context.Background(), []string{"--raw", "foo/Bar"}, stdio) + if err == nil || !strings.Contains(err.Error(), "w boom") { + t.Fatalf("err = %v, want stdout write error", err) + } +} + +func TestSuccessPrettyStdoutErr(t *testing.T) { + t.Parallel() + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 200, []byte(`{"a":1}`), nil + }} + cmd := New(f.deps()) + stdio := testcli.FailingIO() + err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio) + if err == nil || !strings.Contains(err.Error(), "w boom") { + t.Fatalf("err = %v, want stdout write error", err) + } +} + +func TestSuccessPrettyTrailingNewlineStdoutErr(t *testing.T) { + t.Parallel() + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 200, []byte(`{"a":1}`), nil + }} + cmd := New(f.deps()) + stdio, _, _ := testcli.NewIO(nil) + // Accept the pretty body, then fail on the trailing newline write. + stdio.Stdout = &acceptThenFail{} + err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio) + if err == nil || !strings.Contains(err.Error(), "api:") { + t.Fatalf("err = %v, want api: wrap of trailing-newline write error", err) + } +} + +func TestSuccessNonJSONStdoutErr(t *testing.T) { + t.Parallel() + // Body isn't JSON → json.Indent fails → fallthrough to raw write. With a + // failing stdout the raw write errors, exercising the inner branch. + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 200, []byte("plain text"), nil + }} + cmd := New(f.deps()) + stdio := testcli.FailingIO() + err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio) + if err == nil || !strings.Contains(err.Error(), "w boom") { + t.Fatalf("err = %v, want stdout write error on non-JSON fallthrough", err) + } +} + +func TestSuccessRawEmptyBodyStdoutErr(t *testing.T) { + t.Parallel() + // Empty body + --raw takes the `c.raw || len(body) == 0` branch but + // stdio.Stdout.Write on 0 bytes still returns (0, nil) for a FailingWriter? + // FailingWriter always errors — so this exercises the Write-err branch + // with an empty body input. + f := &fakeDeps{doRawFn: func(context.Context, string, string, []byte) (int, []byte, error) { + return 200, []byte{}, nil + }} + cmd := New(f.deps()) + stdio := testcli.FailingIO() + err := cmd.Run(context.Background(), []string{"foo/Bar"}, stdio) + if err == nil || !strings.Contains(err.Error(), "w boom") { + t.Fatalf("err = %v, want stdout write error on empty-body raw path", err) + } +} + +// acceptThenFail accepts the first Write call completely, then fails every +// subsequent call. Exercises the "body wrote OK but trailing-newline write +// failed" branches in emitError and emitSuccess without per-test byte counts. +type acceptThenFail struct{ firstDone bool } + +func (w *acceptThenFail) Write(p []byte) (int, error) { + if !w.firstDone { + w.firstDone = true + return len(p), nil + } + return 0, errors.New("w boom") +} diff --git a/internal/transport/CLAUDE.md b/internal/transport/CLAUDE.md index 4cbf6ad..0871aef 100644 --- a/internal/transport/CLAUDE.md +++ b/internal/transport/CLAUDE.md @@ -4,7 +4,7 @@ Minimal Connect-RPC over HTTP client used by every verb package. Supports unary ## Files -- `client.go` — `Client`, `New`, functional `Option`s (`WithHTTPClient`, `WithUserAgent`), `Unary`, and `Stream`. Injects a `tokenFn` so the transport stays agnostic to where the bearer token comes from. +- `client.go` — `Client`, `New`, functional `Option`s (`WithHTTPClient`, `WithUserAgent`), `Unary`, `Stream`, and `DoRaw` (raw authenticated HTTP used by `internal/api`). Injects a `tokenFn` so the transport stays agnostic to where the bearer token comes from. Bearer + User-Agent attach via a `bearerTransport` RoundTripper middleware wrapped around the configured `http.Client.Transport` — every call site (Unary, Stream, DoRaw) inherits auth for free; there is no per-call-site header plumbing. - `stream.go` — `StreamReader` (one `Next`/`Close` per frame). Terminal frame has the `trailerFlag` bit set and either an empty body or a `{code, message}` error envelope. - `error.go` — `Error` (wraps HTTP status + Connect error code/message), the `IsAuth` predicate used by commands to surface `auth.ErrNotLoggedIn`, and the `IsAuthError()` method that lets `*Error` satisfy the unexported `IsAuthError() bool` interface picked up by both `cli.ExitCode` and `auth.translateErr` — the typed escape hatch that replaces string-matching `"unauthenticated"`. - `client_test.go`, `stream_test.go`, `error_test.go`, `transport_test.go` — drive `httptest.Server` instances to cover happy paths, mid-stream errors, trailer parsing, and auth classification. diff --git a/internal/transport/client.go b/internal/transport/client.go index 6c45a80..551bfad 100644 --- a/internal/transport/client.go +++ b/internal/transport/client.go @@ -43,6 +43,10 @@ type Client struct { // call; tokenFn supplies a bearer token per request (return "" to skip the // Authorization header). Options may override the default http.Client and set // a User-Agent. +// +// After options run, the resolved http.Client's Transport is wrapped with +// bearerTransport so auth + User-Agent attach at the transport layer. Unary, +// Stream, and DoRaw all inherit this — no per-call-site header plumbing. func New(baseURL string, tokenFn func(context.Context) (string, error), opts ...Option) *Client { c := &Client{ httpClient: http.DefaultClient, @@ -52,9 +56,53 @@ func New(baseURL string, tokenFn func(context.Context) (string, error), opts ... for _, opt := range opts { opt(c) } + // Clone so we don't mutate the caller's *http.Client (or http.DefaultClient). + clone := *c.httpClient + base := clone.Transport + if base == nil { + base = http.DefaultTransport + } + clone.Transport = &bearerTransport{next: base, c: c} + c.httpClient = &clone return c } +// bearerTransport is an http.RoundTripper middleware that attaches the bearer +// token + User-Agent to every outbound request. It reads tokenFn and userAgent +// off the parent Client on each call, so post-construction mutation (test +// harnesses that tweak tokenFn after New) still takes effect. +type bearerTransport struct { + next http.RoundTripper + c *Client +} + +// RoundTrip injects auth + User-Agent then delegates to next. Per the +// net/http RoundTripper contract the incoming *http.Request must not be +// mutated, so we clone before touching headers. A tokenFn error is wrapped +// with "token: %w" so callers can still errors.Is the underlying cause +// (http.Client.Do wraps this in *url.Error, which preserves %w). Existing +// Authorization / User-Agent headers on the clone are never overwritten — +// lets a caller that pre-sets them (e.g. a future --header flag) opt out. +func (b *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + cloned := req.Clone(req.Context()) + if cloned.Header == nil { + cloned.Header = make(http.Header) + } + if b.c.userAgent != "" && cloned.Header.Get("User-Agent") == "" { + cloned.Header.Set("User-Agent", b.c.userAgent) + } + if b.c.tokenFn != nil { + token, err := b.c.tokenFn(cloned.Context()) + if err != nil { + return nil, fmt.Errorf("token: %w", err) + } + if token != "" && cloned.Header.Get("Authorization") == "" { + cloned.Header.Set("Authorization", "Bearer "+token) + } + } + return b.next.RoundTrip(cloned) +} + // joinURL concatenates baseURL and path, collapsing at most one pair of // adjacent slashes so callers don't need to care whether baseURL has a // trailing slash or path has a leading one. @@ -65,9 +113,9 @@ func joinURL(baseURL, path string) string { return baseURL + path } -// buildRequest marshals req to JSON and constructs a POST request with all -// standard Connect-over-JSON headers applied. It is shared by Unary and -// Stream so header/body behavior stays in lockstep. +// buildRequest marshals req to JSON and constructs a POST request with the +// Connect-over-JSON content/accept headers. Auth + User-Agent are attached by +// bearerTransport at round-trip time. func (c *Client) buildRequest(ctx context.Context, path string, req any) (*http.Request, error) { var body []byte if req == nil { @@ -87,18 +135,6 @@ func (c *Client) buildRequest(ctx context.Context, path string, req any) (*http. } httpReq.Header.Set("content-type", "application/json") httpReq.Header.Set("accept", "application/json") - if c.userAgent != "" { - httpReq.Header.Set("User-Agent", c.userAgent) - } - if c.tokenFn != nil { - token, err := c.tokenFn(ctx) - if err != nil { - return nil, fmt.Errorf("token: %w", err) - } - if token != "" { - httpReq.Header.Set("Authorization", "Bearer "+token) - } - } return httpReq, nil } @@ -235,18 +271,6 @@ func (c *Client) Stream(ctx context.Context, path string, req any) (*StreamReade httpReq.Header.Set("content-type", "application/connect+json") httpReq.Header.Set("accept", "application/connect+json") httpReq.Header.Set("connect-protocol-version", "1") - if c.userAgent != "" { - httpReq.Header.Set("User-Agent", c.userAgent) - } - if c.tokenFn != nil { - token, err := c.tokenFn(ctx) - if err != nil { - return nil, fmt.Errorf("token: %w", err) - } - if token != "" { - httpReq.Header.Set("Authorization", "Bearer "+token) - } - } httpResp, err := c.httpClient.Do(httpReq) if err != nil { if ctxErr := ctx.Err(); ctxErr != nil { @@ -261,3 +285,39 @@ func (c *Client) Stream(ctx context.Context, path string, req any) (*StreamReade } return newStreamReader(httpResp.Body), nil } + +// DoRaw performs an authenticated HTTP request and returns the raw response +// status + body. An empty or nil body is treated uniformly: no request body +// is sent and Content-Type is omitted (so `--data ""` at the verb layer +// behaves like "no --data" rather than "zero-byte JSON"). Auth + User-Agent +// are applied by the client's bearerTransport middleware. No status-code +// interpretation — the caller decides how to handle non-2xx. Intended for +// the `ana api` raw verb; typed verbs should keep using Unary. +func (c *Client) DoRaw(ctx context.Context, method, path string, body []byte) (int, []byte, error) { + url := joinURL(c.baseURL, path) + var r io.Reader + if len(body) > 0 { + r = bytes.NewReader(body) + } + req, err := http.NewRequestWithContext(ctx, method, url, r) + if err != nil { + return 0, nil, fmt.Errorf("build request: %w", err) + } + if len(body) > 0 { + req.Header.Set("content-type", "application/json") + } + req.Header.Set("accept", "application/json") + resp, err := c.httpClient.Do(req) + if err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + return 0, nil, fmt.Errorf("doraw: %w", ctxErr) + } + return 0, nil, fmt.Errorf("doraw: %w", err) + } + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + return resp.StatusCode, nil, fmt.Errorf("read response: %w", err) + } + return resp.StatusCode, b, nil +} diff --git a/internal/transport/client_test.go b/internal/transport/client_test.go index fa16b44..b1d3ada 100644 --- a/internal/transport/client_test.go +++ b/internal/transport/client_test.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" @@ -276,6 +277,23 @@ func TestUnaryUserAgent(t *testing.T) { } } +// roundTripFn adapts a function into an http.RoundTripper. Used by tests that +// need to inspect the forwarded request without maintaining a named type. +type roundTripFn func(*http.Request) (*http.Response, error) + +func (f roundTripFn) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +// mustParseURL parses raw or t.Fatals — collapses the noise in tests that +// hand-construct an *http.Request. +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + u, err := url.Parse(raw) + if err != nil { + t.Fatalf("parse %q: %v", raw, err) + } + return u +} + // recordingRT captures the last outbound request and returns a canned 200. type recordingRT struct { lastReq *http.Request @@ -313,10 +331,23 @@ func TestWithHTTPClient(t *testing.T) { func TestWithHTTPClientNilIgnored(t *testing.T) { t.Parallel() - // A nil http.Client must not replace the default (guarded in WithHTTPClient). + // A nil http.Client must not crash or install a bad client; the + // bearerTransport middleware wraps http.DefaultTransport in that case + // (we clone DefaultClient so DefaultClient itself stays unmutated). c := New("http://example.invalid", nil, WithHTTPClient(nil)) - if c.httpClient != http.DefaultClient { - t.Fatalf("nil HTTP client replaced default") + if c.httpClient == http.DefaultClient { + t.Fatalf("expected a clone, not DefaultClient itself") + } + bt, ok := c.httpClient.Transport.(*bearerTransport) + if !ok { + t.Fatalf("expected bearerTransport wrap, got %T", c.httpClient.Transport) + } + if bt.next != http.DefaultTransport { + t.Fatalf("expected next == http.DefaultTransport, got %T", bt.next) + } + // DefaultClient's Transport must remain untouched (nil). + if http.DefaultClient.Transport != nil { + t.Fatalf("http.DefaultClient.Transport mutated") } } @@ -488,3 +519,274 @@ func TestUnaryTokenErrorNoServerCall(t *testing.T) { t.Fatalf("server hit despite token error") } } + +// TestBearerTransportPreservesUserAgent verifies a caller-supplied User-Agent +// is not clobbered by the middleware (existing header wins). +func TestBearerTransportPreservesUserAgent(t *testing.T) { + t.Parallel() + _, c := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("User-Agent"); got != "caller/1.0" { + t.Errorf("User-Agent = %q, want caller/1.0", got) + } + w.Write([]byte("{}")) + }, WithUserAgent("ana/0.0.1")) + + // Drive a raw request with a pre-set User-Agent to prove the middleware + // doesn't overwrite it. + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, c.baseURL, nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + req.Header.Set("User-Agent", "caller/1.0") + resp, err := c.httpClient.Do(req) + if err != nil { + t.Fatalf("Do: %v", err) + } + resp.Body.Close() +} + +// TestBearerTransportNilHeader covers the defensive nil-Header init in +// RoundTrip. http.Client.Do pre-initializes nil Headers before calling a +// RoundTripper, so the branch is only reachable by invoking the middleware +// directly — which is exactly what would happen if a test harness stacked +// our transport inside another RoundTripper and handed it a hand-built +// request. +func TestBearerTransportNilHeader(t *testing.T) { + t.Parallel() + var saw http.Header + c := New("http://example.invalid", staticToken("tok"), + WithHTTPClient(&http.Client{Transport: roundTripFn(func(r *http.Request) (*http.Response, error) { + saw = r.Header + return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("{}")), Header: make(http.Header), Request: r}, nil + })})) + bt, ok := c.httpClient.Transport.(*bearerTransport) + if !ok { + t.Fatalf("expected bearerTransport, got %T", c.httpClient.Transport) + } + req := &http.Request{Method: http.MethodGet, URL: mustParseURL(t, "http://example.invalid/x"), Header: nil} + resp, err := bt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip: %v", err) + } + resp.Body.Close() + if got := saw.Get("Authorization"); got != "Bearer tok" { + t.Fatalf("Authorization = %q on forwarded request", got) + } + if req.Header != nil { + t.Fatalf("incoming request was mutated (Header set); RoundTripper contract broken") + } +} + +// TestBearerTransportPreservesAuthorization verifies a caller-supplied +// Authorization header is not clobbered by the middleware (existing header +// wins). +func TestBearerTransportPreservesAuthorization(t *testing.T) { + t.Parallel() + _, c := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer caller-token" { + t.Errorf("Authorization = %q, want Bearer caller-token", got) + } + w.Write([]byte("{}")) + }) + c.tokenFn = staticToken("middleware-token") + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, c.baseURL, nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + req.Header.Set("Authorization", "Bearer caller-token") + resp, err := c.httpClient.Do(req) + if err != nil { + t.Fatalf("Do: %v", err) + } + resp.Body.Close() +} + +// TestWithHTTPClientWrapsCallerTransport proves a caller-supplied +// *http.Client with its own Transport still gets auth applied: the middleware +// wraps the caller's Transport, it isn't swapped out. +func TestWithHTTPClientWrapsCallerTransport(t *testing.T) { + t.Parallel() + rt := &recordingRT{} + c := New("http://example.invalid", staticToken("tok"), + WithHTTPClient(&http.Client{Transport: rt})) + if err := c.Unary(context.Background(), "/x", nil, nil); err != nil { + t.Fatalf("Unary: %v", err) + } + if rt.lastReq == nil { + t.Fatalf("caller's Transport was not reached") + } + if got := rt.lastReq.Header.Get("Authorization"); got != "Bearer tok" { + t.Fatalf("Authorization = %q, want Bearer tok", got) + } +} + +// TestDoRawGET exercises the happy GET path: no body sent, method + path +// preserved, response bytes returned verbatim. +func TestDoRawGET(t *testing.T) { + t.Parallel() + _, c := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("method = %q, want GET", r.Method) + } + if r.URL.Path != "/v1/things" { + t.Errorf("path = %q", r.URL.Path) + } + if _, ok := r.Header["Content-Type"]; ok { + t.Errorf("unexpected content-type on no-body request") + } + if r.ContentLength != 0 { + t.Errorf("ContentLength = %d, want 0 (nil body)", r.ContentLength) + } + w.Header().Set("content-type", "application/json") + w.Write([]byte(`{"ok":true}`)) + }) + status, body, err := c.DoRaw(context.Background(), http.MethodGet, "/v1/things", nil) + if err != nil { + t.Fatalf("DoRaw: %v", err) + } + if status != 200 { + t.Errorf("status = %d", status) + } + if string(body) != `{"ok":true}` { + t.Errorf("body = %q", body) + } +} + +// TestDoRawPOSTWithBody covers a body-bearing POST: Content-Type is set and +// the body bytes are round-tripped. +func TestDoRawPOSTWithBody(t *testing.T) { + t.Parallel() + _, c := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if ct := r.Header.Get("content-type"); ct != "application/json" { + t.Errorf("content-type = %q", ct) + } + got := drainBody(t, r) + if string(got) != `{"x":1}` { + t.Errorf("body = %q", got) + } + w.WriteHeader(200) + w.Write(got) // echo + }) + status, body, err := c.DoRaw(context.Background(), http.MethodPost, "/rpc/foo", []byte(`{"x":1}`)) + if err != nil { + t.Fatalf("DoRaw: %v", err) + } + if status != 200 || string(body) != `{"x":1}` { + t.Fatalf("status=%d body=%q", status, body) + } +} + +// TestDoRawNon2xxPassthrough: the body is returned intact on a 4xx/5xx; no +// error-envelope parsing happens at this layer (DoRaw's whole job). +func TestDoRawNon2xxPassthrough(t *testing.T) { + t.Parallel() + _, c := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"code":"not_found","message":"gone"}`)) + }) + status, body, err := c.DoRaw(context.Background(), http.MethodGet, "/missing", nil) + if err != nil { + t.Fatalf("DoRaw: %v", err) + } + if status != 404 { + t.Errorf("status = %d", status) + } + if string(body) != `{"code":"not_found","message":"gone"}` { + t.Errorf("body = %q", body) + } +} + +// TestDoRawAuthApplied: proves the middleware attaches bearer auth to DoRaw +// requests too — the whole point of the refactor. +func TestDoRawAuthApplied(t *testing.T) { + t.Parallel() + _, c := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer tok" { + t.Errorf("Authorization = %q", got) + } + w.Write([]byte("{}")) + }) + c.tokenFn = staticToken("tok") + if _, _, err := c.DoRaw(context.Background(), http.MethodGet, "/", nil); err != nil { + t.Fatalf("DoRaw: %v", err) + } +} + +// TestDoRawBuildRequestError covers the http.NewRequestWithContext failure +// branch — a control-byte URL is rejected by stdlib. +func TestDoRawBuildRequestError(t *testing.T) { + t.Parallel() + c := New("http://\x7f/bad", nil) + _, _, err := c.DoRaw(context.Background(), http.MethodGet, "/x", nil) + if err == nil || !strings.Contains(err.Error(), "build request") { + t.Fatalf("want build request error, got %v", err) + } +} + +// TestDoRawTransportError covers the non-context transport error branch. +func TestDoRawTransportError(t *testing.T) { + t.Parallel() + want := errors.New("dial fail") + c := New("http://example.invalid", nil, + WithHTTPClient(&http.Client{Transport: doErrRT{err: want}})) + _, _, err := c.DoRaw(context.Background(), http.MethodGet, "/", nil) + if err == nil || !errors.Is(err, want) { + t.Fatalf("want wrapped %v, got %v", want, err) + } +} + +// TestDoRawContextCancel covers the ctx-cancelled branch of DoRaw. +func TestDoRawContextCancel(t *testing.T) { + t.Parallel() + block := make(chan struct{}) + _, c := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + <-block + }) + t.Cleanup(func() { close(block) }) + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { + _, _, err := c.DoRaw(ctx, http.MethodGet, "/", nil) + done <- err + }() + time.Sleep(20 * time.Millisecond) + cancel() + select { + case err := <-done: + if !errors.Is(err, context.Canceled) { + t.Fatalf("want context.Canceled, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("DoRaw did not return after cancel") + } +} + +// TestDoRawReadBodyError covers the io.ReadAll failure path on a 200 body. +func TestDoRawReadBodyError(t *testing.T) { + t.Parallel() + c := New("http://example.invalid", nil, + WithHTTPClient(&http.Client{Transport: readErrRT{}})) + status, _, err := c.DoRaw(context.Background(), http.MethodGet, "/", nil) + if err == nil || !strings.Contains(err.Error(), "read response") { + t.Fatalf("want read response error, got %v", err) + } + if status != 200 { + t.Errorf("status = %d, want 200 (surfaced even on read err)", status) + } +} + +// TestDoRawTokenError covers the tokenFn-error branch propagating through the +// middleware → http.Client.Do → DoRaw wrap. +func TestDoRawTokenError(t *testing.T) { + t.Parallel() + tokenErr := errors.New("no creds") + _, c := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + t.Errorf("server should not have been called") + }) + c.tokenFn = func(context.Context) (string, error) { return "", tokenErr } + _, _, err := c.DoRaw(context.Background(), http.MethodGet, "/", nil) + if err == nil || !errors.Is(err, tokenErr) { + t.Fatalf("want wrapped %v, got %v", tokenErr, err) + } +}