diff --git a/context.go b/context.go index 1d65285..c977a51 100644 --- a/context.go +++ b/context.go @@ -456,23 +456,30 @@ func (mw wrapM) handle(c *Context) { defer func() { req.Pattern = p }() req.Pattern = c.Pattern() + r := req if route := c.Route(); route != nil && route.ParamsLen() > 0 { params := slices.AppendSeq(make(Params, 0, route.ParamsLen()), c.Params()) - ctx := context.WithValue(c.Request().Context(), paramsKey, params) - req = req.WithContext(ctx) + ctx := context.WithValue(req.Context(), paramsKey, params) + r = req.WithContext(ctx) } mw.m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Avoid allocation if w has not been wrapped by m. - rec, ok := w.(*recorder) - if !ok { + var rec *recorder + switch v := w.(type) { + case flusherWriter: + rec, _ = v.ResponseWriter.(*recorder) + case *recorder: + rec = v + } + if rec == nil { rec = new(recorder) rec.reset(w) } cc := c.CloneWith(rec, r) defer cc.Close() mw.next(cc) - })).ServeHTTP(c.Writer(), req) + })).ServeHTTP(flusherWriter{c.Writer()}, r) } func sumLen(s []string) int { diff --git a/context_test.go b/context_test.go index 166e73f..5a8a3e1 100644 --- a/context_test.go +++ b/context_test.go @@ -617,6 +617,56 @@ func TestWrapM(t *testing.T) { assert.Equal(t, "OK", w.Body.String()) } +func TestWrapM_RestoresRequestPattern(t *testing.T) { + mw := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + }) + } + + f := MustRouter(WithMiddleware(WrapM(mw))) + f.MustAdd(MethodGet, "/foo/{bar}", func(c *Context) { + assert.Equal(t, "/foo/{bar}", c.Request().Pattern) + require.NoError(t, c.String(http.StatusOK, "OK")) + }) + + req := httptest.NewRequest(http.MethodGet, "/foo/bar", nil) + w := httptest.NewRecorder() + + f.ServeHTTP(w, req) + + assert.Empty(t, req.Pattern) +} + +func TestWrapM_FlusherShim(t *testing.T) { + var sawFlusher bool + var flushed bool + + mw := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + sawFlusher = ok + h.ServeHTTP(w, r) + if ok { + flusher.Flush() + flushed = true + } + }) + } + + f := MustRouter(WithMiddleware(WrapM(mw))) + f.MustAdd(MethodGet, "/", func(c *Context) { + require.NoError(t, c.String(http.StatusOK, "ok")) + }) + + w := httptest.NewRecorder() + f.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil)) + + assert.True(t, sawFlusher) + assert.True(t, flushed) + assert.True(t, w.Flushed) +} + func BenchmarkWrapH(b *testing.B) { req := httptest.NewRequest(http.MethodGet, "https://example.com/a/b/c", nil) w := httptest.NewRecorder() diff --git a/response_writer.go b/response_writer.go index 0f9b4e7..95ad22c 100644 --- a/response_writer.go +++ b/response_writer.go @@ -296,6 +296,14 @@ type onlyWrite struct { io.Writer } +type flusherWriter struct { + ResponseWriter +} + +func (w flusherWriter) Flush() { _ = w.FlushError() } + +func (w flusherWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter } + type noopWriter struct { h http.Header }