From fbcde8c9cd18e98b13f3f462a06347a7246df52c Mon Sep 17 00:00:00 2001 From: Dmitry Kotik <7944694+dkotik@users.noreply.github.com> Date: Thu, 29 Feb 2024 13:54:17 +0100 Subject: [PATCH] add method mux --- examples/htmxform/feedback/form.go | 33 ++------ examples/htmxform/main.go | 7 +- method.go | 132 +++++++++++++++++++++++++++++ redirect.go | 29 +++++++ 4 files changed, 174 insertions(+), 27 deletions(-) create mode 100644 method.go create mode 100644 redirect.go diff --git a/examples/htmxform/feedback/form.go b/examples/htmxform/feedback/form.go index 9e2524c..1c9494e 100644 --- a/examples/htmxform/feedback/form.go +++ b/examples/htmxform/feedback/form.go @@ -81,28 +81,12 @@ func (f *formResponse) Success() (string, error) { }) } -type formHandler struct { - get http.Handler - post http.Handler -} - -func (h *formHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - h.get.ServeHTTP(w, r) - case http.MethodPost: - h.post.ServeHTTP(w, r) - default: - h.get.ServeHTTP(w, r) - } -} - -func New(sender Sender) (http.Handler, error) { +func New(sender Sender) (get, post http.Handler, err error) { if sender == nil { - return nil, errors.New("cannot use a feedback sender") + return nil, nil, errors.New("cannot use a feedback sender") } - get, err := htadaptor.NewNullaryFuncAdaptor( + get, err = htadaptor.NewNullaryFuncAdaptor( func(ctx context.Context) (*formResponse, error) { // localizer is passed through context using // acceptlanguage middleware all the same @@ -117,10 +101,10 @@ func New(sender Sender) (http.Handler, error) { htadaptor.WithTemplate(templates.Lookup("page")), ) if err != nil { - return nil, fmt.Errorf("unable to create get handler: %w", err) + return nil, nil, fmt.Errorf("unable to create get handler: %w", err) } - post, err := htadaptor.NewUnaryFuncAdaptor( + post, err = htadaptor.NewUnaryFuncAdaptor( func(ctx context.Context, r *formRequest) (*formResponse, error) { // localizer is passed through context using // acceptlanguage middleware @@ -144,11 +128,8 @@ func New(sender Sender) (http.Handler, error) { htadaptor.WithTemplate(templates.Lookup("form")), ) if err != nil { - return nil, fmt.Errorf("unable to create post handler: %w", err) + return nil, nil, fmt.Errorf("unable to create post handler: %w", err) } - return &formHandler{ - get: get, - post: post, - }, nil + return get, post, nil } diff --git a/examples/htmxform/main.go b/examples/htmxform/main.go index 2f98c69..252a3d8 100644 --- a/examples/htmxform/main.go +++ b/examples/htmxform/main.go @@ -51,7 +51,12 @@ func main() { "message", ), ))) - mux.Handle("/{$}", htadaptor.Must(feedback.New(mailer))) + getForm, postForm, err := feedback.New(mailer) + if err != nil { + panic(err) + } + mux.Handle("GET /{$}", getForm) + mux.Handle("POST /{$}", postForm) fmt.Printf( `Listening at http://%[1]s/ diff --git a/method.go b/method.go new file mode 100644 index 0000000..07d75b3 --- /dev/null +++ b/method.go @@ -0,0 +1,132 @@ +package htadaptor + +import ( + "io" + "net/http" + "reflect" + "strings" +) + +// MethodSwitch provides method selections for [NewMethodMux]. +type MethodSwitch struct { + Get http.Handler + Post http.Handler + Put http.Handler + Patch http.Handler + Delete http.Handler + Head http.Handler +} + +func (ms *MethodSwitch) AllowedMethods() (methods []string) { + if ms.Get != nil { + methods = append(methods, http.MethodGet) + } + if ms.Post != nil { + methods = append(methods, http.MethodPost) + } + if ms.Put != nil { + methods = append(methods, http.MethodPut) + } + if ms.Patch != nil { + methods = append(methods, http.MethodPatch) + } + if ms.Delete != nil { + methods = append(methods, http.MethodDelete) + } + return +} + +type methodMux struct { + Get http.Handler + Post http.Handler + Put http.Handler + Patch http.Handler + Delete http.Handler + Head http.Handler + allowed string +} + +func NewMethodMux(ms *MethodSwitch) http.Handler { + if ms == nil { + return &getPostMux{} + } + allowed := ms.AllowedMethods() + if len(allowed) == 2 && reflect.DeepEqual(allowed, []string{"GET", "POST"}) { + return &getPostMux{ + Get: ms.Get, + Post: ms.Post, + } + } + if ms.Head == nil { + ms.Head = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + } + return &methodMux{ + Get: ms.Get, + Post: ms.Post, + Put: ms.Put, + Patch: ms.Patch, + Delete: ms.Delete, + Head: ms.Head, + allowed: strings.Join(append(allowed, http.MethodHead), ", "), + } +} + +func (m *methodMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.Method { // http.Request ALWAYS has a method + case http.MethodGet: + if m.Get != nil { + m.Get.ServeHTTP(w, r) + return + } + case http.MethodPost: + if m.Post != nil { + m.Post.ServeHTTP(w, r) + return + } + case http.MethodPut: + if m.Put != nil { + m.Put.ServeHTTP(w, r) + return + } + case http.MethodPatch: + if m.Patch != nil { + m.Patch.ServeHTTP(w, r) + return + } + case http.MethodDelete: + if m.Delete != nil { + m.Delete.ServeHTTP(w, r) + return + } + case http.MethodOptions: + w.Header().Set("Allow", m.allowed) + return + case http.MethodHead: + m.Head.ServeHTTP(w, r) + } + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = io.WriteString(w, http.StatusText(http.StatusMethodNotAllowed)) +} + +type getPostMux struct { + Get http.Handler + Post http.Handler +} + +func (m *getPostMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.Method { // http.Request ALWAYS has a method + case http.MethodGet: + m.Get.ServeHTTP(w, r) + return + case http.MethodPost: + m.Post.ServeHTTP(w, r) + return + case http.MethodOptions: + w.Header().Set("Allow", "GET, POST, HEAD") + return + case http.MethodHead: + return // no operation + } + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = io.WriteString(w, http.StatusText(http.StatusMethodNotAllowed)) +} diff --git a/redirect.go b/redirect.go new file mode 100644 index 0000000..a7023ac --- /dev/null +++ b/redirect.go @@ -0,0 +1,29 @@ +package htadaptor + +import "net/http" + +type temporaryRedirect string + +func (t temporaryRedirect) ServeHTTP( + w http.ResponseWriter, + r *http.Request, +) { + http.Redirect(w, r, string(t), http.StatusTemporaryRedirect) +} + +func NewTemporaryRedirect(to string) http.Handler { + return temporaryRedirect(to) +} + +type permanentRedirect string + +func (p permanentRedirect) ServeHTTP( + w http.ResponseWriter, + r *http.Request, +) { + http.Redirect(w, r, string(p), http.StatusPermanentRedirect) +} + +func NewPermanentRedirect(to string) http.Handler { + return permanentRedirect(to) +}