From 117247c729442e2d308ede05233b7d13a78fa7c3 Mon Sep 17 00:00:00 2001 From: Max Claus Nunes Date: Mon, 17 Jul 2017 22:20:01 -0300 Subject: [PATCH] Add support for Handle It allows full control over the handler --- examples/custom_handler.go | 58 ++++++++++++++++++++++++++++++++++++++ httpfake.go | 20 ++++--------- request.go | 12 ++++++-- responder.go | 21 ++++++++++++++ 4 files changed, 93 insertions(+), 18 deletions(-) create mode 100644 examples/custom_handler.go create mode 100644 responder.go diff --git a/examples/custom_handler.go b/examples/custom_handler.go new file mode 100644 index 0000000..96b10ed --- /dev/null +++ b/examples/custom_handler.go @@ -0,0 +1,58 @@ +// nolint dupl +package examples + +import ( + "io/ioutil" + "net/http" + "testing" + + "github.com/maxcnunes/httpfake" +) + +// TestHandleCustomResponder tests a fake server handling a GET request +// with a custom responder. It allows full control over the handler. +func TestHandleCustomResponder(t *testing.T) { + fakeService := httpfake.New() + defer fakeService.Server.Close() + + // register a handler for our fake service + fakeService.NewHandler(). + Get("/users"). + Handle(func(w http.ResponseWriter, r *http.Request, rh *httpfake.Request) { + w.WriteHeader(200) + w.Header().Add("Header-From-Custom-Responder-X", "indeed") + w.Write([]byte("Body-From-Custom-Responder-X")) // nolint + }) + + req, err := http.NewRequest("GET", fakeService.ResolveURL("/users"), nil) + if err != nil { + t.Fatal(err) + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() // nolint errcheck + + // Check the status code is what we expect + if status := res.StatusCode; status != 200 { + t.Errorf("request returned wrong status code: got %v want %v", + status, 200) + } + + // Check the response body is what we expect + expected := "Body-From-Custom-Responder-X" + body, _ := ioutil.ReadAll(res.Body) + if bodyString := string(body); bodyString != expected { + t.Errorf("request returned unexpected body: got %v want %v", + bodyString, expected) + } + + // Check the response header is what we expect + expected = "indeed" + if header := res.Header.Get("Header-From-Custom-Responder-X"); header != expected { + t.Errorf("request returned unexpected value for header Header-From-Custom-Responder-X: got %v want %v", + header, expected) + } +} diff --git a/httpfake.go b/httpfake.go index bd76675..a33d885 100644 --- a/httpfake.go +++ b/httpfake.go @@ -23,7 +23,11 @@ func New() *HTTPFake { w.WriteHeader(http.StatusNotFound) return } - respond(rh, w) + if rh.CustomHandle != nil { + rh.CustomHandle(w, r, rh) + return + } + Respond(w, r, rh) })) return fake @@ -70,17 +74,3 @@ func (f *HTTPFake) findHandler(r *http.Request) *Request { } return nil } - -func respond(rh *Request, w http.ResponseWriter) { - if rh.Response.StatusCode > 0 { - w.WriteHeader(rh.Response.StatusCode) - } - if len(rh.Response.BodyBuffer) > 0 { - w.Write(rh.Response.BodyBuffer) // nolint - } - if len(rh.Response.Header) > 0 { - for k := range rh.Response.Header { - w.Header().Add(k, rh.Response.Header.Get(k)) - } - } -} diff --git a/request.go b/request.go index 4321af2..42a5fb8 100644 --- a/request.go +++ b/request.go @@ -7,9 +7,10 @@ import ( // Request ... type Request struct { - Method string - URL *url.URL - Response *Response + Method string + URL *url.URL + Response *Response + CustomHandle Responder } // NewRequest ... @@ -55,6 +56,11 @@ func (r *Request) Reply(status int) *Response { return r.Response.Status(status) } +// Handle ... +func (r *Request) Handle(handle Responder) { + r.CustomHandle = handle +} + func (r *Request) method(method, path string) *Request { if path != "/" { r.URL.Path = path diff --git a/responder.go b/responder.go new file mode 100644 index 0000000..dc9b075 --- /dev/null +++ b/responder.go @@ -0,0 +1,21 @@ +package httpfake + +import "net/http" + +// Responder ... +type Responder func(w http.ResponseWriter, r *http.Request, rh *Request) + +// Respond ... +func Respond(w http.ResponseWriter, r *http.Request, rh *Request) { + if rh.Response.StatusCode > 0 { + w.WriteHeader(rh.Response.StatusCode) + } + if len(rh.Response.BodyBuffer) > 0 { + w.Write(rh.Response.BodyBuffer) // nolint + } + if len(rh.Response.Header) > 0 { + for k := range rh.Response.Header { + w.Header().Add(k, rh.Response.Header.Get(k)) + } + } +}