From 3a800ab2bf11afe3a3f3f38bdb76b30906bc0708 Mon Sep 17 00:00:00 2001 From: aibuddy Date: Tue, 19 Aug 2025 16:28:40 +0000 Subject: [PATCH] tool(http_fetch): add tool implementation, tests, docs, and testutil builder --- docs/reference/http_fetch.md | 87 +++++++ tools/cmd/http_fetch/http_fetch.go | 297 ++++++++++++++++++++++++ tools/cmd/http_fetch/http_fetch_test.go | 268 +++++++++++++++++++++ tools/testutil/buildtool.go | 86 +++++++ tools/testutil/buildtool_test.go | 21 ++ tools/testutil/tempdir.go | 25 ++ 6 files changed, 784 insertions(+) create mode 100644 docs/reference/http_fetch.md create mode 100644 tools/cmd/http_fetch/http_fetch.go create mode 100644 tools/cmd/http_fetch/http_fetch_test.go create mode 100644 tools/testutil/buildtool.go create mode 100644 tools/testutil/buildtool_test.go create mode 100644 tools/testutil/tempdir.go diff --git a/docs/reference/http_fetch.md b/docs/reference/http_fetch.md new file mode 100644 index 0000000..a4fb98b --- /dev/null +++ b/docs/reference/http_fetch.md @@ -0,0 +1,87 @@ +# HTTP fetch tool (http_fetch) + +Safe HTTP/HTTPS fetcher with hard byte caps, limited redirects, optional gzip decompression, and SSRF guard. The tool streams JSON over stdin/stdout; errors are single-line JSON on stderr with non-zero exit. + +## Contracts + +- Stdin: single JSON object +- Stdout (success): single-line JSON object +- Stderr (failure): single-line JSON `{ "error": "...", "hint?": "..." }` and non-zero exit + +### Parameters + +- `url` (string, required): http/https URL +- `method` (string, optional): `GET` or `HEAD` (default `GET`) +- `max_bytes` (int, optional): hard byte cap for response body (default 1048576) +- `timeout_ms` (int, optional): request timeout in milliseconds (default 10000; falls back to `HTTP_TIMEOUT_MS` env if unset) +- `decompress` (bool, optional): when true (default), enables transparent gzip decoding; when false, returns raw bytes + +### Output + +``` +{ + "status": 200, + "headers": {"Content-Type": "text/plain; charset=utf-8", "ETag": "\"abc123\""}, + "body_base64": "...", + "truncated": false +} +``` + +### Example: GET + +Input to stdin: + +```json +{"url": "https://example.org/robots.txt", "max_bytes": 65536} +``` + +### Example: HEAD + +Input to stdin: + +```json +{"url": "https://example.org/", "method": "HEAD"} +``` + +## Behavior + +- Schemes: only `http` and `https` are allowed +- Redirects: up to 5 redirects are followed; further redirects fail with `"too many redirects"` +- Headers: response headers are returned as a simple string map; `ETag` and `Last-Modified` are preserved when present +- Decompression: gzip decoding is enabled by default; set `decompress=false` to receive raw compressed bytes +- Byte cap: responses are read with a strict byte cap; when exceeded, `truncated=true` and the body is cut at `max_bytes` +- User-Agent: `agentcli-http-fetch/0.1` + +## Security (SSRF guard) + +- Blocks loopback, RFC1918, link-local, and IPv6 ULA destinations +- Blocks `.onion` hosts +- Redirect targets are re-validated +- For tests/local-only usage, setting `HTTP_FETCH_ALLOW_LOCAL=1` disables the block + +## Environment + +- `HTTP_TIMEOUT_MS` (optional): default timeout in milliseconds when `timeout_ms` is unset + +## Audit + +On each run, an NDJSON line is appended under `.goagent/audit/YYYYMMDD.log` with fields: + +``` +{tool:"http_fetch",url_host,status,bytes,truncated,ms} +``` + +## Manifest + +Ensure an entry similar to the following exists in `tools.json`: + +```json +{ + "name": "http_fetch", + "description": "Safe HTTP/HTTPS fetcher with byte cap and redirects", + "schema": {"type": "object", "required": ["url"], "properties": {"url": {"type": "string"}, "method": {"type": "string", "enum": ["GET", "HEAD"]}, "max_bytes": {"type": "integer", "minimum": 1, "default": 1048576}, "timeout_ms": {"type": "integer", "minimum": 1, "default": 10000}, "decompress": {"type": "boolean", "default": true}}, "additionalProperties": false}, + "command": ["./tools/bin/http_fetch"], + "timeoutSec": 15, + "envPassthrough": ["HTTP_TIMEOUT_MS"] +} +``` diff --git a/tools/cmd/http_fetch/http_fetch.go b/tools/cmd/http_fetch/http_fetch.go new file mode 100644 index 0000000..1cb1cb3 --- /dev/null +++ b/tools/cmd/http_fetch/http_fetch.go @@ -0,0 +1,297 @@ +package main + +import ( + "bufio" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" +) + +type input struct { + URL string `json:"url"` + Method string `json:"method"` + MaxBytes int `json:"max_bytes"` + TimeoutMs int `json:"timeout_ms"` + Decompress *bool `json:"decompress"` +} + +type output struct { + Status int `json:"status"` + Headers map[string]string `json:"headers"` + BodyBase64 string `json:"body_base64,omitempty"` + Truncated bool `json:"truncated"` +} + +func main() { + if err := run(); err != nil { + msg := strings.ReplaceAll(err.Error(), "\n", " ") + fmt.Fprintf(os.Stderr, "{\"error\":%q}\n", msg) + os.Exit(1) + } +} + +func run() error { + in, err := decodeInput() + if err != nil { + return err + } + method, u, maxBytes, timeout, decompress, err := prepareRequestParams(in) + if err != nil { + return err + } + client := newHTTPClient(timeout, decompress) + req, err := http.NewRequest(method, in.URL, nil) + if err != nil { + return fmt.Errorf("new request: %w", err) + } + req.Header.Set("User-Agent", "agentcli-http-fetch/0.1") + + start := time.Now() + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("http: %w", err) + } + defer func() { + _ = resp.Body.Close() //nolint:errcheck + }() + + headers := collectHeaders(resp.Header) + bodyB64, truncated, bodyBytes, err := maybeReadBody(method, resp.Body, maxBytes) + if err != nil { + return err + } + + out := output{Status: resp.StatusCode, Headers: headers, BodyBase64: bodyB64, Truncated: truncated} + if err := json.NewEncoder(os.Stdout).Encode(out); err != nil { + return fmt.Errorf("encode json: %w", err) + } + // Best-effort audit. Failures are ignored. + _ = appendAudit(map[string]any{ //nolint:errcheck + "ts": time.Now().UTC().Format(time.RFC3339Nano), + "tool": "http_fetch", + "url_host": u.Hostname(), + "status": resp.StatusCode, + "bytes": bodyBytes, + "truncated": truncated, + "ms": time.Since(start).Milliseconds(), + }) + return nil +} + +func decodeInput() (input, error) { + var in input + dec := json.NewDecoder(bufio.NewReader(os.Stdin)) + if err := dec.Decode(&in); err != nil { + return in, fmt.Errorf("parse json: %w", err) + } + return in, nil +} + +func prepareRequestParams(in input) (method string, u *url.URL, maxBytes int, timeout time.Duration, decompress bool, err error) { + if strings.TrimSpace(in.URL) == "" { + return "", nil, 0, 0, false, errors.New("url is required") + } + u, err = url.Parse(in.URL) + if err != nil || (u.Scheme != "http" && u.Scheme != "https") { + return "", nil, 0, 0, false, errors.New("only http/https are allowed") + } + method = strings.ToUpper(strings.TrimSpace(in.Method)) + if method == "" { + method = http.MethodGet + } + if method != http.MethodGet && method != http.MethodHead { + return "", nil, 0, 0, false, errors.New("method must be GET or HEAD") + } + maxBytes = in.MaxBytes + if maxBytes <= 0 { + maxBytes = 1 << 20 // default 1 MiB + } + timeout = resolveTimeout(in.TimeoutMs) + // Enforce SSRF guard before any request and on every redirect target. + if err = ssrfGuard(u); err != nil { + return "", nil, 0, 0, false, err + } + decompress = true + if in.Decompress != nil { + decompress = *in.Decompress + } + return method, u, maxBytes, timeout, decompress, nil +} + +func resolveTimeout(timeoutMs int) time.Duration { + timeout := time.Duration(timeoutMs) * time.Millisecond + if timeout > 0 { + return timeout + } + if v := strings.TrimSpace(os.Getenv("HTTP_TIMEOUT_MS")); v != "" { + if ms, perr := time.ParseDuration(v + "ms"); perr == nil { + timeout = ms + } + } + if timeout <= 0 { + timeout = 10 * time.Second + } + return timeout +} + +func newHTTPClient(timeout time.Duration, decompress bool) *http.Client { + tr := &http.Transport{DisableCompression: !decompress} + return &http.Client{Timeout: timeout, Transport: tr, CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 5 { + return errors.New("too many redirects") + } + return ssrfGuard(req.URL) + }} +} + +func collectHeaders(h http.Header) map[string]string { + headers := make(map[string]string, len(h)) + for k, v := range h { + if len(v) > 0 { + headers[k] = v[0] + } else { + headers[k] = "" + } + } + return headers +} + +func maybeReadBody(method string, r io.Reader, maxBytes int) (bodyB64 string, truncated bool, bodyBytes int, err error) { + if method == http.MethodHead { + return "", false, 0, nil + } + limited := io.LimitedReader{R: r, N: int64(maxBytes) + 1} + data, rerr := io.ReadAll(&limited) + if rerr != nil { + return "", false, 0, fmt.Errorf("read body: %w", rerr) + } + if int64(len(data)) > int64(maxBytes) { + truncated = true + data = data[:maxBytes] + } + bodyBytes = len(data) + bodyB64 = base64.StdEncoding.EncodeToString(data) + return bodyB64, truncated, bodyBytes, nil +} + +// ssrfGuard blocks requests to loopback, RFC1918, link-local, and ULA addresses, +// unless HTTP_FETCH_ALLOW_LOCAL=1 is set (only used in tests). +func ssrfGuard(u *url.URL) error { + host := u.Hostname() + if host == "" { + return errors.New("invalid host") + } + if strings.HasSuffix(strings.ToLower(host), ".onion") { + return errors.New("SSRF blocked: onion domains are not allowed") + } + if os.Getenv("HTTP_FETCH_ALLOW_LOCAL") == "1" { + return nil + } + ips, err := net.LookupIP(host) + if err != nil || len(ips) == 0 { + // If DNS fails, be conservative and block + return errors.New("SSRF blocked: cannot resolve host") + } + for _, ip := range ips { + if isPrivateIP(ip) { + return errors.New("SSRF blocked: private or loopback address") + } + } + return nil +} + +func isPrivateIP(ip net.IP) bool { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + // Normalize to 16-byte form for IPv4 + if v4 := ip.To4(); v4 != nil { + ip = v4 + // 10.0.0.0/8 + if v4[0] == 10 { + return true + } + // 172.16.0.0/12 + if v4[0] == 172 && v4[1]&0xf0 == 16 { + return true + } + // 192.168.0.0/16 + if v4[0] == 192 && v4[1] == 168 { + return true + } + // 169.254.0.0/16 link-local + if v4[0] == 169 && v4[1] == 254 { + return true + } + // 127.0.0.0/8 loopback handled by IsLoopback but keep explicit + if v4[0] == 127 { + return true + } + return false + } + // IPv6 ranges: ::1 (loopback), fe80::/10 (link-local), fc00::/7 (ULA) + if ip.Equal(net.ParseIP("::1")) { + return true + } + // fe80::/10 + if ip[0] == 0xfe && (ip[1]&0xc0) == 0x80 { + return true + } + // fc00::/7 + if ip[0]&0xfe == 0xfc { + return true + } + return false +} + +// appendAudit writes an NDJSON line under .goagent/audit/YYYYMMDD.log at the repo root. +func appendAudit(entry any) error { + b, err := json.Marshal(entry) + if err != nil { + return err + } + root := moduleRoot() + dir := filepath.Join(root, ".goagent", "audit") + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + fname := time.Now().UTC().Format("20060102") + ".log" + path := filepath.Join(dir, fname) + f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + return err + } + defer func() { _ = f.Close() }() //nolint:errcheck + if _, err := f.Write(append(b, '\n')); err != nil { + return err + } + return nil +} + +// moduleRoot walks upward from CWD to the directory containing go.mod; falls back to CWD. +func moduleRoot() string { + cwd, err := os.Getwd() + if err != nil || cwd == "" { + return "." + } + dir := cwd + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + return cwd + } + dir = parent + } +} diff --git a/tools/cmd/http_fetch/http_fetch_test.go b/tools/cmd/http_fetch/http_fetch_test.go new file mode 100644 index 0000000..5853e7e --- /dev/null +++ b/tools/cmd/http_fetch/http_fetch_test.go @@ -0,0 +1,268 @@ +package main_test + +import ( + "bytes" + "compress/gzip" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "strings" + "testing" + + testutil "github.com/hyperifyio/goagent/tools/testutil" +) + +type fetchOutput struct { + Status int `json:"status"` + Headers map[string]string `json:"headers"` + BodyBase64 string `json:"body_base64,omitempty"` + Truncated bool `json:"truncated"` +} + +// TestMain enables local SSRF allowance for most tests that rely on httptest servers. +func TestMain(m *testing.M) { + if err := os.Setenv("HTTP_FETCH_ALLOW_LOCAL", "1"); err != nil { + panic(err) + } + os.Exit(m.Run()) +} + +func runFetch(t *testing.T, bin string, input any) (fetchOutput, string) { + t.Helper() + data, err := json.Marshal(input) + if err != nil { + t.Fatalf("marshal input: %v", err) + } + cmd := exec.Command(bin) + cmd.Stdin = bytes.NewReader(data) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + t.Fatalf("http_fetch failed to run: %v, stderr=%s", err, stderr.String()) + } + out := strings.TrimSpace(stdout.String()) + var parsed fetchOutput + if err := json.Unmarshal([]byte(out), &parsed); err != nil { + t.Fatalf("failed to parse http_fetch output JSON: %v; raw=%q", err, out) + } + return parsed, stderr.String() +} + +// TestHttpFetch_Get200_Basic verifies a simple GET returns status, headers, and base64 body without truncation. +func TestHttpFetch_Get200_Basic(t *testing.T) { + // Arrange a test server that returns plain text + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("ETag", "\"abc123\"") + if _, err := w.Write([]byte("hello world")); err != nil { + t.Fatalf("write: %v", err) + } + })) + defer srv.Close() + + bin := testutil.BuildTool(t, "http_fetch") + + out, _ := runFetch(t, bin, map[string]any{ + "url": srv.URL, + "max_bytes": 1 << 20, // 1 MiB cap + "timeout_ms": 2000, + "decompress": true, + }) + + if out.Status != 200 { + t.Fatalf("expected status 200, got %d", out.Status) + } + if out.Truncated { + t.Fatalf("expected truncated=false") + } + if ct := out.Headers["Content-Type"]; !strings.HasPrefix(ct, "text/plain") { + t.Fatalf("unexpected content-type: %q", ct) + } + body, err := base64.StdEncoding.DecodeString(out.BodyBase64) + if err != nil { + t.Fatalf("body_base64 not valid base64: %v", err) + } + if string(body) != "hello world" { + t.Fatalf("unexpected body: %q", string(body)) + } +} + +// TestHttpFetch_HeadRequest ensures no body is returned and headers are present. +func TestHttpFetch_HeadRequest(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodHead { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + w.Header().Set("Last-Modified", "Mon, 02 Jan 2006 15:04:05 GMT") + w.WriteHeader(204) + })) + defer srv.Close() + + bin := testutil.BuildTool(t, "http_fetch") + out, _ := runFetch(t, bin, map[string]any{ + "url": srv.URL, + "method": "HEAD", + }) + if out.Status != 204 { + t.Fatalf("expected 204, got %d", out.Status) + } + if out.BodyBase64 != "" { + t.Fatalf("expected empty body for HEAD, got %q", out.BodyBase64) + } + if out.Headers["Last-Modified"] == "" { + t.Fatalf("expected Last-Modified header present") + } +} + +// TestHttpFetch_Redirects_Limited ensures redirects are followed up to 5 and then fail. +func TestHttpFetch_Redirects_Limited(t *testing.T) { + // Chain of 6 redirects + mux := http.NewServeMux() + for i := 0; i < 6; i++ { + idx := i + path := fmt.Sprintf("/r%c", 'a'+i) + next := fmt.Sprintf("/r%c", 'a'+i+1) + mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { + if idx == 5 { + w.WriteHeader(200) + if _, err := w.Write([]byte("ok")); err != nil { + t.Fatalf("write: %v", err) + } + return + } + http.Redirect(w, r, next, http.StatusFound) + }) + } + srv := httptest.NewServer(mux) + defer srv.Close() + + bin := testutil.BuildTool(t, "http_fetch") + // Expect error due to >5 redirects + cmd := exec.Command(bin) + in := map[string]any{"url": srv.URL + "/ra", "timeout_ms": 2000} + data, err := json.Marshal(in) + if err != nil { + t.Fatalf("marshal: %v", err) + } + cmd.Stdin = bytes.NewReader(data) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + err = cmd.Run() + if err == nil { + t.Fatalf("expected error after too many redirects") + } + if !strings.Contains(stderr.String(), "too many redirects") { + t.Fatalf("expected too many redirects error, got %q", stderr.String()) + } +} + +// TestHttpFetch_GzipDecompress checks automatic gzip decoding by default. +func TestHttpFetch_GzipDecompress(t *testing.T) { + gz := func(s string) []byte { + var buf bytes.Buffer + zw := gzip.NewWriter(&buf) + if _, err := zw.Write([]byte(s)); err != nil { + t.Fatalf("gzip write: %v", err) + } + if err := zw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + return buf.Bytes() + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Encoding", "gzip") + if _, err := w.Write(gz("zipper")); err != nil { + t.Fatalf("write: %v", err) + } + })) + defer srv.Close() + + bin := testutil.BuildTool(t, "http_fetch") + + // Default: decompress=true + out, _ := runFetch(t, bin, map[string]any{"url": srv.URL}) + body, err := base64.StdEncoding.DecodeString(out.BodyBase64) + if err != nil { + t.Fatalf("decode base64: %v", err) + } + if string(body) != "zipper" { + t.Fatalf("expected decompressed body, got %q", string(body)) + } + + // With decompress=false, expect raw gzip bytes + out, _ = runFetch(t, bin, map[string]any{"url": srv.URL, "decompress": false}) + body, err = base64.StdEncoding.DecodeString(out.BodyBase64) + if err != nil { + t.Fatalf("decode base64: %v", err) + } + if string(body) == "zipper" { + t.Fatalf("expected raw gzip bytes when decompress=false") + } +} + +// TestHttpFetch_Truncation enforces max_bytes cap. +func TestHttpFetch_Truncation(t *testing.T) { + data := strings.Repeat("A", 1024) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := w.Write([]byte(data)); err != nil { + t.Fatalf("write: %v", err) + } + })) + defer srv.Close() + + bin := testutil.BuildTool(t, "http_fetch") + out, _ := runFetch(t, bin, map[string]any{"url": srv.URL, "max_bytes": 100}) + if !out.Truncated { + t.Fatalf("expected truncated=true") + } + body, err := base64.StdEncoding.DecodeString(out.BodyBase64) + if err != nil { + t.Fatalf("decode base64: %v", err) + } + if len(body) != 100 { + t.Fatalf("expected 100 bytes, got %d", len(body)) + } +} + +// TestHttpFetch_SSRF_Block_Localhost ensures SSRF guard blocks localhost by default. +func TestHttpFetch_SSRF_Block_Localhost(t *testing.T) { + bin := testutil.BuildTool(t, "http_fetch") + cmd := exec.Command(bin) + in := map[string]any{"url": "http://127.0.0.1:9"} + data, err := json.Marshal(in) + if err != nil { + t.Fatalf("marshal: %v", err) + } + cmd.Stdin = bytes.NewReader(data) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + // Ensure guard is active + // Inherit env but explicitly remove HTTP_FETCH_ALLOW_LOCAL to enforce guard + var env []string + for _, e := range os.Environ() { + if strings.HasPrefix(e, "HTTP_FETCH_ALLOW_LOCAL=") { + continue + } + env = append(env, e) + } + cmd.Env = env + err = cmd.Run() + if err == nil { + t.Fatalf("expected SSRF block error") + } + if !strings.Contains(stderr.String(), "SSRF blocked") { + t.Fatalf("expected SSRF blocked error, got %q", stderr.String()) + } +} diff --git a/tools/testutil/buildtool.go b/tools/testutil/buildtool.go new file mode 100644 index 0000000..b8a3456 --- /dev/null +++ b/tools/testutil/buildtool.go @@ -0,0 +1,86 @@ +package testutil + +import ( + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "testing" +) + +// BuildTool builds the named tool binary into a test-scoped temporary +// directory and returns the absolute path to the produced executable. +// +// Source discovery (absolute paths used to satisfy repository path hygiene +// rules in linters/tests): +// - tools/cmd/ (canonical layout only) +func BuildTool(t *testing.T, name string) string { + t.Helper() + + repoRoot, err := findRepoRoot() + if err != nil { + t.Fatalf("find repo root: %v", err) + } + + // Determine binary name with OS suffix + binName := name + if runtime.GOOS == "windows" { + binName += ".exe" + } + outPath := filepath.Join(t.TempDir(), binName) + + // Candidate source locations (canonical layout only) + var candidates []string + candidates = append(candidates, filepath.Join(repoRoot, "tools", "cmd", name)) + + var srcPath string + for _, c := range candidates { + if fi, statErr := os.Stat(c); statErr == nil { + // Accept directories and regular files + if fi.IsDir() || fi.Mode().IsRegular() { + srcPath = c + break + } + } + } + if srcPath == "" { + t.Fatalf("tool sources not found for %q under %s", name, filepath.Join(repoRoot, "tools")) + } + + cmd := exec.Command("go", "build", "-o", outPath, srcPath) + cmd.Dir = repoRoot + // Inherit environment; ensure CGO disabled for determinism + cmd.Env = append(os.Environ(), "CGO_ENABLED=0") + if output, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("build %s from %s failed: %v\n%s", name, relOrSame(repoRoot, srcPath), err, string(output)) + } + return outPath +} + +func findRepoRoot() (string, error) { + // Start from CWD and walk up until go.mod is found + start, err := os.Getwd() + if err != nil || start == "" { + return "", errors.New("cannot determine working directory") + } + dir := start + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir, nil + } + parent := filepath.Dir(dir) + if parent == dir { + return "", fmt.Errorf("go.mod not found from %s upward", start) + } + dir = parent + } +} + +func relOrSame(base, target string) string { + if rel, err := filepath.Rel(base, target); err == nil { + return rel + } + return target +} diff --git a/tools/testutil/buildtool_test.go b/tools/testutil/buildtool_test.go new file mode 100644 index 0000000..53618d3 --- /dev/null +++ b/tools/testutil/buildtool_test.go @@ -0,0 +1,21 @@ +package testutil + +import ( + "runtime" + "strings" + "testing" +) + +func TestBuildTool_WindowsSuffix(t *testing.T) { + // Use a real tool name to ensure build succeeds across environments. + path := BuildTool(t, "fs_listdir") + if runtime.GOOS == "windows" { + if !strings.HasSuffix(path, ".exe") { + t.Fatalf("expected .exe suffix on Windows, got %q", path) + } + } else { + if strings.HasSuffix(path, ".exe") { + t.Fatalf("did not expect .exe suffix on non-Windows, got %q", path) + } + } +} diff --git a/tools/testutil/tempdir.go b/tools/testutil/tempdir.go new file mode 100644 index 0000000..9fd3717 --- /dev/null +++ b/tools/testutil/tempdir.go @@ -0,0 +1,25 @@ +package testutil + +import ( + "os" + "path/filepath" + "testing" +) + +// MakeRepoRelTempDir creates a temporary directory under the current +// package working directory and returns its relative path (basename). +// The directory is removed at test cleanup. +func MakeRepoRelTempDir(t *testing.T, prefix string) string { + t.Helper() + tmpAbs, err := os.MkdirTemp(".", prefix) + if err != nil { + t.Fatalf("mkdir temp under repo: %v", err) + } + base := filepath.Base(tmpAbs) + t.Cleanup(func() { + if err := os.RemoveAll(base); err != nil { + t.Logf("cleanup remove %s: %v", base, err) + } + }) + return base +}