diff --git a/ghttp/handlers.go b/ghttp/handlers.go index a80b27ddf..fa7fc0ba3 100644 --- a/ghttp/handlers.go +++ b/ghttp/handlers.go @@ -112,6 +112,21 @@ func (g GHTTPWithGomega) VerifyHeaderKV(key string, values ...string) http.Handl return g.VerifyHeader(http.Header{key: values}) } +// VerifyHost returns a handler that verifies the host of a request matches the expected host +// Host is a special header in net/http, which is not set on the request.Header but rather on the Request itself +// +// Host may be a string or a matcher +func (g GHTTPWithGomega) VerifyHost(host interface{}) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + switch p := host.(type) { + case types.GomegaMatcher: + g.gomega.Expect(req.Host).Should(p, "Host mismatch") + default: + g.gomega.Expect(req.Host).Should(Equal(host), "Host mismatch") + } + } +} + //VerifyBody returns a handler that verifies that the body of the request matches the passed in byte array. //It does this using Equal(). func (g GHTTPWithGomega) VerifyBody(expectedBody []byte) http.HandlerFunc { @@ -358,6 +373,10 @@ func VerifyHeaderKV(key string, values ...string) http.HandlerFunc { return NewGHTTPWithGomega(gomega.Default).VerifyHeaderKV(key, values...) } +func VerifyHost(host interface{}) http.HandlerFunc { + return NewGHTTPWithGomega(gomega.Default).VerifyHost(host) +} + func VerifyBody(expectedBody []byte) http.HandlerFunc { return NewGHTTPWithGomega(gomega.Default).VerifyBody(expectedBody) } diff --git a/ghttp/test_server_test.go b/ghttp/test_server_test.go index 1f2b1cc53..1b1342f49 100644 --- a/ghttp/test_server_test.go +++ b/ghttp/test_server_test.go @@ -499,6 +499,69 @@ var _ = Describe("TestServer", func() { }) }) + Describe("VerifyHost", func() { + var ( + err error + req *http.Request + ) + + BeforeEach(func() { + req, err = http.NewRequest("GET", s.URL()+"/host", nil) + Expect(err).ShouldNot(HaveOccurred()) + }) + + When("passed a matcher for host", func() { + BeforeEach(func() { + s.AppendHandlers(CombineHandlers( + VerifyRequest("GET", "/host"), + VerifyHost(Equal("my-host")), + )) + }) + + It("should verify the host", func() { + req.Host = "my-host" + + resp, err = http.DefaultClient.Do(req) + Expect(err).ShouldNot(HaveOccurred()) + }) + + It("should reject an invalid host", func() { + req.Host = "not-my-host" + + failures := InterceptGomegaFailures(func() { + http.DefaultClient.Do(req) + }) + Expect(failures).Should(HaveLen(1)) + }) + }) + + When("passed a string for host", func() { + BeforeEach(func() { + s.AppendHandlers(CombineHandlers( + VerifyRequest("GET", "/host"), + VerifyHost("my-host"), + )) + }) + + It("should verify the host", func() { + req.Host = "my-host" + + resp, err = http.DefaultClient.Do(req) + Expect(err).ShouldNot(HaveOccurred()) + }) + + It("should reject an invalid host", func() { + req.Host = "not-my-host" + + failures := InterceptGomegaFailures(func() { + http.DefaultClient.Do(req) + }) + Expect(failures).Should(HaveLen(1)) + }) + }) + + }) + Describe("VerifyBody", func() { BeforeEach(func() { s.AppendHandlers(CombineHandlers(