From f6673d89a7d5d6dd0f9b9b54744bfbd3bf4df56d Mon Sep 17 00:00:00 2001 From: Eric Marden Date: Sat, 11 Mar 2017 16:52:27 -0600 Subject: [PATCH] implements automatic Content Negotiation - adds `Respond()` helper function to be used in Endpoint method handlers - `Respond()` takes an http.ResponseWriter, http.Request, and user supplied: status, body, and 0 or more `http.Header`s - if value of `Accept` header ends in `json`, `body` will be encoded as json - if value of `Accept` header ends in `xml`, `body` will be encoded as xml - switches `Present()`, `Representation` and friends to just using struct tags to aid in encoding `RootResource` in xml/json - adds `EndpointResource` to replace the old `Representation` fixes #17 --- config.go | 6 ++++ discovery.go | 61 +++++++++++++++++---------------------- discovery_test.go | 15 +--------- encoder.go | 63 ++++++++++++++++++++++++++++++++++++++++ encoder_test.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++ endpoint.go | 24 ++++++++++++++-- endpoint_test.go | 6 ++++ error.go | 14 +++++++++ error_test.go | 15 ++++++++++ hyperdrive.go | 8 ++---- hyperdrive_test.go | 14 ++++----- middleware.go | 12 ++++---- 12 files changed, 238 insertions(+), 71 deletions(-) create mode 100644 encoder.go create mode 100644 encoder_test.go create mode 100644 error.go create mode 100644 error_test.go diff --git a/config.go b/config.go index 99a8344..d16e638 100644 --- a/config.go +++ b/config.go @@ -7,6 +7,12 @@ import ( "github.com/caarlos0/env" ) +var conf Config + +func init() { + conf = NewConfig() +} + // Config holds configuration values from the environment, with sane defaults // (where possible). Required configuration will throw a Fatal error if they // are missing. diff --git a/discovery.go b/discovery.go index 5c6c06e..2dfb837 100644 --- a/discovery.go +++ b/discovery.go @@ -1,7 +1,7 @@ package hyperdrive import ( - "encoding/json" + "encoding/xml" "net/http" ) @@ -14,51 +14,42 @@ type Representation map[string]interface{} // the hypermedia respresentation returned by the Discovery URL endpoint for // API clients to learn about the API. type RootResource struct { - Name string - Endpoints []Endpointer + XMLName xml.Name `json:"-" xml:"api"` + Resource string `json:"resource" xml:"-"` + Name string `json:"name" xml:"name,attr"` + Endpoints []EndpointResource `json:"endpoints" xml:"endpoints"` } -// NewRootResource creates an instance of RootResource, based on the given API. -func NewRootResource(api API) *RootResource { - return &RootResource{Name: api.Name} -} - -// AddEndpointer adds Endpointers to the slice of Endpointers on an instance of RootResource. -func (root *RootResource) AddEndpointer(e Endpointer) { - root.Endpoints = append(root.Endpoints, e) +// EndpointResource contains information about and Endpoint, and is +// the hypermedia respresentation returned by the Discovery URL endpoint for +// API clients to learn about the Endpoint. +type EndpointResource struct { + XMLName xml.Name `json:"-" xml:"endpoint"` + Resource string `json:"resource" xml:"-"` + Name string `json:"name" xml:"name,attr"` + Path string `json:"path" xml:"path,attr"` + MethodsList string `json:"-" xml:"methods,attr"` + Methods []string `json:"methods" xml:"-"` + Desc string `json:"description" xml:"description"` } -// Present returns an Representation of the RootResource to describe the API -// for the Discovery URL. -func (root *RootResource) Present() Representation { - return Representation{ - "resource": "api", - "name": root.Name, - "endpoints": root.endpointRepresentations(), - } +// NewRootResource creates an instance of RootResource from the given API. +func NewRootResource(api API) *RootResource { + return &RootResource{Resource: "api", Name: api.Name} } -func (root *RootResource) endpointRepresentations() []Representation { - var endpoints = []Representation{} - for _, e := range root.Endpoints { - endpoints = append(endpoints, PresentEndpoint(e)) - } - return endpoints +// NewEndpointResource creates an instance of EndpointResource from the given Endpointer. +func NewEndpointResource(e Endpointer) EndpointResource { + return EndpointResource{Resource: "endpoint", Name: e.GetName(), Path: e.GetPath(), MethodsList: GetMethodsList(e), Methods: GetMethods(e), Desc: e.GetDesc()} } -// PresentEndpoint returns a Representation to describe an Endpoint for the Discovery URL. -func PresentEndpoint(e Endpointer) Representation { - return Representation{ - "name": e.GetName(), - "desc": e.GetDesc(), - "path": e.GetPath(), - "methods": GetMethods(e), - } +// AddEndpoint adds EndpointResources to the slice of Endpoints on an instance of RootResource. +func (root *RootResource) AddEndpoint(e Endpointer) { + root.Endpoints = append(root.Endpoints, NewEndpointResource(e)) } // ServeHTTP satisfies the http.Handler interface and returns the hypermedia // representation of the Discovery URL. func (root *RootResource) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - rw.Header().Set("Content-Type", "application/json") - json.NewEncoder(rw).Encode(root.Present()) + Respond(rw, r, 200, root) } diff --git a/discovery_test.go b/discovery_test.go index 83b5da4..421a35d 100644 --- a/discovery_test.go +++ b/discovery_test.go @@ -11,23 +11,10 @@ func (suite *HyperdriveTestSuite) TestRootResourceEndpointsEmpty() { } func (suite *HyperdriveTestSuite) TestAddEndpointer() { - suite.TestRoot.AddEndpointer(suite.TestEndpoint) + suite.TestRoot.AddEndpoint(suite.TestEndpoint) suite.Equal(1, len(suite.TestRoot.Endpoints), "expects 1 Endpoints") } func (suite *HyperdriveTestSuite) TestRootResourceServeHTTP() { suite.Implements((*http.Handler)(nil), suite.TestRoot, "return an implementation of http.Handler") } - -func (suite *HyperdriveTestSuite) TestPresentRepresentation() { - suite.IsType(Representation{}, suite.TestRoot.Present(), "return a Representation") -} - -func (suite *HyperdriveTestSuite) TestPresent() { - suite.TestRoot.AddEndpointer(suite.TestEndpoint) - suite.Equal(suite.TestRootRepresentation, suite.TestRoot.Present(), "return the correct Representation of RootResource") -} - -func (suite *HyperdriveTestSuite) TestPresentEndpoint() { - suite.Equal(suite.TestEndpointRepresentation, PresentEndpoint(suite.TestEndpoint), "return the correct Representation of RootResource") -} diff --git a/encoder.go b/encoder.go new file mode 100644 index 0000000..1226832 --- /dev/null +++ b/encoder.go @@ -0,0 +1,63 @@ +package hyperdrive + +import ( + "encoding/json" + "encoding/xml" + "errors" + "net/http" + "strings" +) + +// ContentEncoder interface wraps the details of encoding response bodies to +// support automatic Content Negotiation. +type ContentEncoder interface { + Encode(interface{}) error +} + +// NullEncoder is an implementation of ContentEncoder, and is the default +// encoder used when Content Negotiation has falied. It produces a 406 +// NOT ACCEPTABLE error when it's Encode() function is run. +type NullEncoder struct{} + +// Encode returns a 406 NOT ACCEPTABLE error. +func (enc NullEncoder) Encode(v interface{}) error { + return errors.New(http.StatusText(http.StatusNotAcceptable)) +} + +// JSONEncoder is an implementation of ContentEncoder and wraps the Encoder +// found in encoding/json package. +type JSONEncoder struct { + Encoder *json.Encoder +} + +// Encode encodes input as json text or returns an error. +func (enc JSONEncoder) Encode(v interface{}) error { + return enc.Encoder.Encode(v) +} + +// XMLEncoder is an implementation of ContentEncoder and wraps the Encoder +// found in encoding/xml package. +type XMLEncoder struct { + Encoder *xml.Encoder +} + +// Encode encodes input as xml text or returns an error. +func (enc XMLEncoder) Encode(v interface{}) error { + return enc.Encoder.Encode(v) +} + +// GetEncoder returns the correct ContentEncoder, determined by the Accept +// header, to support automatic Content Negotiation. +func GetEncoder(rw http.ResponseWriter, accept string) (ContentEncoder, http.ResponseWriter) { + if strings.HasSuffix(accept, "json") { + rw.Header().Set("Content-Type", accept) + return JSONEncoder{json.NewEncoder(rw)}, rw + } + + if strings.HasSuffix(accept, "xml") { + rw.Header().Set("Content-Type", accept) + return XMLEncoder{xml.NewEncoder(rw)}, rw + } + + return NullEncoder{}, rw +} diff --git a/encoder_test.go b/encoder_test.go new file mode 100644 index 0000000..d048610 --- /dev/null +++ b/encoder_test.go @@ -0,0 +1,71 @@ +package hyperdrive + +import ( + "encoding/json" + "encoding/xml" + "net/http/httptest" +) + +func (suite *HyperdriveTestSuite) TestNullEncoder() { + suite.Implements((*ContentEncoder)(nil), NullEncoder{}, "return an implementation of ContentEncoder") +} + +func (suite *HyperdriveTestSuite) TestNullEncoderEncode() { + suite.Error(NullEncoder{}.Encode(suite.TestEndpointResource), "return an error") +} + +func (suite *HyperdriveTestSuite) TestJSONEncoder() { + suite.Implements((*ContentEncoder)(nil), JSONEncoder{}, "return an implementation of ContentEncoder") +} + +func (suite *HyperdriveTestSuite) TestJSONEncoderEncodeNoError() { + rw := httptest.NewRecorder() + enc := JSONEncoder{Encoder: json.NewEncoder(rw)} + suite.Nil(enc.Encode(suite.TestEndpointResource), "returns nil") +} + +func (suite *HyperdriveTestSuite) TestJSONEncoderEncode() { + rw := httptest.NewRecorder() + enc := JSONEncoder{Encoder: json.NewEncoder(rw)} + enc.Encode(suite.TestEndpointResource) + json := `{"resource":"endpoint","name":"Test","path":"/test","methods":["OPTIONS"],"description":"Test Endpoint"}` + "\n" + suite.Equal(json, rw.Body.String(), "returns nil") +} + +func (suite *HyperdriveTestSuite) TestXMLEncoder() { + suite.Implements((*ContentEncoder)(nil), XMLEncoder{}, "return an implementation of ContentEncoder") +} + +func (suite *HyperdriveTestSuite) TestXMLEncoderEncodeNoError() { + rw := httptest.NewRecorder() + enc := XMLEncoder{Encoder: xml.NewEncoder(rw)} + suite.Nil(enc.Encode(suite.TestEndpointResource), "returns nil") +} + +func (suite *HyperdriveTestSuite) TestXMLEncoderEncode() { + rw := httptest.NewRecorder() + enc := XMLEncoder{Encoder: xml.NewEncoder(rw)} + enc.Encode(suite.TestEndpointResource) + xml := `Test Endpoint` + suite.Equal(xml, rw.Body.String(), "returns nil") +} + +func (suite *HyperdriveTestSuite) TestGetEncoder() { + enc, _ := GetEncoder(httptest.NewRecorder(), "text/plain") + suite.Implements((*ContentEncoder)(nil), enc, "return an implementation of ContentEncoder") +} + +func (suite *HyperdriveTestSuite) TestGetEncoderXML() { + enc, _ := GetEncoder(httptest.NewRecorder(), "application/xml") + suite.IsType(XMLEncoder{}, enc, "return an XMLEncoder") +} + +func (suite *HyperdriveTestSuite) TestGetEncoderJSON() { + enc, _ := GetEncoder(httptest.NewRecorder(), "application/json") + suite.IsType(JSONEncoder{}, enc, "return a JSONEncoder") +} + +func (suite *HyperdriveTestSuite) TestGetEncoderNULL() { + enc, _ := GetEncoder(httptest.NewRecorder(), "text/plain") + suite.IsType(NullEncoder{}, enc, "return a NullEncoder") +} diff --git a/endpoint.go b/endpoint.go index 85f7eea..5d523eb 100644 --- a/endpoint.go +++ b/endpoint.go @@ -163,9 +163,9 @@ func GetMethodsList(e Endpointer) string { } // NewMethodHandler sets the correct http.Handler for each method, depending on -// the interfaces the Enpointer supports. It returns an http.HandlerFunc, ready +// the interfaces the Enpointer supports. It returns an http.Handler, ready // to be served directly, wrapped in other middleware, etc. -func NewMethodHandler(e Endpointer) http.HandlerFunc { +func NewMethodHandler(e Endpointer) http.Handler { handler := make(handlers.MethodHandler) if h, ok := interface{}(e).(GetHandler); ok { handler["GET"] = http.HandlerFunc(h.Get) @@ -190,5 +190,23 @@ func NewMethodHandler(e Endpointer) http.HandlerFunc { if h, ok := interface{}(e).(OptionsHandler); ok { handler["OPTIONS"] = http.HandlerFunc(h.Options) } - return http.HandlerFunc(handler.ServeHTTP) + return handler +} + +// Respond is a helper function to make it easy for an Endpointer's method +// handler (e.g. GetHandler) to respond with the appropriate Content-Type. +func Respond(rw http.ResponseWriter, r *http.Request, status int, body interface{}, headers ...http.Header) (http.ResponseWriter, *http.Request) { + var enc ContentEncoder + enc, rw = GetEncoder(rw, r.Header.Get("Accept")) + err := enc.Encode(body) + if err != nil { + http.Error(rw, err.Error(), http.StatusNotAcceptable) + // TODO: Add LOGGING + return rw, r + } + rw.WriteHeader(status) + for _, header := range headers { + header.Write(rw) + } + return rw, r } diff --git a/endpoint_test.go b/endpoint_test.go index 5c8f5d0..0fdfc40 100644 --- a/endpoint_test.go +++ b/endpoint_test.go @@ -1,5 +1,7 @@ package hyperdrive +import "net/http" + func (suite *HyperdriveTestSuite) TestNewEndpoint() { suite.IsType(&Endpoint{}, suite.TestEndpoint, "expects an instance of hyperdrive.Endpoint") } @@ -31,3 +33,7 @@ func (suite *HyperdriveTestSuite) TestGetMethods() { func (suite *HyperdriveTestSuite) TestGetMethodsList() { suite.Equal("OPTIONS", GetMethodsList(suite.TestEndpoint), "expects a list of supported method strings") } + +func (suite *HyperdriveTestSuite) TestNewMethodHandler() { + suite.Implements((*http.Handler)(nil), NewMethodHandler(suite.TestEndpoint), "return an implementation of http.Handler") +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..02c7bec --- /dev/null +++ b/error.go @@ -0,0 +1,14 @@ +package hyperdrive + +import "net/http" + +// GetErrorText helps ensure implementation details are not leaked in production +// environments. If this is production, it returns the http.StatusText for the +// given status code. If this is not production, the error message is returned +// to aid in debugging. +func GetErrorText(status int, err error) string { + if conf.Env != "production" { + return err.Error() + } + return http.StatusText(status) +} diff --git a/error_test.go b/error_test.go new file mode 100644 index 0000000..3299e6a --- /dev/null +++ b/error_test.go @@ -0,0 +1,15 @@ +package hyperdrive + +import ( + "errors" + "net/http" +) + +func (suite *HyperdriveTestSuite) TestGetErrorTextProduction() { + conf.Env = "production" + suite.Equal(http.StatusText(406), GetErrorText(406, errors.New("Test Error")), "returns 406 Status Text") +} + +func (suite *HyperdriveTestSuite) TestGetErrorText() { + suite.Equal("Test Error", GetErrorText(406, errors.New("Test Error")), "returns Error Text") +} diff --git a/hyperdrive.go b/hyperdrive.go index ae009da..910b4e5 100644 --- a/hyperdrive.go +++ b/hyperdrive.go @@ -19,7 +19,6 @@ type API struct { Router *mux.Router Server *http.Server Root *RootResource - conf Config endpoints []Endpoint } @@ -29,13 +28,12 @@ func NewAPI(name string, desc string) API { Name: name, Desc: desc, Router: mux.NewRouter(), - conf: NewConfig(), } api.Root = NewRootResource(api) api.Router.Handle("/", api.DefaultMiddlewareChain(api.Root)).Methods("GET") api.Server = &http.Server{ Handler: api.Router, - Addr: api.conf.GetPort(), + Addr: conf.GetPort(), WriteTimeout: 15 * time.Second, ReadTimeout: 15 * time.Second, } @@ -46,7 +44,7 @@ func NewAPI(name string, desc string) API { // respond with a 405 error if the endpoint does not support a particular // HTTP method. func (api *API) AddEndpoint(e Endpointer) { - api.Root.AddEndpointer(e) + api.Root.AddEndpoint(e) api.Router.Handle(e.GetPath(), api.DefaultMiddlewareChain(NewMethodHandler(e))) } @@ -63,7 +61,7 @@ func (api *API) GetMediaType(e Endpointer) string { // Start starts the configured http server, listening on the configured Port // (default: 5000). Set the PORT environment variable to change this. func (api *API) Start() { - log.Printf("Hyperdrive API starting on PORT %d in ENVIRONMENT %s", api.conf.Port, api.conf.Env) + log.Printf("Hyperdriven API: %s starting on: http://0.0.0.0:%d in: %s", api.Name, conf.Port, conf.Env) log.Fatal(api.Server.ListenAndServe()) } diff --git a/hyperdrive_test.go b/hyperdrive_test.go index 066f550..414a806 100644 --- a/hyperdrive_test.go +++ b/hyperdrive_test.go @@ -9,12 +9,11 @@ import ( type HyperdriveTestSuite struct { suite.Suite - TestAPI API - TestEndpoint Endpointer - TestHandler http.Handler - TestRoot *RootResource - TestRootRepresentation Representation - TestEndpointRepresentation Representation + TestAPI API + TestEndpoint Endpointer + TestHandler http.Handler + TestRoot *RootResource + TestEndpointResource EndpointResource } func (suite *HyperdriveTestSuite) SetupTest() { @@ -22,8 +21,7 @@ func (suite *HyperdriveTestSuite) SetupTest() { suite.TestEndpoint = NewEndpoint("Test", "Test Endpoint", "/test", "1.0.1") suite.TestHandler = NewMethodHandler(suite.TestEndpoint) suite.TestRoot = NewRootResource(suite.TestAPI) - suite.TestEndpointRepresentation = Representation{"name": "Test", "desc": "Test Endpoint", "path": "/test", "methods": []string{"OPTIONS"}} - suite.TestRootRepresentation = Representation{"resource": "api", "name": "API", "endpoints": []Representation{suite.TestEndpointRepresentation}} + suite.TestEndpointResource = NewEndpointResource(suite.TestEndpoint) } func (suite *HyperdriveTestSuite) TestNewAPI() { diff --git a/middleware.go b/middleware.go index 2ab565f..fb760c1 100644 --- a/middleware.go +++ b/middleware.go @@ -25,7 +25,7 @@ func (api *API) LoggingMiddleware(h http.Handler) http.Handler { // RecoveryMiddleware wraps the given http.Handler and recovers from panics. It wil log // the stacktrace if HYPERDRIVE_ENVIRONMENT env var is not set to "production". func (api *API) RecoveryMiddleware(h http.Handler) http.Handler { - opt := handlers.PrintRecoveryStack(api.conf.Env != "production") + opt := handlers.PrintRecoveryStack(conf.Env != "production") return handlers.RecoveryHandler(opt)(h) } @@ -49,7 +49,7 @@ func (api *API) RecoveryMiddleware(h http.Handler) http.Handler { // More info can be found in the docs for the compress/flate package: // https://golang.org/pkg/compress/flate/ func (api *API) CompressionMiddleware(h http.Handler) http.Handler { - return handlers.CompressHandlerLevel(h, api.conf.GzipLevel) + return handlers.CompressHandlerLevel(h, conf.GzipLevel) } // MethodOverrideMiddleware allows clients who can not perform native PUT, PATCH, @@ -67,13 +67,13 @@ func (api *API) MethodOverrideMiddleware(h http.Handler) http.Handler { // - CORS_HEADERS (string) // - CORS_CREDENTIALS (bool) func (api *API) CorsMiddleware(h http.Handler) http.Handler { - if api.conf.CorsEnabled == true { + if conf.CorsEnabled == true { return h } defaultHeaders := []string{"Content-Type", "X-Content-Type-Options"} - headers := handlers.AllowedHeaders(append(defaultHeaders, strings.Split(api.conf.CorsHeaders, ",")...)) - origins := handlers.AllowedOrigins(strings.Split(api.conf.CorsOrigins, ",")) - if api.conf.CorsCredentials == true { + headers := handlers.AllowedHeaders(append(defaultHeaders, strings.Split(conf.CorsHeaders, ",")...)) + origins := handlers.AllowedOrigins(strings.Split(conf.CorsOrigins, ",")) + if conf.CorsCredentials == true { handlers.AllowCredentials() } return handlers.CORS(headers, origins)(h)