-
Notifications
You must be signed in to change notification settings - Fork 2
/
csrf.go
82 lines (64 loc) · 1.85 KB
/
csrf.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package page
import (
"bytes"
"errors"
"io"
"mime"
"mime/multipart"
"net/http"
"github.com/gorilla/sessions"
)
type contextKey string
var ErrCsrfInvalid = errors.New("CSRF token not valid")
const csrfTokenLength = 12
func ValidateCsrf(next http.Handler, store sessions.Store, randomString func(int) string, errorHandler ErrorHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
csrfSession, err := store.Get(r, "csrf")
if r.Method == http.MethodPost {
if err != nil {
errorHandler(w, r, err)
return
}
if !csrfValid(r, csrfSession) {
errorHandler(w, r, ErrCsrfInvalid)
return
}
}
if csrfSession.IsNew {
csrfSession.Values = map[any]any{"token": randomString(csrfTokenLength)}
csrfSession.Options = &sessions.Options{
MaxAge: 24 * 60 * 60,
Secure: true,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
_ = store.Save(r, w, csrfSession)
}
appData := AppDataFromContext(ctx)
appData.CsrfToken, _ = csrfSession.Values["token"].(string)
next.ServeHTTP(w, r.WithContext(ContextWithAppData(ctx, appData)))
}
}
func csrfValid(r *http.Request, csrfSession *sessions.Session) bool {
cookieValue, ok := csrfSession.Values["token"].(string)
if !ok {
return false
}
if mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")); err == nil && mediaType == "multipart/form-data" {
var buf bytes.Buffer
reader := multipart.NewReader(io.TeeReader(r.Body, &buf), params["boundary"])
part, err := reader.NextPart()
if err != nil {
return false
}
if part.FormName() != "csrf" {
return false
}
lmt := io.LimitReader(part, csrfTokenLength+1)
value, _ := io.ReadAll(lmt)
r.Body = MultiReadCloser(io.NopCloser(&buf), r.Body)
return string(value) == cookieValue
}
return r.PostFormValue("csrf") == cookieValue
}