diff --git a/web/bind.go b/web/bind.go index f6f025a2..8fc574e2 100644 --- a/web/bind.go +++ b/web/bind.go @@ -17,24 +17,22 @@ package web import ( - "context" "fmt" "net/http" "reflect" "github.com/go-spring-projects/go-spring/internal/utils" "github.com/go-spring-projects/go-spring/web/binding" - "github.com/go-spring-projects/go-spring/web/render" ) type Renderer interface { - Render(ctx context.Context, err error, result interface{}) render.Renderer + Render(ctx *Context, err error, result interface{}) } -type RendererFunc func(ctx context.Context, err error, result interface{}) render.Renderer +type RendererFunc func(ctx *Context, err error, result interface{}) -func (fn RendererFunc) Render(ctx context.Context, err error, result interface{}) render.Renderer { - return fn(ctx, err, result) +func (fn RendererFunc) Render(ctx *Context, err error, result interface{}) { + fn(ctx, err, result) } // Bind convert fn to HandlerFunc. @@ -76,9 +74,9 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc { defer func() { if nil != request.MultipartForm { - request.MultipartForm.RemoveAll() + _ = request.MultipartForm.RemoveAll() } - request.Body.Close() + _ = request.Body.Close() }() var returnValues []reflect.Value @@ -93,7 +91,7 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc { } // render error response - render.Render(ctx, err, nil).Render(writer) + render.Render(webCtx, err, nil) } }() @@ -144,7 +142,7 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc { } // render response - render.Render(ctx, err, result).Render(writer) + render.Render(webCtx, err, result) } } diff --git a/web/binding/binding.go b/web/binding/binding.go index 2ece4c70..c299361a 100644 --- a/web/binding/binding.go +++ b/web/binding/binding.go @@ -18,6 +18,7 @@ package binding import ( + "errors" "fmt" "io" "mime" @@ -30,6 +31,9 @@ import ( "github.com/go-spring-projects/go-spring/conf" ) +var ErrBinding = errors.New("binding failed") +var ErrValidate = errors.New("validate failed") + const ( MIMEApplicationJSON = "application/json" MIMEApplicationXML = "application/xml" @@ -93,12 +97,17 @@ func RegisterBodyBinder(mime string, binder BodyBinder) { func Bind(i interface{}, r Request) error { if err := bindScope(i, r); err != nil { - return err + return fmt.Errorf("%w: %v", ErrBinding, err) } + if err := bindBody(i, r); err != nil { - return err + return fmt.Errorf("%w: %v", ErrBinding, err) } - return conf.ValidateStruct(i) + + if err := conf.ValidateStruct(i); nil != err { + return fmt.Errorf("%w: %v", ErrValidate, err) + } + return nil } func bindBody(i interface{}, r Request) error { diff --git a/web/error.go b/web/error.go index 9f2944a3..2f3b6018 100644 --- a/web/error.go +++ b/web/error.go @@ -19,7 +19,6 @@ package web import ( "fmt" "net/http" - "strings" ) type HttpError struct { @@ -31,10 +30,10 @@ func (e HttpError) Error() string { return fmt.Sprintf("%d: %s", e.Code, e.Message) } -func Error(code int, msg ...string) HttpError { +func Error(code int, format string, args ...interface{}) HttpError { var message = http.StatusText(code) - if len(msg) > 0 { - message = strings.Join(msg, ",") + if len(format) > 0 { + message = fmt.Sprintf(format, args...) } return HttpError{Code: code, Message: message} } diff --git a/web/examples/greeting/main.go b/web/examples/greeting/main.go index 17e3ac27..067743ee 100644 --- a/web/examples/greeting/main.go +++ b/web/examples/greeting/main.go @@ -22,6 +22,7 @@ import ( "log/slog" "math/rand" "mime/multipart" + "net/http" "time" "github.com/go-spring-projects/go-spring/gs" @@ -38,6 +39,19 @@ func (g *Greeting) OnInit(ctx context.Context) error { g.Server.Bind("/greeting", g.Greeting) g.Server.Bind("/health", g.Health) g.Server.Bind("/user/register/{username}/{password}", g.Register) + g.Server.Bind("/user/password", g.UpdatePassword) + + g.Server.Use(func(handler http.Handler) http.Handler { + + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + + start := time.Now() + handler.ServeHTTP(writer, request) + g.Logger.Info("http handle cost", + slog.String("path", request.URL.Path), slog.Duration("cost", time.Since(start))) + }) + }) + return nil } @@ -57,13 +71,13 @@ func (g *Greeting) Health(ctx context.Context) (string, error) { func (g *Greeting) Register( ctx context.Context, req struct { - Username string `path:"username"` // 用户名 - Password string `path:"password"` // 密码 - HeadImg *multipart.FileHeader `form:"headImg"` // 上传头像 - Captcha string `form:"captcha"` // 验证码 - UserAgent string `header:"User-Agent"` // 用户代理 - Ad string `query:"ad"` // 推广ID - Token string `cookie:"token"` // cookie参数 + Username string `path:"username" expr:"len($)>6 && len($)<20"` // username + Password string `path:"password" expr:"len($)>6 && len($)<20"` // password + HeadImg *multipart.FileHeader `form:"headImg"` // upload head image + Captcha string `form:"captcha" expr:"len($)==4"` // captcha + UserAgent string `header:"User-Agent"` // user agent + Ad string `query:"ad"` // AD + Token string `cookie:"token"` // token }, ) string { g.Logger.Info("register user", @@ -79,6 +93,17 @@ func (g *Greeting) Register( return "ok" } +func (g *Greeting) UpdatePassword( + ctx context.Context, + req struct { + Password string `json:"password" expr:"len($) > 6 && len($) < 20"` // password + Token string `cookie:"token"` // token + }, +) string { + g.Logger.Info("change password", slog.String("password", req.Password)) + return "ok" +} + func main() { gs.Object(new(Greeting)) diff --git a/web/server.go b/web/server.go index f9ccbc91..247fc0aa 100644 --- a/web/server.go +++ b/web/server.go @@ -22,7 +22,7 @@ import ( "errors" "net/http" - "github.com/go-spring-projects/go-spring/web/render" + "github.com/go-spring-projects/go-spring/web/binding" ) // A Server defines parameters for running an HTTP server. @@ -54,13 +54,7 @@ func NewServer(router *Router, options Options) *Server { } } - var jsonRenderer = func(ctx context.Context, err error, result interface{}) render.Renderer { - - type jsonResponse struct { - Code int `json:"code"` - Message string `json:"message,omitempty"` - Data interface{} `json:"data"` - } + var jsonRenderer = func(ctx *Context, err error, result interface{}) { var code = 0 var message = "" @@ -72,10 +66,20 @@ func NewServer(router *Router, options Options) *Server { } else { code = http.StatusInternalServerError message = err.Error() + + if errors.Is(err, binding.ErrBinding) || errors.Is(err, binding.ErrValidate) { + code = http.StatusBadRequest + } } } - return render.JsonRenderer{Data: jsonResponse{Code: code, Message: message, Data: result}} + type response struct { + Code int `json:"code"` + Message string `json:"message,omitempty"` + Data interface{} `json:"data"` + } + + ctx.JSON(http.StatusOK, response{Code: code, Message: message, Data: result}) } return &Server{ @@ -121,6 +125,11 @@ func (s *Server) Shutdown(ctx context.Context) error { return s.httpSvr.Shutdown(ctx) } +// Router returns the server router. +func (s *Server) Router() *Router { + return s.router +} + // NotFound to be used when no route matches. // This can be used to render your own 404 Not Found errors. func (s *Server) NotFound(handler http.Handler) { @@ -138,6 +147,12 @@ func (s *Server) Renderer(renderer Renderer) { s.renderer = renderer } +// Use appends a MiddlewareFunc to the chain. +// Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router. +func (s *Server) Use(mwf ...MiddlewareFunc) { + s.router.Use(mwf...) +} + // Match attempts to match the given request against the router's registered routes. // // If the request matches a route of this router or one of its subrouters the Route,