/
methodoverride.go
212 lines (184 loc) · 5.52 KB
/
methodoverride.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
package methodoverride
import (
stdContext "context"
"net/http"
"strings"
"github.com/kataras/iris/v12/context"
"github.com/kataras/iris/v12/core/router"
)
type options struct {
getters []GetterFunc
methods []string
saveOriginalMethodContextKey interface{} // if not nil original value will be saved.
}
func (o *options) configure(opts ...Option) {
for _, opt := range opts {
opt(o)
}
}
func (o *options) canOverride(method string) bool {
for _, s := range o.methods {
if s == method {
return true
}
}
return false
}
func (o *options) get(w http.ResponseWriter, r *http.Request) string {
for _, getter := range o.getters {
if v := getter(w, r); v != "" {
return strings.ToUpper(v)
}
}
return ""
}
// Option sets options for a fresh method override wrapper.
// See `New` package-level function for more.
type Option func(*options)
// Methods can be used to add methods that can be overridden.
// Defaults to "POST".
func Methods(methods ...string) Option {
for i, s := range methods {
methods[i] = strings.ToUpper(s)
}
return func(opts *options) {
opts.methods = append(opts.methods, methods...)
}
}
// SaveOriginalMethod will save the original method
// on Context.Request().Context().Value(requestContextKey).
//
// Defaults to nil, don't save it.
func SaveOriginalMethod(requestContextKey interface{}) Option {
return func(opts *options) {
if requestContextKey == nil {
opts.saveOriginalMethodContextKey = nil
}
opts.saveOriginalMethodContextKey = requestContextKey
}
}
// GetterFunc is the type signature for declaring custom logic
// to extract the method name which a POST request will be replaced with.
type GetterFunc func(http.ResponseWriter, *http.Request) string
// Getter sets a custom logic to use to extract the method name
// to override the POST method with.
// Defaults to nil.
func Getter(customFunc GetterFunc) Option {
return func(opts *options) {
opts.getters = append(opts.getters, customFunc)
}
}
// Headers that client can send to specify a method
// to override the POST method with.
//
// Defaults to:
// X-HTTP-Method
// X-HTTP-Method-Override
// X-Method-Override
func Headers(headers ...string) Option {
getter := func(w http.ResponseWriter, r *http.Request) string {
for _, s := range headers {
if v := r.Header.Get(s); v != "" {
w.Header().Add("Vary", s)
return v
}
}
return ""
}
return Getter(getter)
}
// FormField specifies a form field to use to determinate the method
// to override the POST method with.
//
// Example Field:
// <input type="hidden" name="_method" value="DELETE">
//
// Defaults to: "_method".
func FormField(fieldName string) Option {
return FormFieldWithConf(fieldName, nil)
}
// FormFieldWithConf same as `FormField` but it accepts the application's
// configuration to parse the form based on the app core configuration.
func FormFieldWithConf(fieldName string, conf context.ConfigurationReadOnly) Option {
var (
postMaxMemory int64 = 32 << 20 // 32 MB
resetBody = false
)
if conf != nil {
postMaxMemory = conf.GetPostMaxMemory()
resetBody = conf.GetDisableBodyConsumptionOnUnmarshal()
}
getter := func(w http.ResponseWriter, r *http.Request) string {
return context.FormValueDefault(r, fieldName, "", postMaxMemory, resetBody)
}
return Getter(getter)
}
// Query specifies a url parameter name to use to determinate the method
// to override the POST methos with.
//
// Example URL Query string:
// http://localhost:8080/path?_method=DELETE
//
// Defaults to: "_method".
func Query(paramName string) Option {
getter := func(w http.ResponseWriter, r *http.Request) string {
return r.URL.Query().Get(paramName)
}
return Getter(getter)
}
// Only clears all default or previously registered values
// and uses only the "o" option(s).
//
// The default behavior is to check for all the following by order:
// headers, form field, query string
// and any custom getter (if set).
// Use this method to override that
// behavior and use only the passed option(s)
// to determinate the method to override with.
//
// Use cases:
// 1. When need to check only for headers and ignore other fields:
// New(Only(Headers("X-Custom-Header")))
//
// 2. When need to check only for (first) form field and (second) custom getter:
// New(Only(FormField("fieldName"), Getter(...)))
func Only(o ...Option) Option {
return func(opts *options) {
opts.getters = opts.getters[0:0]
opts.configure(o...)
}
}
// New returns a new method override wrapper
// which can be registered with `Application.WrapRouter`.
//
// Use this wrapper when you expecting clients
// that do not support certain HTTP operations such as DELETE or PUT for security reasons.
// This wrapper will accept a method, based on criteria, to override the POST method with.
//
//
// Read more at:
// https://github.com/kataras/iris/issues/1325
func New(opt ...Option) router.WrapperFunc {
opts := new(options)
// Default values.
opts.configure(
Methods(http.MethodPost),
Headers("X-HTTP-Method", "X-HTTP-Method-Override", "X-Method-Override"),
FormField("_method"),
Query("_method"),
)
opts.configure(opt...)
return func(w http.ResponseWriter, r *http.Request, proceed http.HandlerFunc) {
originalMethod := strings.ToUpper(r.Method)
if opts.canOverride(originalMethod) {
newMethod := opts.get(w, r)
if newMethod != "" {
if opts.saveOriginalMethodContextKey != nil {
r = r.WithContext(stdContext.WithValue(r.Context(), opts.saveOriginalMethodContextKey, originalMethod))
}
r.Method = newMethod
}
}
proceed(w, r)
}
}