Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 8 additions & 21 deletions web/bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc {

switch h := fn.(type) {
case http.HandlerFunc:
return warpHandlerCtx(h)
return warpContext(h)
case http.Handler:
return warpHandlerCtx(h.ServeHTTP)
return warpContext(h.ServeHTTP)
case func(http.ResponseWriter, *http.Request):
return warpHandlerCtx(h)
return warpContext(h)
default:
// valid func
if err := validMappingFunc(fnType); nil != err {
Expand All @@ -75,12 +75,8 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request) {

// param of context
ctx := request.Context()
webCtx := FromContext(ctx)
if nil == webCtx {
webCtx = &Context{Writer: writer, Request: request}
ctx = WithContext(request.Context(), webCtx)
}
webCtx := &Context{Writer: writer, Request: request}
ctx := WithContext(request.Context(), webCtx)

defer func() {
if nil != request.MultipartForm {
Expand Down Expand Up @@ -205,22 +201,13 @@ func validMappingFunc(fnType reflect.Type) error {
return nil
}

func warpHandlerCtx(handler http.HandlerFunc) http.HandlerFunc {
func warpContext(handler http.HandlerFunc) http.HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request) {
if nil != FromContext(request.Context()) {
handler.ServeHTTP(writer, request)
} else {
webCtx := &Context{Writer: writer, Request: request}
handler.ServeHTTP(writer, requestWithCtx(request, webCtx))
}
webCtx := &Context{Writer: writer, Request: request}
handler.ServeHTTP(writer, request.WithContext(WithContext(request.Context(), webCtx)))
}
}

func requestWithCtx(r *http.Request, webCtx *Context) *http.Request {
ctx := WithContext(r.Context(), webCtx)
return r.WithContext(ctx)
}

func defaultJsonRender(ctx *Context, err error, result interface{}) {

var code = 0
Expand Down
63 changes: 63 additions & 0 deletions web/bind_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package web

import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"

"go-spring.dev/spring/internal/utils/assert"
)

func TestBindWithoutParams(t *testing.T) {

var handler = func(ctx context.Context) string {
webCtx := FromContext(ctx)
assert.NotNil(t, webCtx)
return "0987654321"
}

request := httptest.NewRequest(http.MethodGet, "/get", strings.NewReader("{}"))
response := httptest.NewRecorder()
Bind(handler, RendererFunc(defaultJsonRender))(response, request)
assert.Equal(t, response.Body.String(), "{\"code\":0,\"data\":\"0987654321\"}\n")
}

func TestBindWithParams(t *testing.T) {
var handler = func(ctx context.Context, req struct {
Username string `json:"username"`
Password string `json:"password"`
}) string {
webCtx := FromContext(ctx)
assert.NotNil(t, webCtx)
assert.Equal(t, req.Username, "aaa")
assert.Equal(t, req.Password, "88888888")
return "success"
}

request := httptest.NewRequest(http.MethodPost, "/post", strings.NewReader(`{"username": "aaa", "password": "88888888"}`))
request.Header.Add("Content-Type", "application/json")
response := httptest.NewRecorder()
Bind(handler, RendererFunc(defaultJsonRender))(response, request)
assert.Equal(t, response.Body.String(), "{\"code\":0,\"data\":\"success\"}\n")
}

func TestBindWithParamsAndError(t *testing.T) {
var handler = func(ctx context.Context, req struct {
Username string `json:"username"`
Password string `json:"password"`
}) (string, error) {
webCtx := FromContext(ctx)
assert.NotNil(t, webCtx)
assert.Equal(t, req.Username, "aaa")
assert.Equal(t, req.Password, "88888888")
return "requestid: 9999999", Error(403, "user locked")
}

request := httptest.NewRequest(http.MethodPost, "/post", strings.NewReader(`{"username": "aaa", "password": "88888888"}`))
request.Header.Add("Content-Type", "application/json")
response := httptest.NewRecorder()
Bind(handler, RendererFunc(defaultJsonRender))(response, request)
assert.Equal(t, response.Body.String(), "{\"code\":403,\"message\":\"user locked\",\"data\":\"requestid: 9999999\"}\n")
}
90 changes: 52 additions & 38 deletions web/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,33 +54,9 @@ type Context struct {
// or to be sent by a client.
Request *http.Request

routes Routes

// SameSite allows a server to define a cookie attribute making it impossible for
// the browser to send this cookie along with cross-site requests.
sameSite http.SameSite

// URLParams are the stack of routeParams captured during the
// routing lifecycle across a stack of sub-routers.
urlParams RouteParams

// routeParams matched for the current sub-router. It is
// intentionally unexported so it can't be tampered.
routeParams RouteParams

// Routing path/method override used during the route search.
routePath string
routeMethod string

// The endpoint routing pattern that matched the request URI path
// or `RoutePath` of the current sub-router. This value will update
// during the lifecycle of a request passing through a stack of
// sub-routers.
routePattern string
routePatterns []string

methodNotAllowed bool
methodsAllowed []methodTyp
}

// Context returns the request's context.
Expand Down Expand Up @@ -116,7 +92,10 @@ func (c *Context) Cookie(name string) (string, bool) {

// PathParam returns the named variables in the request.
func (c *Context) PathParam(name string) (string, bool) {
return c.urlParams.Get(name)
if ctx := FromRouteContext(c.Request.Context()); nil != ctx {
return ctx.URLParams.Get(name)
}
return "", false
}

// QueryParam returns the named query in the request.
Expand Down Expand Up @@ -315,18 +294,53 @@ func (c *Context) ClientIP() string {
return remoteIP.String()
}

type routeContextKey struct{}

func WithRouteContext(parent context.Context, ctx *RouteContext) context.Context {
return context.WithValue(parent, routeContextKey{}, ctx)
}

func FromRouteContext(ctx context.Context) *RouteContext {
if v := ctx.Value(routeContextKey{}); v != nil {
return v.(*RouteContext)
}
return nil
}

type RouteContext struct {
Routes Routes
// URLParams are the stack of routeParams captured during the
// routing lifecycle across a stack of sub-routers.
URLParams RouteParams

// routeParams matched for the current sub-router. It is
// intentionally unexported so it can't be tampered.
routeParams RouteParams

// Routing path/method override used during the route search.
RoutePath string
RouteMethod string

// The endpoint routing pattern that matched the request URI path
// or `RoutePath` of the current sub-router. This value will update
// during the lifecycle of a request passing through a stack of
// sub-routers.
RoutePattern string
routePatterns []string

methodNotAllowed bool
methodsAllowed []methodTyp
}

// Reset context to initial state
func (c *Context) Reset() {
c.Writer = nil
c.Request = nil
c.sameSite = 0
c.routes = nil
c.routePath = ""
c.routeMethod = ""
c.routePattern = ""
func (c *RouteContext) Reset() {
c.Routes = nil
c.RoutePath = ""
c.RouteMethod = ""
c.RoutePattern = ""
c.routePatterns = c.routePatterns[:0]
c.urlParams.Keys = c.urlParams.Keys[:0]
c.urlParams.Values = c.urlParams.Values[:0]
c.URLParams.Keys = c.URLParams.Keys[:0]
c.URLParams.Values = c.URLParams.Values[:0]
c.routeParams.Keys = c.routeParams.Keys[:0]
c.routeParams.Values = c.routeParams.Values[:0]
c.methodNotAllowed = false
Expand All @@ -345,9 +359,9 @@ func (s *RouteParams) Add(key, value string) {
}

func (s *RouteParams) Get(key string) (value string, ok bool) {
for index, k := range s.Keys {
if key == k {
return s.Values[index], true
for i := len(s.Keys) - 1; i >= 0; i-- {
if s.Keys[i] == key {
return s.Values[i], true
}
}
return "", false
Expand Down
Loading