From c49faf9a8ae63a09fdec77b2b17972c447a92116 Mon Sep 17 00:00:00 2001 From: RW Date: Fri, 22 Dec 2023 14:49:58 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20[Bug]:=20Adaptator=20+=20otelfib?= =?UTF-8?q?er=20issue=20#2641=20(#2772)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/adaptor/adaptor.go | 2 ++ middleware/adaptor/adaptor_test.go | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/middleware/adaptor/adaptor.go b/middleware/adaptor/adaptor.go index db2149d921..8bd6f3fe00 100644 --- a/middleware/adaptor/adaptor.go +++ b/middleware/adaptor/adaptor.go @@ -76,6 +76,7 @@ func HTTPMiddleware(mw func(http.Handler) http.Handler) fiber.Handler { c.Request().Header.SetMethod(r.Method) c.Request().SetRequestURI(r.RequestURI) c.Request().SetHost(r.Host) + c.Request().Header.SetHost(r.Host) for key, val := range r.Header { for _, v := range val { c.Request().Header.Set(key, v) @@ -128,6 +129,7 @@ func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc { req.Header.SetMethod(r.Method) req.SetRequestURI(r.RequestURI) req.SetHost(r.Host) + req.Header.SetHost(r.Host) for key, val := range r.Header { for _, v := range val { req.Header.Set(key, v) diff --git a/middleware/adaptor/adaptor_test.go b/middleware/adaptor/adaptor_test.go index 6e03b05a2d..dc52704760 100644 --- a/middleware/adaptor/adaptor_test.go +++ b/middleware/adaptor/adaptor_test.go @@ -116,6 +116,7 @@ var ( ) func Test_HTTPMiddleware(t *testing.T) { + const expectedHost = "foobar.com" tests := []struct { name string url string @@ -148,6 +149,7 @@ func Test_HTTPMiddleware(t *testing.T) { w.WriteHeader(http.StatusMethodNotAllowed) return } + r = r.WithContext(context.WithValue(r.Context(), TestContextKey, "okay")) r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "not_okay")) r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "okay")) @@ -180,6 +182,7 @@ func Test_HTTPMiddleware(t *testing.T) { for _, tt := range tests { req, err := http.NewRequestWithContext(context.Background(), tt.method, tt.url, nil) + req.Host = expectedHost utils.AssertEqual(t, nil, err) resp, err := app.Test(req) @@ -188,6 +191,7 @@ func Test_HTTPMiddleware(t *testing.T) { } req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", nil) + req.Host = expectedHost utils.AssertEqual(t, nil, err) resp, err := app.Test(req) @@ -239,6 +243,8 @@ func testFiberToHandlerFunc(t *testing.T, checkDefaultPort bool, app ...*fiber.A utils.AssertEqual(t, expectedRequestURI, string(c.Context().RequestURI()), "RequestURI") utils.AssertEqual(t, expectedContentLength, c.Context().Request.Header.ContentLength(), "ContentLength") utils.AssertEqual(t, expectedHost, c.Hostname(), "Host") + utils.AssertEqual(t, expectedHost, string(c.Request().Header.Host()), "Host") + utils.AssertEqual(t, "http://"+expectedHost, c.BaseURL(), "BaseURL") utils.AssertEqual(t, expectedRemoteAddr, c.Context().RemoteAddr().String(), "RemoteAddr") body := string(c.Body())