diff --git a/api.go b/api.go index 9de22bd..480c792 100644 --- a/api.go +++ b/api.go @@ -7,8 +7,10 @@ package rest import ( "errors" + "log" "net/http" "regexp" + "strings" "github.com/go-rs/rest-api-framework/utils" ) @@ -69,23 +71,89 @@ func (api *API) Route(method string, pattern string, handle Handler) { }) } +func (api *API) Use(handle Handler) { + task := interceptor{ + handle: handle, + } + api.interceptors = append(api.interceptors, task) +} + +func (api *API) All(pattern string, handle Handler) { + api.Route("", pattern, handle) +} + +func (api *API) Get(pattern string, handle Handler) { + api.Route(http.MethodGet, pattern, handle) +} + +func (api *API) Post(pattern string, handle Handler) { + api.Route(http.MethodPost, pattern, handle) +} + +func (api *API) Put(pattern string, handle Handler) { + api.Route(http.MethodPut, pattern, handle) +} + +func (api *API) Delete(pattern string, handle Handler) { + api.Route(http.MethodDelete, pattern, handle) +} + +func (api *API) Options(pattern string, handle Handler) { + api.Route(http.MethodOptions, pattern, handle) +} + +func (api *API) Head(pattern string, handle Handler) { + api.Route(http.MethodHead, pattern, handle) +} + +func (api *API) Patch(pattern string, handle Handler) { + api.Route(http.MethodPatch, pattern, handle) +} + +func (api *API) Exception(err string, handle Handler) { + exp := exception{ + message: err, + handle: handle, + } + api.exceptions = append(api.exceptions, exp) +} + +func (api *API) UnhandledException(handle Handler) { + api.unhandled = handle +} + +var ( + ErrNotFound = errors.New("URL_NOT_FOUND") + ErrUncaughtException = errors.New("UNCAUGHT_EXCEPTION") +) + /** * Required handle for http module */ func (api API) ServeHTTP(res http.ResponseWriter, req *http.Request) { - urlPath := []byte(req.URL.Path) - + // STEP 1: initialize context ctx := Context{ Request: req, Response: res, Query: req.URL.Query(), } - // STEP 1: initialize context ctx.init() defer ctx.destroy() + defer func() { + err := recover() + if err != nil { + log.Fatalln("uncaught exception - ", err) + if !ctx.end { + ctx.err = ErrUncaughtException + ctx.unhandledException() + return + } + } + }() + // STEP 2: execute all interceptors for _, task := range api.interceptors { if ctx.end || ctx.err != nil { @@ -96,12 +164,13 @@ func (api API) ServeHTTP(res http.ResponseWriter, req *http.Request) { } // STEP 3: check routes + urlPath := []byte(req.URL.Path) for _, route := range api.routes { if ctx.end || ctx.err != nil { break } - if (route.method == "" || route.method == req.Method) && route.regex.Match(urlPath) { + if (route.method == "" || strings.EqualFold(route.method, req.Method)) && route.regex.Match(urlPath) { ctx.found = route.method != "" //? ctx.Params = utils.Exec(route.regex, route.params, urlPath) route.handle(&ctx) @@ -122,7 +191,7 @@ func (api API) ServeHTTP(res http.ResponseWriter, req *http.Request) { // STEP 5: unhandled exceptions if !ctx.end { if ctx.err == nil && !ctx.found { - ctx.err = errors.New("URL_NOT_FOUND") + ctx.err = ErrNotFound } if api.unhandled != nil { @@ -135,54 +204,3 @@ func (api API) ServeHTTP(res http.ResponseWriter, req *http.Request) { ctx.unhandledException() } } - -func (api *API) Use(handle Handler) { - task := interceptor{ - handle: handle, - } - api.interceptors = append(api.interceptors, task) -} - -func (api *API) All(pattern string, handle Handler) { - api.Route("", pattern, handle) -} - -func (api *API) Get(pattern string, handle Handler) { - api.Route("GET", pattern, handle) -} - -func (api *API) Post(pattern string, handle Handler) { - api.Route("POST", pattern, handle) -} - -func (api *API) Put(pattern string, handle Handler) { - api.Route("PUT", pattern, handle) -} - -func (api *API) Delete(pattern string, handle Handler) { - api.Route("DELETE", pattern, handle) -} - -func (api *API) Options(pattern string, handle Handler) { - api.Route("OPTIONS", pattern, handle) -} - -func (api *API) Head(pattern string, handle Handler) { - api.Route("HEAD", pattern, handle) -} - -func (api *API) Patch(pattern string, handle Handler) { - api.Route("PATCH", pattern, handle) -} - -func (api *API) Exception(err string, handle Handler) { - exp := exception{ - message: err, - handle: handle, - } - api.exceptions = append(api.exceptions, exp) -} - -func (api *API) UnhandledException(handle Handler) { - api.unhandled = handle -} diff --git a/api_test.go b/api_test.go index 49eb062..9233993 100644 --- a/api_test.go +++ b/api_test.go @@ -118,13 +118,13 @@ func TestAPI_ServeHTTP(t *testing.T) { defer dummy.Close() res, err := http.Get(dummy.URL) - if err != nil { t.Error("ServeHTTP error") + return } greeting, err := ioutil.ReadAll(res.Body) - res.Body.Close() + _ = res.Body.Close() if err != nil { t.Error("ServeHTTP error") } diff --git a/context.go b/context.go index 5173ca2..25605aa 100644 --- a/context.go +++ b/context.go @@ -6,37 +6,47 @@ package rest import ( + "bytes" + "compress/gzip" "net/http" "net/url" + "strings" "github.com/go-rs/rest-api-framework/render" ) +type Task func() + /** * Context */ type Context struct { - Request *http.Request - Response http.ResponseWriter - Query url.Values - Body map[string]interface{} - Params map[string]string - preSendTasks []func() error - postSendTasks []func() error - data map[string]interface{} - err error - status int - found bool - end bool + Request *http.Request + Response http.ResponseWriter + Query url.Values + Body map[string]interface{} + Params map[string]string + headers map[string]string + data map[string]interface{} + err error + status int + found bool + end bool + requestSent bool + preTasksCalled bool + postTasksCalled bool + preSendTasks []Task + postSendTasks []Task } /** * Initialization of context on every request */ func (ctx *Context) init() { + ctx.headers = make(map[string]string) ctx.data = make(map[string]interface{}) - ctx.preSendTasks = make([]func() error, 0) - ctx.postSendTasks = make([]func() error, 0) + ctx.preSendTasks = make([]Task, 0) + ctx.postSendTasks = make([]Task, 0) ctx.status = 200 ctx.found = false ctx.end = false @@ -84,7 +94,7 @@ func (ctx *Context) Status(code int) *Context { * Set Header */ func (ctx *Context) SetHeader(key string, val string) *Context { - ctx.Response.Header().Set(key, val) + ctx.headers[key] = val return ctx } @@ -124,6 +134,7 @@ func (ctx *Context) JSON(data interface{}) { Body: data, } body, err := json.Write(ctx.Response) + //ctx.SetHeader("Content-Type", "application/json;charset=UTF-8") ctx.send(body, err) } @@ -135,24 +146,49 @@ func (ctx *Context) Text(data string) { Body: data, } body, err := txt.Write(ctx.Response) + //ctx.SetHeader("Content-Type", "text/plain;charset=UTF-8") ctx.send(body, err) } /** * */ -func (ctx *Context) PreSend(task func() error) { +func (ctx *Context) PreSend(task Task) { ctx.preSendTasks = append(ctx.preSendTasks, task) } /** * */ -func (ctx *Context) PostSend(task func() error) { +func (ctx *Context) PostSend(task Task) { ctx.postSendTasks = append(ctx.postSendTasks, task) } ////////////////////////////////////////////////// +func compress(data []byte) (cdata []byte, err error) { + var b bytes.Buffer + gz := gzip.NewWriter(&b) + + _, err = gz.Write(data) + if err != nil { + return + } + + err = gz.Flush() + if err != nil { + return + } + + err = gz.Close() + if err != nil { + return + } + + cdata = b.Bytes() + + return +} + /** * Send data */ @@ -166,23 +202,42 @@ func (ctx *Context) send(data []byte, err error) { return } - for _, task := range ctx.preSendTasks { - //TODO: handle error - _ = task() + if !ctx.preTasksCalled { + ctx.preTasksCalled = true + for _, task := range ctx.preSendTasks { + task() + } } - ctx.Response.WriteHeader(ctx.status) - _, err = ctx.Response.Write(data) + if !ctx.requestSent { + ctx.requestSent = true - //TODO: check - should not be recursive - if err != nil { - ctx.err = err - return + for key, val := range ctx.headers { + ctx.Response.Header().Set(key, val) + } + + if strings.Contains(ctx.Request.Header.Get("Accept-Encoding"), "gzip") { + data, err = compress(data) + if err == nil { + ctx.Response.Header().Set("Content-Encoding", "gzip") + } + } + + ctx.Response.WriteHeader(ctx.status) + + _, err = ctx.Response.Write(data) + + if err != nil { + ctx.err = err + return + } } - for _, task := range ctx.postSendTasks { - //TOD: handle error - _ = task() + if !ctx.postTasksCalled { + ctx.postTasksCalled = true + for _, task := range ctx.postSendTasks { + task() + } } ctx.End() @@ -192,12 +247,25 @@ func (ctx *Context) send(data []byte, err error) { * Unhandled Exception */ func (ctx *Context) unhandledException() { + defer func() { + err := recover() + if err != nil { + if !ctx.requestSent { + ctx.Response.WriteHeader(http.StatusInternalServerError) + ctx.Response.Header().Set("Content-Type", "text/plain;charset=UTF-8") + _, _ = ctx.Response.Write([]byte("Internal Server Error")) + } + } + }() + err := ctx.GetError() + if err != nil { msg := err.Error() - ctx.Status(500) + ctx.Status(http.StatusInternalServerError) + ctx.SetHeader("Content-Type", "text/plain;charset=UTF-8") if msg == "URL_NOT_FOUND" { - ctx.Status(404) + ctx.Status(http.StatusNotFound) } ctx.Write([]byte(msg)) } diff --git a/context_test.go b/context_test.go index d00474f..c240d64 100644 --- a/context_test.go +++ b/context_test.go @@ -7,7 +7,9 @@ package rest import ( "errors" + "net/http" "net/http/httptest" + "strings" "testing" ) @@ -84,6 +86,7 @@ func TestContext_End(t *testing.T) { func TestContext_Write1(t *testing.T) { ctx.init() + ctx.Request = httptest.NewRequest(http.MethodGet, "/", strings.NewReader("")) ctx.Response = httptest.NewRecorder() data := []byte("Hello World") ctx.Write(data) diff --git a/namespace.go b/namespace.go index fce0a96..ea9af9d 100644 --- a/namespace.go +++ b/namespace.go @@ -5,6 +5,8 @@ */ package rest +import "net/http" + /** * Namespace - Application */ @@ -13,7 +15,6 @@ type Namespace struct { api *API } -//TODO: error handling on unset api func (n *Namespace) Set(prefix string, api *API) { n.prefix = prefix n.api = api @@ -28,31 +29,31 @@ func (n *Namespace) All(pattern string, handle Handler) { } func (n *Namespace) Get(pattern string, handle Handler) { - n.api.Route("GET", n.prefix+pattern, handle) + n.api.Route(http.MethodGet, n.prefix+pattern, handle) } func (n *Namespace) Post(pattern string, handle Handler) { - n.api.Route("POST", n.prefix+pattern, handle) + n.api.Route(http.MethodPost, n.prefix+pattern, handle) } func (n *Namespace) Put(pattern string, handle Handler) { - n.api.Route("PUT", n.prefix+pattern, handle) + n.api.Route(http.MethodPut, n.prefix+pattern, handle) } func (n *Namespace) Delete(pattern string, handle Handler) { - n.api.Route("DELETE", n.prefix+pattern, handle) + n.api.Route(http.MethodDelete, n.prefix+pattern, handle) } func (n *Namespace) Options(pattern string, handle Handler) { - n.api.Route("OPTIONS", n.prefix+pattern, handle) + n.api.Route(http.MethodOptions, n.prefix+pattern, handle) } func (n *Namespace) Head(pattern string, handle Handler) { - n.api.Route("HEAD", n.prefix+pattern, handle) + n.api.Route(http.MethodHead, n.prefix+pattern, handle) } func (n *Namespace) Patch(pattern string, handle Handler) { - n.api.Route("PATCH", n.prefix+pattern, handle) + n.api.Route(http.MethodPatch, n.prefix+pattern, handle) } func (n *Namespace) Exception(err string, handle Handler) { diff --git a/render/json.go b/render/json.go index 6de3b62..702bb26 100644 --- a/render/json.go +++ b/render/json.go @@ -17,12 +17,12 @@ type JSON struct { Body interface{} } -var ( +const ( jsonType = "application/json" ) var ( - invalidJson = errors.New("INVALID_JSON_RESPONSE") + ErrInvalidJson = errors.New("INVALID_JSON_RESPONSE") ) /** @@ -31,7 +31,7 @@ var ( func (j JSON) Write(w http.ResponseWriter) (data []byte, err error) { _type := reflect.TypeOf(j.Body).String() if _type == "int" || _type == "float64" || _type == "bool" { - err = invalidJson + err = ErrInvalidJson } else if _type == "string" { data, err = json.RawMessage(j.Body.(string)).MarshalJSON() } else { @@ -43,7 +43,7 @@ func (j JSON) Write(w http.ResponseWriter) (data []byte, err error) { } if !json.Valid(data) { - err = invalidJson + err = ErrInvalidJson return } diff --git a/render/text.go b/render/text.go index 15cedad..1f98b17 100644 --- a/render/text.go +++ b/render/text.go @@ -13,8 +13,8 @@ type Text struct { Body string } -var ( - plainType = "text/plain" +const ( + plainType = "text/plain;charset=UTF-8" ) /** diff --git a/version.txt b/version.txt index cb1cd2e..4a16fd9 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -v0.0.1-beta.2 +v0.0.1-beta.3