-
Notifications
You must be signed in to change notification settings - Fork 2
/
csrf.go
73 lines (55 loc) · 1.65 KB
/
csrf.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
package page
import (
"bytes"
"errors"
"io"
"mime"
"mime/multipart"
"net/http"
"github.com/ministryofjustice/opg-modernising-lpa/internal/sesh"
)
type contextKey string
var ErrCsrfInvalid = errors.New("CSRF token not valid")
const csrfTokenLength = 12
func ValidateCsrf(next http.Handler, store SessionStore, randomString func(int) string, errorHandler ErrorHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
csrfSession, err := store.Csrf(r)
if r.Method == http.MethodPost {
if err != nil {
errorHandler(w, r, err)
return
}
if !csrfValid(r, csrfSession) {
errorHandler(w, r, ErrCsrfInvalid)
return
}
}
if err != nil {
csrfSession = &sesh.CsrfSession{Token: randomString(csrfTokenLength)}
_ = store.SetCsrf(r, w, csrfSession)
}
appData := AppDataFromContext(ctx)
appData.CsrfToken = csrfSession.Token
next.ServeHTTP(w, r.WithContext(ContextWithAppData(ctx, appData)))
}
}
func csrfValid(r *http.Request, csrfSession *sesh.CsrfSession) bool {
cookieValue := csrfSession.Token
if mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")); err == nil && mediaType == "multipart/form-data" {
var buf bytes.Buffer
reader := multipart.NewReader(io.TeeReader(r.Body, &buf), params["boundary"])
part, err := reader.NextPart()
if err != nil {
return false
}
if part.FormName() != "csrf" {
return false
}
lmt := io.LimitReader(part, csrfTokenLength+1)
value, _ := io.ReadAll(lmt)
r.Body = MultiReadCloser(io.NopCloser(&buf), r.Body)
return string(value) == cookieValue
}
return r.PostFormValue("csrf") == cookieValue
}