Skip to content

Commit 0dab439

Browse files
committed
Fixed #600
Signed-off-by: Vishal Rana <vr@labstack.com>
1 parent c1358ed commit 0dab439

File tree

2 files changed

+53
-70
lines changed

2 files changed

+53
-70
lines changed

middleware/csrf.go

Lines changed: 51 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
package middleware
22

33
import (
4-
"crypto/hmac"
5-
"crypto/rand"
6-
"crypto/sha1"
7-
"encoding/hex"
4+
"crypto/subtle"
85
"errors"
9-
"fmt"
6+
"math/rand"
107
"net/http"
118
"strings"
129
"time"
@@ -17,8 +14,9 @@ import (
1714
type (
1815
// CSRFConfig defines the config for CSRF middleware.
1916
CSRFConfig struct {
20-
// Key to create CSRF token.
21-
Secret []byte `json:"secret"`
17+
// TokenLength is the length of the generated token.
18+
TokenLength uint8 `json:"token_length"`
19+
// Optional. Default value 32.
2220

2321
// TokenLookup is a string in the form of "<source>:<key>" that is used
2422
// to extract token from the request.
@@ -52,6 +50,10 @@ type (
5250
// Indicates if CSRF cookie is secure.
5351
// Optional. Default value false.
5452
CookieSecure bool `json:"cookie_secure"`
53+
54+
// Indicates if CSRF cookie is HTTP only.
55+
// Optional. Default value false.
56+
CookieHTTPOnly bool `json:"cookie_http_only"`
5557
}
5658

5759
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
@@ -62,6 +64,7 @@ type (
6264
var (
6365
// DefaultCSRFConfig is the default CSRF middleware config.
6466
DefaultCSRFConfig = CSRFConfig{
67+
TokenLength: 32,
6568
TokenLookup: "header:" + echo.HeaderXCSRFToken,
6669
ContextKey: "csrf",
6770
CookieName: "_csrf",
@@ -71,18 +74,17 @@ var (
7174

7275
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
7376
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
74-
func CSRF(secret []byte) echo.MiddlewareFunc {
77+
func CSRF() echo.MiddlewareFunc {
7578
c := DefaultCSRFConfig
76-
c.Secret = secret
7779
return CSRFWithConfig(c)
7880
}
7981

8082
// CSRFWithConfig returns a CSRF middleware from config.
8183
// See `CSRF()`.
8284
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
8385
// Defaults
84-
if config.Secret == nil {
85-
panic("csrf secret must be provided")
86+
if config.TokenLength == 0 {
87+
config.TokenLength = DefaultCSRFConfig.TokenLength
8688
}
8789
if config.TokenLookup == "" {
8890
config.TokenLookup = DefaultCSRFConfig.TokenLookup
@@ -110,51 +112,51 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
110112
return func(next echo.HandlerFunc) echo.HandlerFunc {
111113
return func(c echo.Context) error {
112114
req := c.Request()
113-
cookie, err := c.Cookie(config.CookieName)
115+
k, err := c.Cookie(config.CookieName)
114116
token := ""
115117

116118
if err != nil {
117-
// Token expired, generate it
118-
salt, err := generateSalt(8)
119-
if err != nil {
120-
return err
121-
}
122-
token = generateCSRFToken(config.Secret, salt)
123-
cookie := new(echo.Cookie)
124-
cookie.SetName(config.CookieName)
125-
cookie.SetValue(token)
126-
if config.CookiePath != "" {
127-
cookie.SetPath(config.CookiePath)
128-
}
129-
if config.CookieDomain != "" {
130-
cookie.SetDomain(config.CookieDomain)
131-
}
132-
cookie.SetExpires(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second))
133-
cookie.SetSecure(config.CookieSecure)
134-
cookie.SetHTTPOnly(true)
135-
c.SetCookie(cookie)
119+
// Generate token
120+
token = generateCSRFToken(config.TokenLength)
136121
} else {
137122
// Reuse token
138-
token = cookie.Value()
123+
token = k.Value()
139124
}
140125

141-
c.Set(config.ContextKey, token)
142-
143126
switch req.Method() {
144127
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
145128
default:
129+
// Validate token only for requests which are not defined as 'safe' by RFC7231
146130
clientToken, err := extractor(c)
147131
if err != nil {
148132
return err
149133
}
150-
ok, err := validateCSRFToken(token, clientToken, config.Secret)
151-
if err != nil {
152-
return err
153-
}
154-
if !ok {
155-
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
134+
if !validateCSRFToken(token, clientToken) {
135+
return echo.NewHTTPError(http.StatusForbidden, "csrf token is invalid")
156136
}
157137
}
138+
139+
// Set CSRF cookie
140+
cookie := new(echo.Cookie)
141+
cookie.SetName(config.CookieName)
142+
cookie.SetValue(token)
143+
if config.CookiePath != "" {
144+
cookie.SetPath(config.CookiePath)
145+
}
146+
if config.CookieDomain != "" {
147+
cookie.SetDomain(config.CookieDomain)
148+
}
149+
cookie.SetExpires(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second))
150+
cookie.SetSecure(config.CookieSecure)
151+
cookie.SetHTTPOnly(config.CookieHTTPOnly)
152+
c.SetCookie(cookie)
153+
154+
// Store token in the context
155+
c.Set(config.ContextKey, token)
156+
157+
// Protect clients from caching the response
158+
c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
159+
158160
return next(c)
159161
}
160162
}
@@ -192,29 +194,16 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor {
192194
}
193195
}
194196

195-
func generateCSRFToken(secret, salt []byte) string {
196-
h := hmac.New(sha1.New, secret)
197-
h.Write(salt)
198-
return fmt.Sprintf("%s:%s", hex.EncodeToString(h.Sum(nil)), hex.EncodeToString(salt))
199-
}
200-
201-
func validateCSRFToken(serverToken, clientToken string, secret []byte) (bool, error) {
202-
if serverToken != clientToken {
203-
return false, nil
204-
}
205-
sep := strings.Index(clientToken, ":")
206-
if sep < 0 {
207-
return false, nil
208-
}
209-
salt, err := hex.DecodeString(clientToken[sep+1:])
210-
if err != nil {
211-
return false, err
197+
func generateCSRFToken(n uint8) string {
198+
// TODO: From utility library
199+
chars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
200+
b := make([]byte, n)
201+
for i := range b {
202+
b[i] = chars[rand.Int63()%int64(len(chars))]
212203
}
213-
return clientToken == generateCSRFToken(secret, salt), nil
204+
return string(b)
214205
}
215206

216-
func generateSalt(len uint8) (salt []byte, err error) {
217-
salt = make([]byte, len)
218-
_, err = rand.Read(salt)
219-
return
207+
func validateCSRFToken(token, clientToken string) bool {
208+
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
220209
}

middleware/csrf_test.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,12 @@ func TestCSRF(t *testing.T) {
1717
rec := test.NewResponseRecorder()
1818
c := e.NewContext(req, rec)
1919
csrf := CSRFWithConfig(CSRFConfig{
20-
Secret: []byte("secret"),
20+
TokenLength: 16,
2121
})
2222
h := csrf(func(c echo.Context) error {
2323
return c.String(http.StatusOK, "test")
2424
})
2525

26-
// No secret
27-
assert.Panics(t, func() {
28-
CSRF(nil)
29-
})
30-
3126
// Generate CSRF token
3227
h(c)
3328
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
@@ -46,8 +41,7 @@ func TestCSRF(t *testing.T) {
4641
assert.Error(t, h(c))
4742

4843
// Valid CSRF token
49-
salt, _ := generateSalt(8)
50-
token := generateCSRFToken([]byte("secret"), salt)
44+
token := generateCSRFToken(16)
5145
req.Header().Set(echo.HeaderCookie, "_csrf="+token)
5246
req.Header().Set(echo.HeaderXCSRFToken, token)
5347
if assert.NoError(t, h(c)) {

0 commit comments

Comments
 (0)