From 3c0d68615e55096e8f98f52b7f2efa76d63f9b20 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Mon, 12 Jan 2026 12:38:45 +1100 Subject: [PATCH] feat: add a fluent builder for creating strategy HTTP handlers This vastly simplifies the strategy handlers. Both the github-releases and host strategy have been updated to use the builder. --- .golangci.yml | 1 + internal/config/config.go | 2 +- internal/httputil/error.go | 16 +- internal/httputil/logging.go | 4 +- internal/strategy/github_releases.go | 88 ++----- internal/strategy/handler/handler.go | 202 ++++++++++++++ internal/strategy/handler/handler_test.go | 306 ++++++++++++++++++++++ internal/strategy/host.go | 62 ++--- internal/strategy/host_test.go | 6 +- 9 files changed, 566 insertions(+), 121 deletions(-) create mode 100644 internal/strategy/handler/handler.go create mode 100644 internal/strategy/handler/handler_test.go diff --git a/.golangci.yml b/.golangci.yml index edf18be..0adf0fa 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -227,6 +227,7 @@ linters: # Such cases aren't reported by default. # Default: false check-type-assertions: true + check-blank: true exhaustive: # Program elements to check for exhaustiveness. diff --git a/internal/config/config.go b/internal/config/config.go index da89c38..e686edf 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -87,7 +87,7 @@ func Load(ctx context.Context, r io.Reader, mux *http.ServeMux, vars map[string] } func expandVars(ast *hcl.AST, vars map[string]string) { - _ = hcl.Visit(ast, func(node hcl.Node, next func() error) error { + _ = hcl.Visit(ast, func(node hcl.Node, next func() error) error { //nolint:errcheck attr, ok := node.(*hcl.Attribute) if ok { switch attr := attr.Value.(type) { diff --git a/internal/httputil/error.go b/internal/httputil/error.go index 2187e5b..42aaade 100644 --- a/internal/httputil/error.go +++ b/internal/httputil/error.go @@ -17,14 +17,24 @@ func ErrorResponse(w http.ResponseWriter, r *http.Request, status int, msg strin http.Error(w, msg, status) } +// HTTPResponder is an error that knows how to write itself as an HTTP response. +type HTTPResponder interface { + error + WriteHTTP(http.ResponseWriter, *http.Request) +} + type HTTPError struct { status int err error } -func (h HTTPError) Error() string { return fmt.Sprintf("%d: %s", h.status, h.err) } -func (h HTTPError) Unwrap() error { return h.err } -func (h HTTPError) StatusCode() int { return h.status } +func (h HTTPError) Error() string { return fmt.Sprintf("%d: %s", h.status, h.err) } +func (h HTTPError) Unwrap() error { return h.err } + +// WriteHTTP writes this error as an HTTP response. +func (h HTTPError) WriteHTTP(w http.ResponseWriter, r *http.Request) { + ErrorResponse(w, r, h.status, h.err.Error()) +} func Errorf(status int, format string, args ...any) error { return HTTPError{ diff --git a/internal/httputil/logging.go b/internal/httputil/logging.go index 8266175..2555fb2 100644 --- a/internal/httputil/logging.go +++ b/internal/httputil/logging.go @@ -1,6 +1,7 @@ package httputil import ( + "fmt" "net/http" "github.com/block/sfptc/internal/logging" @@ -8,7 +9,8 @@ import ( func LoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - logger := logging.FromContext(r.Context()).With("url", r.RequestURI) + // Propagate attributes tot the handlers. + logger := logging.FromContext(r.Context()).With("request", fmt.Sprintf("%s %s", r.Method, r.RequestURI)) r = r.WithContext(logging.ContextWithLogger(r.Context(), logger)) logger.Debug("Request received") next.ServeHTTP(w, r) diff --git a/internal/strategy/github_releases.go b/internal/strategy/github_releases.go index 9df235c..0eff0f9 100644 --- a/internal/strategy/github_releases.go +++ b/internal/strategy/github_releases.go @@ -4,11 +4,8 @@ import ( "context" "encoding/json" "fmt" - "io" "log/slog" - "maps" "net/http" - "os" "slices" "github.com/alecthomas/errors" @@ -16,6 +13,7 @@ import ( "github.com/block/sfptc/internal/cache" "github.com/block/sfptc/internal/httputil" "github.com/block/sfptc/internal/logging" + "github.com/block/sfptc/internal/strategy/handler" ) func init() { @@ -31,6 +29,7 @@ type GitHubReleasesConfig struct { type GitHubReleases struct { config GitHubReleasesConfig cache cache.Cache + client *http.Client } // NewGitHubReleases creates a [Strategy] that fetches private (and public) release binaries from GitHub. @@ -38,13 +37,29 @@ func NewGitHubReleases(ctx context.Context, config GitHubReleasesConfig, cache c s := &GitHubReleases{ config: config, cache: cache, + client: http.DefaultClient, } logger := logging.FromContext(ctx) if config.Token == "" { logger.WarnContext(ctx, "No token configured for github-releases strategy") } // eg. https://github.com/alecthomas/chroma/releases/download/v2.21.1/chroma-2.21.1-darwin-amd64.tar.gz - mux.Handle("GET /github.com/{org}/{repo}/releases/download/{release}/{file}", http.HandlerFunc(s.fetch)) + h := handler.New(s.client, cache). + CacheKey(func(r *http.Request) string { + org := r.PathValue("org") + repo := r.PathValue("repo") + release := r.PathValue("release") + file := r.PathValue("file") + return fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s", org, repo, release, file) + }). + Transform(func(r *http.Request) (*http.Request, error) { + org := r.PathValue("org") + repo := r.PathValue("repo") + release := r.PathValue("release") + file := r.PathValue("file") + return s.downloadRelease(r.Context(), org, repo, release, file) + }) + mux.Handle("GET /github.com/{org}/{repo}/releases/download/{release}/{file}", h) return s, nil } @@ -52,69 +67,6 @@ var _ Strategy = (*GitHubReleases)(nil) func (g *GitHubReleases) String() string { return "github-releases" } -func (g *GitHubReleases) fetch(w http.ResponseWriter, r *http.Request) { - org := r.PathValue("org") - repo := r.PathValue("repo") - release := r.PathValue("release") - file := r.PathValue("file") - ghURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s", org, repo, release, file) - - logger := logging.FromContext(r.Context()).With("upstream", ghURL) - - key := cache.NewKey(ghURL) - - logger.Debug("Fetching GitHub release") - - // Check if the key exists in the cache - cr, headers, err := g.cache.Open(r.Context(), key) - if err == nil { - logger.Debug("Cache hit") - // Cache hit - stream directly from cache - defer cr.Close() - maps.Copy(w.Header(), headers) - if _, err := io.Copy(w, cr); err != nil { - httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to stream from cache", "error", err.Error()) - return - } - return - } - if !errors.Is(err, os.ErrNotExist) { - httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to open cache", "error", err.Error()) - return - } - - // Cache miss - fetch from GitHub and stream while caching - req, err := g.downloadRelease(r.Context(), org, repo, release, file) - if err != nil { - if herr, ok := errors.AsType[httputil.HTTPError](err); ok { - httputil.ErrorResponse(w, r, herr.StatusCode(), herr.Error(), "upstream", ghURL) - } else { - httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to create download request", "error", err.Error()) - } - return - } - - response, err := cache.FetchDirect(http.DefaultClient, req, g.cache, key) - if err != nil { - if herr, ok := errors.AsType[httputil.HTTPError](err); ok { - httputil.ErrorResponse(w, r, herr.StatusCode(), herr.Error()) - } else { - httputil.ErrorResponse(w, r, http.StatusInternalServerError, err.Error()) - } - return - } - defer response.Body.Close() - if response.StatusCode != http.StatusOK { - httputil.ErrorResponse(w, r, response.StatusCode, response.Status) - return - } - maps.Copy(w.Header(), response.Header) - if _, err := io.Copy(w, response.Body); err != nil { - httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to stream response", "error", err.Error()) - return - } -} - // newGitHubRequest creates a new HTTP request with GitHub API headers and authentication. func (g *GitHubReleases) newGitHubRequest(ctx context.Context, url, accept string) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) @@ -160,7 +112,7 @@ func (g *GitHubReleases) downloadRelease(ctx context.Context, org, repo, release return nil, httputil.Errorf(http.StatusInternalServerError, "create API request") } - resp, err := http.DefaultClient.Do(req) + resp, err := g.client.Do(req) if err != nil { return nil, httputil.Errorf(http.StatusBadGateway, "fetch release info failed: %w", err) } diff --git a/internal/strategy/handler/handler.go b/internal/strategy/handler/handler.go new file mode 100644 index 0000000..367b5a0 --- /dev/null +++ b/internal/strategy/handler/handler.go @@ -0,0 +1,202 @@ +package handler + +import ( + "io" + "log/slog" + "maps" + "net/http" + "net/textproto" + "os" + "time" + + "github.com/alecthomas/errors" + + "github.com/block/sfptc/internal/cache" + "github.com/block/sfptc/internal/httputil" + "github.com/block/sfptc/internal/logging" +) + +// Handler provides a fluent API for creating cache-backed HTTP handlers. +// +// Example usage: +// +// h := handler.New(client, cache). +// CacheKey(func(r *http.Request) string { +// return "custom-key" +// }). +// Transform(func(r *http.Request) (*http.Request, error) { +// // Modify request before fetching +// return modifiedRequest, nil +// }) +type Handler struct { + client *http.Client + cache cache.Cache + cacheKeyFunc func(*http.Request) string + transformFunc func(*http.Request) (*http.Request, error) + errorHandler func(error, http.ResponseWriter, *http.Request) + ttlFunc func(*http.Request) time.Duration +} + +// New creates a new Handler with the given HTTP client and cache. +// By default: +// - Cache key is derived from the request URL +// - No request transformation is performed +// - Standard error handling is used. +func New(client *http.Client, c cache.Cache) *Handler { + return &Handler{ + client: client, + cache: c, + cacheKeyFunc: func(r *http.Request) string { + return r.URL.String() + }, + transformFunc: func(r *http.Request) (*http.Request, error) { + return r, nil + }, + errorHandler: defaultErrorHandler, + ttlFunc: func(_ *http.Request) time.Duration { + return 0 + }, + } +} + +// CacheKey sets the function used to determine the cache key for a request. +// The function receives the original incoming request. +func (h *Handler) CacheKey(f func(*http.Request) string) *Handler { + h.cacheKeyFunc = f + return h +} + +// Transform sets the function used to transform the incoming request before fetching. +// This is where you can modify the request URL, headers, etc. +// The function receives the original incoming request and should return the request +// that will be sent to the upstream server. +func (h *Handler) Transform(f func(*http.Request) (*http.Request, error)) *Handler { + h.transformFunc = f + return h +} + +// OnError sets a custom error handler for the built handler. +// If not set, a default error handler is used. +func (h *Handler) OnError(f func(error, http.ResponseWriter, *http.Request)) *Handler { + h.errorHandler = f + return h +} + +// TTL sets the function used to determine the cache TTL for a request. +// The function receives the original incoming request. +// If not set or returns 0, the cache's default/maximum TTL is used. +func (h *Handler) TTL(f func(*http.Request) time.Duration) *Handler { + h.ttlFunc = f + return h +} + +// ServeHTTP implements http.Handler. +// The handler will: +// 1. Determine the cache key using the configured function +// 2. Check if the content exists in cache +// 3. If cached, stream from cache +// 4. If not cached, transform the request and fetch from upstream +// 5. Cache the response while streaming to the client. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + logger := logging.FromContext(r.Context()) + + cacheKeyStr := h.cacheKeyFunc(r) + key := cache.NewKey(cacheKeyStr) + + logger.DebugContext(r.Context(), "Processing request", slog.String("cache_key", cacheKeyStr)) + + if h.serveCached(w, r, key, logger) { + return + } + + h.fetchAndCache(w, r, key, logger) +} + +func (h *Handler) serveCached(w http.ResponseWriter, r *http.Request, key cache.Key, logger *slog.Logger) bool { + cr, headers, err := h.cache.Open(r.Context(), key) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + h.errorHandler(httputil.Errorf(http.StatusInternalServerError, "failed to open cache: %w", err), w, r) + return true + } + return false + } + + logger.DebugContext(r.Context(), "Cache hit") + defer cr.Close() + maps.Copy(w.Header(), headers) + if _, err := io.Copy(w, cr); err != nil { + logger.ErrorContext(r.Context(), "Failed to stream from cache", slog.String("error", err.Error())) + httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to stream from cache", "error", err.Error()) + } + return true +} + +func (h *Handler) fetchAndCache(w http.ResponseWriter, r *http.Request, key cache.Key, logger *slog.Logger) { + logger.DebugContext(r.Context(), "Cache miss, fetching from upstream") + + upstreamReq, err := h.transformFunc(r) + if err != nil { + h.errorHandler(err, w, r) + return + } + + resp, err := h.client.Do(upstreamReq) + if err != nil { + h.errorHandler(httputil.Errorf(http.StatusBadGateway, "failed to fetch: %w", err), w, r) + return + } + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + logger.ErrorContext(r.Context(), "Failed to close response body", slog.String("error", closeErr.Error())) + } + }() + + if resp.StatusCode != http.StatusOK { + h.streamNonOKResponse(w, resp, logger) + return + } + + h.streamAndCache(w, r, key, resp, logger) +} + +func (h *Handler) streamNonOKResponse(w http.ResponseWriter, resp *http.Response, logger *slog.Logger) { + w.WriteHeader(resp.StatusCode) + if _, err := io.Copy(w, resp.Body); err != nil { + logger.ErrorContext(resp.Request.Context(), "Failed to stream error response", slog.String("error", err.Error())) + } +} + +func (h *Handler) streamAndCache(w http.ResponseWriter, r *http.Request, key cache.Key, resp *http.Response, logger *slog.Logger) { + ttl := h.ttlFunc(r) + responseHeaders := textproto.MIMEHeader(maps.Clone(resp.Header)) + cw, err := h.cache.Create(r.Context(), key, responseHeaders, ttl) + if err != nil { + h.errorHandler(httputil.Errorf(http.StatusInternalServerError, "failed to create cache entry: %w", err), w, r) + return + } + + pr, pw := io.Pipe() + go func() { + mw := io.MultiWriter(pw, cw) + _, copyErr := io.Copy(mw, resp.Body) + closeErr := errors.Join(cw.Close(), resp.Body.Close()) + pw.CloseWithError(errors.Join(copyErr, closeErr)) + }() + + maps.Copy(w.Header(), resp.Header) + if _, err := io.Copy(w, pr); err != nil { + logger.ErrorContext(r.Context(), "Failed to stream response", slog.String("error", err.Error())) + } + if closeErr := pr.Close(); closeErr != nil { + logger.ErrorContext(r.Context(), "Failed to close pipe", slog.String("error", closeErr.Error())) + } +} + +func defaultErrorHandler(err error, w http.ResponseWriter, r *http.Request) { + if h, ok := errors.AsType[httputil.HTTPResponder](err); ok { + h.WriteHTTP(w, r) + } else { + httputil.ErrorResponse(w, r, http.StatusInternalServerError, err.Error()) + } +} diff --git a/internal/strategy/handler/handler_test.go b/internal/strategy/handler/handler_test.go new file mode 100644 index 0000000..cb9f4ef --- /dev/null +++ b/internal/strategy/handler/handler_test.go @@ -0,0 +1,306 @@ +package handler_test + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/alecthomas/assert/v2" + "github.com/alecthomas/errors" + + "github.com/block/sfptc/internal/cache" + "github.com/block/sfptc/internal/httputil" + "github.com/block/sfptc/internal/logging" + "github.com/block/sfptc/internal/strategy/handler" +) + +type testRequest struct { + url string + headers map[string]string + expectStatus int + expectBody string + expectContains string +} + +func TestBuilder(t *testing.T) { + callCounts := make(map[string]int) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCounts[r.URL.Path]++ + + switch r.URL.Path { + case "/simple": + w.Header().Set("Content-Type", "text/plain") + _, _ = fmt.Fprint(w, "simple response") + case "/echo-header": + _, _ = fmt.Fprintf(w, "header: %s", r.Header.Get("X-Custom")) + case "/conditional": + if r.Header.Get("X-Private") == "true" { + _, _ = fmt.Fprint(w, "private") + } else { + _, _ = fmt.Fprint(w, "public") + } + case "/not-found": + w.WriteHeader(http.StatusNotFound) + _, _ = fmt.Fprint(w, "not found") + case "/stream": + w.Header().Set("Content-Type", "application/octet-stream") + for i := range 100 { + _, _ = fmt.Fprintf(w, "chunk %d\n", i) + } + default: + _, _ = fmt.Fprintf(w, "path: %s", r.URL.Path) + } + })) + defer upstream.Close() + + tests := []struct { + name string + buildHandler func(cache.Cache) http.Handler + requests []testRequest + expectUpstreamCalls map[string]int + }{ + { + name: "BasicFlow", + buildHandler: func(c cache.Cache) http.Handler { + return handler.New(http.DefaultClient, c). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/simple", nil) + }) + }, + requests: []testRequest{ + {url: "/test", expectStatus: http.StatusOK, expectBody: "simple response"}, + }, + expectUpstreamCalls: map[string]int{"/simple": 1}, + }, + { + name: "CacheHit", + buildHandler: func(c cache.Cache) http.Handler { + return handler.New(http.DefaultClient, c). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/simple", nil) + }) + }, + requests: []testRequest{ + {url: "/test", expectStatus: http.StatusOK, expectBody: "simple response"}, + {url: "/test", expectStatus: http.StatusOK, expectBody: "simple response"}, + }, + expectUpstreamCalls: map[string]int{"/simple": 1}, + }, + { + name: "CustomCacheKey", + buildHandler: func(c cache.Cache) http.Handler { + return handler.New(http.DefaultClient, c). + CacheKey(func(_ *http.Request) string { + return "constant-key" + }). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/path1", nil) + }) + }, + requests: []testRequest{ + {url: "/anything1", expectStatus: http.StatusOK, expectBody: "path: /path1"}, + {url: "/anything2", expectStatus: http.StatusOK, expectBody: "path: /path1"}, + }, + expectUpstreamCalls: map[string]int{"/path1": 1}, + }, + { + name: "Transform", + buildHandler: func(c cache.Cache) http.Handler { + return handler.New(http.DefaultClient, c). + Transform(func(r *http.Request) (*http.Request, error) { + upstreamReq, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/echo-header", nil) + if err != nil { + return nil, err + } + upstreamReq.Header.Set("X-Custom", "transformed") + return upstreamReq, nil + }) + }, + requests: []testRequest{ + {url: "/test", expectStatus: http.StatusOK, expectBody: "header: transformed"}, + }, + expectUpstreamCalls: map[string]int{"/echo-header": 1}, + }, + { + name: "TransformError", + buildHandler: func(c cache.Cache) http.Handler { + return handler.New(http.DefaultClient, c). + Transform(func(_ *http.Request) (*http.Request, error) { + return nil, httputil.Errorf(http.StatusBadRequest, "transform failed") + }) + }, + requests: []testRequest{ + {url: "/test", expectStatus: http.StatusBadRequest}, + }, + expectUpstreamCalls: map[string]int{}, + }, + { + name: "ConditionalTransform", + buildHandler: func(c cache.Cache) http.Handler { + return handler.New(http.DefaultClient, c). + CacheKey(func(r *http.Request) string { + return r.URL.String() + ":" + r.Header.Get("X-Private") + }). + Transform(func(r *http.Request) (*http.Request, error) { + upstreamReq, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/conditional", nil) + if err != nil { + return nil, err + } + if r.Header.Get("X-Private") == "true" { + upstreamReq.Header.Set("X-Private", "true") + } + return upstreamReq, nil + }) + }, + requests: []testRequest{ + {url: "/test", expectStatus: http.StatusOK, expectBody: "public"}, + {url: "/test", headers: map[string]string{"X-Private": "true"}, expectStatus: http.StatusOK, expectBody: "private"}, + }, + expectUpstreamCalls: map[string]int{"/conditional": 2}, + }, + { + name: "CustomErrorHandler", + buildHandler: func(c cache.Cache) http.Handler { + return handler.New(http.DefaultClient, c). + Transform(func(_ *http.Request) (*http.Request, error) { + return nil, errors.New("test error") + }). + OnError(func(err error, w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusTeapot) + _, _ = fmt.Fprint(w, "custom error: "+err.Error()) + }) + }, + requests: []testRequest{ + {url: "/test", expectStatus: http.StatusTeapot, expectContains: "custom error"}, + }, + expectUpstreamCalls: map[string]int{}, + }, + { + name: "UpstreamError", + buildHandler: func(c cache.Cache) http.Handler { + return handler.New(http.DefaultClient, c). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/not-found", nil) + }) + }, + requests: []testRequest{ + {url: "/test", expectStatus: http.StatusNotFound, expectBody: "not found"}, + }, + expectUpstreamCalls: map[string]int{"/not-found": 1}, + }, + { + name: "CacheKeyWithTransform", + buildHandler: func(c cache.Cache) http.Handler { + return handler.New(http.DefaultClient, c). + CacheKey(func(r *http.Request) string { + return "original:" + r.URL.Path + }). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/transformed", nil) + }) + }, + requests: []testRequest{ + {url: "/original", expectStatus: http.StatusOK, expectBody: "path: /transformed"}, + {url: "/original", expectStatus: http.StatusOK, expectBody: "path: /transformed"}, + }, + expectUpstreamCalls: map[string]int{"/transformed": 1}, + }, + { + name: "StreamingResponse", + buildHandler: func(c cache.Cache) http.Handler { + return handler.New(http.DefaultClient, c). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/stream", nil) + }) + }, + requests: []testRequest{ + {url: "/test", expectStatus: http.StatusOK, expectContains: "chunk 0"}, + {url: "/test", expectStatus: http.StatusOK, expectContains: "chunk 99"}, + }, + expectUpstreamCalls: map[string]int{"/stream": 1}, + }, + { + name: "CustomTTL", + buildHandler: func(c cache.Cache) http.Handler { + return handler.New(http.DefaultClient, c). + TTL(func(r *http.Request) time.Duration { + if r.Header.Get("X-Short-Cache") == "true" { + return 100 * time.Millisecond + } + return time.Hour + }). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/simple", nil) + }) + }, + requests: []testRequest{ + {url: "/test", headers: map[string]string{"X-Short-Cache": "true"}, expectStatus: http.StatusOK, expectBody: "simple response"}, + }, + expectUpstreamCalls: map[string]int{"/simple": 1}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for path := range callCounts { + delete(callCounts, path) + } + + c := mustNewMemoryCache() + handler := tt.buildHandler(c) + ctx := logging.ContextWithLogger(context.Background(), slog.Default()) + + for i, req := range tt.requests { + r := httptest.NewRequest(http.MethodGet, "http://example.com"+req.url, nil) + r = r.WithContext(ctx) + for k, v := range req.headers { + r.Header.Set(k, v) + } + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + assert.Equal(t, req.expectStatus, w.Code, "request %d status mismatch", i) + if req.expectBody != "" { + assert.Equal(t, req.expectBody, w.Body.String(), "request %d body mismatch", i) + } + if req.expectContains != "" { + assert.True(t, strings.Contains(w.Body.String(), req.expectContains), + "request %d: expected body to contain %q, got %q", i, req.expectContains, w.Body.String()) + } + } + + for path, expectedCount := range tt.expectUpstreamCalls { + assert.Equal(t, expectedCount, callCounts[path], "upstream call count mismatch for %s", path) + } + }) + } +} + +func TestHandlerMethodChaining(t *testing.T) { + c := mustNewMemoryCache() + client := &http.Client{} + + h := handler.New(client, c) + result := h. + CacheKey(func(_ *http.Request) string { return "key" }). + Transform(func(r *http.Request) (*http.Request, error) { return r, nil }). + OnError(func(_ error, _ http.ResponseWriter, _ *http.Request) {}). + TTL(func(_ *http.Request) time.Duration { return time.Hour }) + + assert.Equal(t, h, result, "methods should return the same handler instance") +} + +func mustNewMemoryCache() cache.Cache { + c, err := cache.NewMemory(context.Background(), cache.MemoryConfig{ + MaxTTL: time.Hour, + }) + if err != nil { + panic(err) + } + return c +} diff --git a/internal/strategy/host.go b/internal/strategy/host.go index efff241..3cbc692 100644 --- a/internal/strategy/host.go +++ b/internal/strategy/host.go @@ -3,17 +3,13 @@ package strategy import ( "context" "fmt" - "io" "log/slog" - "maps" "net/http" "net/url" - "github.com/alecthomas/errors" - "github.com/block/sfptc/internal/cache" - "github.com/block/sfptc/internal/httputil" "github.com/block/sfptc/internal/logging" + "github.com/block/sfptc/internal/strategy/handler" ) func init() { @@ -57,18 +53,24 @@ func NewHost(ctx context.Context, config HostConfig, cache cache.Cache, mux Mux) logger: logging.FromContext(ctx), prefix: prefix, } - mux.HandleFunc("GET "+prefix+"/", h.serveHTTP) + + hdlr := handler.New(h.client, cache). + CacheKey(func(r *http.Request) string { + return h.buildTargetURL(r).String() + }). + Transform(func(r *http.Request) (*http.Request, error) { + targetURL := h.buildTargetURL(r) + return http.NewRequestWithContext(r.Context(), http.MethodGet, targetURL.String(), nil) + }) + + mux.Handle("GET "+prefix+"/", hdlr) return h, nil } func (d *Host) String() string { return "host:" + d.target.Host + d.target.Path } -func (d *Host) serveHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - +// buildTargetURL constructs the target URL from the incoming request. +func (d *Host) buildTargetURL(r *http.Request) *url.URL { // Strip the prefix from the request path path := r.URL.Path if len(path) >= len(d.prefix) { @@ -80,40 +82,10 @@ func (d *Host) serveHTTP(w http.ResponseWriter, r *http.Request) { targetURL, err := url.Parse(d.target.String()) if err != nil { - httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to parse target URL", "error", err.Error(), "upstream", d.target.String()) - return + d.logger.Error("Failed to parse target URL", "error", err.Error(), "target", d.target.String()) + return &url.URL{} } targetURL.Path = path targetURL.RawQuery = r.URL.RawQuery - fullURL := targetURL.String() - - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, fullURL, nil) - if err != nil { - httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to create request", "error", err.Error(), "upstream", fullURL) - return - } - - resp, err := cache.Fetch(d.client, req, d.cache) - if err != nil { - if httpErr, ok := errors.AsType[httputil.HTTPError](err); ok { - httputil.ErrorResponse(w, r, httpErr.StatusCode(), httpErr.Error(), "error", httpErr.Error(), "upstream", fullURL) - } else { - httputil.ErrorResponse(w, r, http.StatusInternalServerError, "Failed to fetch", "error", err.Error(), "upstream", fullURL) - } - return - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - w.WriteHeader(resp.StatusCode) - if _, err := io.Copy(w, resp.Body); err != nil { - d.logger.Error("Failed to copy error response", "error", err.Error(), "upstream", fullURL) - } - return - } - - maps.Copy(w.Header(), resp.Header) - if _, err := io.Copy(w, resp.Body); err != nil { - d.logger.Error("Failed to copy response", "error", err.Error(), "upstream", fullURL) - } + return targetURL } diff --git a/internal/strategy/host_test.go b/internal/strategy/host_test.go index 1ec7eea..e86e812 100644 --- a/internal/strategy/host_test.go +++ b/internal/strategy/host_test.go @@ -38,7 +38,7 @@ func TestHostCaching(t *testing.T) { u, _ := url.Parse(backend.URL) reqPath := "/" + u.Host + "/test" - req1 := httptest.NewRequest(http.MethodGet, reqPath, nil) + req1 := httptest.NewRequestWithContext(ctx, http.MethodGet, reqPath, nil) w1 := httptest.NewRecorder() mux.ServeHTTP(w1, req1) @@ -46,7 +46,7 @@ func TestHostCaching(t *testing.T) { assert.Equal(t, "response", w1.Body.String()) assert.Equal(t, 1, callCount) - req2 := httptest.NewRequest(http.MethodGet, reqPath, nil) + req2 := httptest.NewRequestWithContext(ctx, http.MethodGet, reqPath, nil) w2 := httptest.NewRecorder() mux.ServeHTTP(w2, req2) @@ -75,7 +75,7 @@ func TestHostNonOKStatus(t *testing.T) { u, _ := url.Parse(backend.URL) reqPath := "/" + u.Host + "/missing" - req := httptest.NewRequest(http.MethodGet, reqPath, nil) + req := httptest.NewRequestWithContext(ctx, http.MethodGet, reqPath, nil) w := httptest.NewRecorder() mux.ServeHTTP(w, req)