diff --git a/api.go b/api.go index 6bd3ba2..40893eb 100644 --- a/api.go +++ b/api.go @@ -250,19 +250,29 @@ func (api *API) addResource(prototype jsonapi.MarshalIdentifier, source interfac baseURL = "/" + prefix + baseURL } - api.router.Handle("OPTIONS", baseURL, func(w http.ResponseWriter, r *http.Request, _ map[string]string) { + api.router.Handle("OPTIONS", baseURL, func(w http.ResponseWriter, r *http.Request, _ map[string]string, context map[string]interface{}) { c := api.contextPool.Get().(APIContexter) c.Reset() + + for key, val := range context { + c.Set(key, val) + } + api.middlewareChain(c, w, r) w.Header().Set("Allow", strings.Join(getAllowedMethods(source, true), ",")) w.WriteHeader(http.StatusNoContent) api.contextPool.Put(c) }) - api.router.Handle("GET", baseURL, func(w http.ResponseWriter, r *http.Request, _ map[string]string) { + api.router.Handle("GET", baseURL, func(w http.ResponseWriter, r *http.Request, _ map[string]string, context map[string]interface{}) { info := requestInfo(r, api) c := api.contextPool.Get().(APIContexter) c.Reset() + + for key, val := range context { + c.Set(key, val) + } + api.middlewareChain(c, w, r) err := res.handleIndex(c, w, r, *info) @@ -273,19 +283,29 @@ func (api *API) addResource(prototype jsonapi.MarshalIdentifier, source interfac }) if _, ok := source.(ResourceGetter); ok { - api.router.Handle("OPTIONS", baseURL+"/:id", func(w http.ResponseWriter, r *http.Request, _ map[string]string) { + api.router.Handle("OPTIONS", baseURL+"/:id", func(w http.ResponseWriter, r *http.Request, _ map[string]string, context map[string]interface{}) { c := api.contextPool.Get().(APIContexter) c.Reset() + + for key, val := range context { + c.Set(key, val) + } + api.middlewareChain(c, w, r) w.Header().Set("Allow", strings.Join(getAllowedMethods(source, false), ",")) w.WriteHeader(http.StatusNoContent) api.contextPool.Put(c) }) - api.router.Handle("GET", baseURL+"/:id", func(w http.ResponseWriter, r *http.Request, params map[string]string) { + api.router.Handle("GET", baseURL+"/:id", func(w http.ResponseWriter, r *http.Request, params map[string]string, context map[string]interface{}) { info := requestInfo(r, api) c := api.contextPool.Get().(APIContexter) c.Reset() + + for key, val := range context { + c.Set(key, val) + } + api.middlewareChain(c, w, r) err := res.handleRead(c, w, r, params, *info) api.contextPool.Put(c) @@ -301,10 +321,15 @@ func (api *API) addResource(prototype jsonapi.MarshalIdentifier, source interfac relations := casted.GetReferences() for _, relation := range relations { api.router.Handle("GET", baseURL+"/:id/relationships/"+relation.Name, func(relation jsonapi.Reference) routing.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request, params map[string]string) { + return func(w http.ResponseWriter, r *http.Request, params map[string]string, context map[string]interface{}) { info := requestInfo(r, api) c := api.contextPool.Get().(APIContexter) c.Reset() + + for key, val := range context { + c.Set(key, val) + } + api.middlewareChain(c, w, r) err := res.handleReadRelation(c, w, r, params, *info, relation) api.contextPool.Put(c) @@ -315,10 +340,15 @@ func (api *API) addResource(prototype jsonapi.MarshalIdentifier, source interfac }(relation)) api.router.Handle("GET", baseURL+"/:id/"+relation.Name, func(relation jsonapi.Reference) routing.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request, params map[string]string) { + return func(w http.ResponseWriter, r *http.Request, params map[string]string, context map[string]interface{}) { info := requestInfo(r, api) c := api.contextPool.Get().(APIContexter) c.Reset() + + for key, val := range context { + c.Set(key, val) + } + api.middlewareChain(c, w, r) err := res.handleLinked(c, api, w, r, params, relation, *info) api.contextPool.Put(c) @@ -329,9 +359,14 @@ func (api *API) addResource(prototype jsonapi.MarshalIdentifier, source interfac }(relation)) api.router.Handle("PATCH", baseURL+"/:id/relationships/"+relation.Name, func(relation jsonapi.Reference) routing.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request, params map[string]string) { + return func(w http.ResponseWriter, r *http.Request, params map[string]string, context map[string]interface{}) { c := api.contextPool.Get().(APIContexter) c.Reset() + + for key, val := range context { + c.Set(key, val) + } + api.middlewareChain(c, w, r) err := res.handleReplaceRelation(c, w, r, params, relation) api.contextPool.Put(c) @@ -344,9 +379,14 @@ func (api *API) addResource(prototype jsonapi.MarshalIdentifier, source interfac if _, ok := ptrPrototype.(jsonapi.EditToManyRelations); ok && relation.Name == jsonapi.Pluralize(relation.Name) { // generate additional routes to manipulate to-many relationships api.router.Handle("POST", baseURL+"/:id/relationships/"+relation.Name, func(relation jsonapi.Reference) routing.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request, params map[string]string) { + return func(w http.ResponseWriter, r *http.Request, params map[string]string, context map[string]interface{}) { c := api.contextPool.Get().(APIContexter) c.Reset() + + for key, val := range context { + c.Set(key, val) + } + api.middlewareChain(c, w, r) err := res.handleAddToManyRelation(c, w, r, params, relation) api.contextPool.Put(c) @@ -357,9 +397,14 @@ func (api *API) addResource(prototype jsonapi.MarshalIdentifier, source interfac }(relation)) api.router.Handle("DELETE", baseURL+"/:id/relationships/"+relation.Name, func(relation jsonapi.Reference) routing.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request, params map[string]string) { + return func(w http.ResponseWriter, r *http.Request, params map[string]string, context map[string]interface{}) { c := api.contextPool.Get().(APIContexter) c.Reset() + + for key, val := range context { + c.Set(key, val) + } + api.middlewareChain(c, w, r) err := res.handleDeleteToManyRelation(c, w, r, params, relation) api.contextPool.Put(c) @@ -373,10 +418,15 @@ func (api *API) addResource(prototype jsonapi.MarshalIdentifier, source interfac } if _, ok := source.(ResourceCreator); ok { - api.router.Handle("POST", baseURL, func(w http.ResponseWriter, r *http.Request, params map[string]string) { + api.router.Handle("POST", baseURL, func(w http.ResponseWriter, r *http.Request, params map[string]string, context map[string]interface{}) { info := requestInfo(r, api) c := api.contextPool.Get().(APIContexter) c.Reset() + + for key, val := range context { + c.Set(key, val) + } + api.middlewareChain(c, w, r) err := res.handleCreate(c, w, r, info.prefix, *info) api.contextPool.Put(c) @@ -387,9 +437,14 @@ func (api *API) addResource(prototype jsonapi.MarshalIdentifier, source interfac } if _, ok := source.(ResourceDeleter); ok { - api.router.Handle("DELETE", baseURL+"/:id", func(w http.ResponseWriter, r *http.Request, params map[string]string) { + api.router.Handle("DELETE", baseURL+"/:id", func(w http.ResponseWriter, r *http.Request, params map[string]string, context map[string]interface{}) { c := api.contextPool.Get().(APIContexter) c.Reset() + + for key, val := range context { + c.Set(key, val) + } + api.middlewareChain(c, w, r) err := res.handleDelete(c, w, r, params) api.contextPool.Put(c) @@ -400,10 +455,15 @@ func (api *API) addResource(prototype jsonapi.MarshalIdentifier, source interfac } if _, ok := source.(ResourceUpdater); ok { - api.router.Handle("PATCH", baseURL+"/:id", func(w http.ResponseWriter, r *http.Request, params map[string]string) { + api.router.Handle("PATCH", baseURL+"/:id", func(w http.ResponseWriter, r *http.Request, params map[string]string, context map[string]interface{}) { info := requestInfo(r, api) c := api.contextPool.Get().(APIContexter) c.Reset() + + for key, val := range context { + c.Set(key, val) + } + api.middlewareChain(c, w, r) err := res.handleUpdate(c, w, r, params, *info) api.contextPool.Put(c) diff --git a/routing/echo.go b/routing/echo.go index e1a5326..fd48339 100644 --- a/routing/echo.go +++ b/routing/echo.go @@ -24,7 +24,7 @@ func (e echoRouter) Handle(protocol, route string, handler HandlerFunc) { params[p] = c.ParamValues()[i] } - handler(c.Response(), c.Request(), params) + handler(c.Response(), c.Request(), params, make(map[string]interface{})) return nil } diff --git a/routing/gingonic.go b/routing/gingonic.go index 51a26bf..65f6424 100644 --- a/routing/gingonic.go +++ b/routing/gingonic.go @@ -23,7 +23,7 @@ func (g ginRouter) Handle(protocol, route string, handler HandlerFunc) { params[p.Key] = p.Value } - handler(c.Writer, c.Request, params) + handler(c.Writer, c.Request, params, c.Keys) } g.router.Handle(protocol, route, wrappedCallback) diff --git a/routing/gingonic_test.go b/routing/gingonic_test.go index 23fd805..787b7bf 100644 --- a/routing/gingonic_test.go +++ b/routing/gingonic_test.go @@ -7,7 +7,9 @@ import ( "log" "net/http" "net/http/httptest" + "reflect" "strings" + "unsafe" "github.com/gin-gonic/gin" "github.com/manyminds/api2go" @@ -22,10 +24,13 @@ import ( var _ = Describe("api2go with gingonic router adapter", func() { var ( - router routing.Routeable - gg *gin.Engine - api *api2go.API - rec *httptest.ResponseRecorder + router routing.Routeable + gg *gin.Engine + api *api2go.API + rec *httptest.ResponseRecorder + contextKey = "userID" + contextValue *string + apiContext api2go.APIContext ) BeforeSuite(func() { @@ -38,9 +43,22 @@ var _ = Describe("api2go with gingonic router adapter", func() { router, ) + // Define the ApiContext to allow for access. + apiContext = api2go.APIContext{} + api.SetContextAllocator(func(*api2go.API) api2go.APIContexter { + return &apiContext + }) + userStorage := storage.NewUserStorage() chocStorage := storage.NewChocolateStorage() api.AddResource(model.User{}, resource.UserResource{ChocStorage: chocStorage, UserStorage: userStorage}) + + gg.Use(func(c *gin.Context) { + if contextValue != nil { + c.Set(contextKey, *contextValue) + } + }) + api.AddResource(model.Chocolate{}, resource.ChocolateResource{ChocStorage: chocStorage, UserStorage: userStorage}) }) @@ -144,4 +162,40 @@ var _ = Describe("api2go with gingonic router adapter", func() { Expect(string(rec.Body.Bytes())).To(MatchJSON(expected)) }) }) + + Context("Gin Context Key Copy Tests", func() { + BeforeEach(func() { + contextValue = nil + }) + + It("context value is present for chocolate resource", func() { + tempVal := "1" + contextValue = &tempVal + expected := `{"data":[],"meta":{"author": "The api2go examples crew", "license": "wtfpl", "license-url": "http://www.wtfpl.net"}}` + req, err := http.NewRequest("GET", "/api/chocolates", strings.NewReader("")) + Expect(err).To(BeNil()) + gg.ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(string(rec.Body.Bytes())).To(MatchJSON(expected)) + + rawKeys := reflect.ValueOf(&apiContext).Elem().Field(0) + keys := reflect.NewAt(rawKeys.Type(), unsafe.Pointer(rawKeys.UnsafeAddr())).Elem().Interface().(map[string]interface{}) + + Expect(keys).To(Equal(map[string]interface{}{contextKey: *contextValue})) + }) + + It("context value is not present for chocolate resource", func() { + expected := `{"data":[],"meta":{"author": "The api2go examples crew", "license": "wtfpl", "license-url": "http://www.wtfpl.net"}}` + req, err := http.NewRequest("GET", "/api/chocolates", strings.NewReader("")) + Expect(err).To(BeNil()) + gg.ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(string(rec.Body.Bytes())).To(MatchJSON(expected)) + + rawKeys := reflect.ValueOf(&apiContext).Elem().Field(0) + keys := reflect.NewAt(rawKeys.Type(), unsafe.Pointer(rawKeys.UnsafeAddr())).Elem().Interface().(map[string]interface{}) + + Expect(keys).To(BeNil()) + }) + }) }) diff --git a/routing/gorillamux.go b/routing/gorillamux.go index 9945aa5..f761359 100644 --- a/routing/gorillamux.go +++ b/routing/gorillamux.go @@ -20,7 +20,7 @@ func (gm gorillamuxRouter) Handler() http.Handler { func (gm gorillamuxRouter) Handle(protocol, route string, handler HandlerFunc) { wrappedHandler := func(w http.ResponseWriter, r *http.Request) { - handler(w, r, mux.Vars(r)) + handler(w, r, mux.Vars(r), make(map[string]interface{})) } // The request path will have parameterized segments indicated as :name. Convert diff --git a/routing/httprouter.go b/routing/httprouter.go index 2f6a6fd..ee6b8bb 100644 --- a/routing/httprouter.go +++ b/routing/httprouter.go @@ -19,7 +19,7 @@ func (h HTTPRouter) Handle(protocol, route string, handler HandlerFunc) { params[p.Key] = p.Value } - handler(w, r, params) + handler(w, r, params, make(map[string]interface{})) } h.router.Handle(protocol, route, wrappedCallback) diff --git a/routing/router.go b/routing/router.go index 9a7d9cd..5d42400 100644 --- a/routing/router.go +++ b/routing/router.go @@ -4,7 +4,7 @@ import "net/http" // HandlerFunc must contain all params from the route // in the form key,value -type HandlerFunc func(w http.ResponseWriter, r *http.Request, params map[string]string) +type HandlerFunc func(w http.ResponseWriter, r *http.Request, params map[string]string, context map[string]interface{}) // Routeable allows drop in replacement for api2go's router // by default, we are using julienschmidt/httprouter