diff --git a/middleware/heartbeat.go b/middleware/heartbeat.go index f36e8ccf..5f620074 100644 --- a/middleware/heartbeat.go +++ b/middleware/heartbeat.go @@ -3,6 +3,8 @@ package middleware import ( "net/http" "strings" + + "github.com/go-chi/chi/v5" ) // Heartbeat endpoint middleware useful to setting up a path like @@ -12,7 +14,9 @@ import ( func Heartbeat(endpoint string) func(http.Handler) http.Handler { f := func(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { - if (r.Method == "GET" || r.Method == "HEAD") && strings.EqualFold(r.URL.Path, endpoint) { + rctx := chi.RouteContext(r.Context()) + routePath := rctx.RoutePath + if (r.Method == "GET" || r.Method == "HEAD") && (strings.EqualFold(routePath, endpoint)) || strings.EqualFold(r.URL.Path, endpoint) { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte(".")) diff --git a/middleware/heartbeat_test.go b/middleware/heartbeat_test.go new file mode 100644 index 00000000..141939f3 --- /dev/null +++ b/middleware/heartbeat_test.go @@ -0,0 +1,47 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" +) + +func TestHeartbeat(t *testing.T) { + r := chi.NewRouter() + r.Use(Heartbeat("/ping")) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "yes") + w.Write([]byte("bye")) + }) + r.Mount("/sub", NewSubrouter()) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/", nil); body != "bye" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/ping", nil); body != "." { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/sub", nil); body != "bye" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/sub/ping", nil); body != "." { + t.Fatalf(body) + } +} + +func NewSubrouter() chi.Router { + r := chi.NewRouter() + r.Use(Heartbeat("/ping")) + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "yes") + w.Write([]byte("bye")) + }) + + return r +}