11package middleware
22
33import (
4- "crypto/hmac"
5- "crypto/rand"
6- "crypto/sha1"
7- "encoding/hex"
4+ "crypto/subtle"
85 "errors"
9- "fmt "
6+ "math/rand "
107 "net/http"
118 "strings"
129 "time"
@@ -17,8 +14,9 @@ import (
1714type (
1815 // CSRFConfig defines the config for CSRF middleware.
1916 CSRFConfig struct {
20- // Key to create CSRF token.
21- Secret []byte `json:"secret"`
17+ // TokenLength is the length of the generated token.
18+ TokenLength uint8 `json:"token_length"`
19+ // Optional. Default value 32.
2220
2321 // TokenLookup is a string in the form of "<source>:<key>" that is used
2422 // to extract token from the request.
@@ -52,6 +50,10 @@ type (
5250 // Indicates if CSRF cookie is secure.
5351 // Optional. Default value false.
5452 CookieSecure bool `json:"cookie_secure"`
53+
54+ // Indicates if CSRF cookie is HTTP only.
55+ // Optional. Default value false.
56+ CookieHTTPOnly bool `json:"cookie_http_only"`
5557 }
5658
5759 // csrfTokenExtractor defines a function that takes `echo.Context` and returns
6264var (
6365 // DefaultCSRFConfig is the default CSRF middleware config.
6466 DefaultCSRFConfig = CSRFConfig {
67+ TokenLength : 32 ,
6568 TokenLookup : "header:" + echo .HeaderXCSRFToken ,
6669 ContextKey : "csrf" ,
6770 CookieName : "_csrf" ,
@@ -71,18 +74,17 @@ var (
7174
7275// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
7376// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
74- func CSRF (secret [] byte ) echo.MiddlewareFunc {
77+ func CSRF () echo.MiddlewareFunc {
7578 c := DefaultCSRFConfig
76- c .Secret = secret
7779 return CSRFWithConfig (c )
7880}
7981
8082// CSRFWithConfig returns a CSRF middleware from config.
8183// See `CSRF()`.
8284func CSRFWithConfig (config CSRFConfig ) echo.MiddlewareFunc {
8385 // Defaults
84- if config .Secret == nil {
85- panic ( "csrf secret must be provided" )
86+ if config .TokenLength == 0 {
87+ config . TokenLength = DefaultCSRFConfig . TokenLength
8688 }
8789 if config .TokenLookup == "" {
8890 config .TokenLookup = DefaultCSRFConfig .TokenLookup
@@ -110,51 +112,51 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
110112 return func (next echo.HandlerFunc ) echo.HandlerFunc {
111113 return func (c echo.Context ) error {
112114 req := c .Request ()
113- cookie , err := c .Cookie (config .CookieName )
115+ k , err := c .Cookie (config .CookieName )
114116 token := ""
115117
116118 if err != nil {
117- // Token expired, generate it
118- salt , err := generateSalt (8 )
119- if err != nil {
120- return err
121- }
122- token = generateCSRFToken (config .Secret , salt )
123- cookie := new (echo.Cookie )
124- cookie .SetName (config .CookieName )
125- cookie .SetValue (token )
126- if config .CookiePath != "" {
127- cookie .SetPath (config .CookiePath )
128- }
129- if config .CookieDomain != "" {
130- cookie .SetDomain (config .CookieDomain )
131- }
132- cookie .SetExpires (time .Now ().Add (time .Duration (config .CookieMaxAge ) * time .Second ))
133- cookie .SetSecure (config .CookieSecure )
134- cookie .SetHTTPOnly (true )
135- c .SetCookie (cookie )
119+ // Generate token
120+ token = generateCSRFToken (config .TokenLength )
136121 } else {
137122 // Reuse token
138- token = cookie .Value ()
123+ token = k .Value ()
139124 }
140125
141- c .Set (config .ContextKey , token )
142-
143126 switch req .Method () {
144127 case echo .GET , echo .HEAD , echo .OPTIONS , echo .TRACE :
145128 default :
129+ // Validate token only for requests which are not defined as 'safe' by RFC7231
146130 clientToken , err := extractor (c )
147131 if err != nil {
148132 return err
149133 }
150- ok , err := validateCSRFToken (token , clientToken , config .Secret )
151- if err != nil {
152- return err
153- }
154- if ! ok {
155- return echo .NewHTTPError (http .StatusForbidden , "invalid csrf token" )
134+ if ! validateCSRFToken (token , clientToken ) {
135+ return echo .NewHTTPError (http .StatusForbidden , "csrf token is invalid" )
156136 }
157137 }
138+
139+ // Set CSRF cookie
140+ cookie := new (echo.Cookie )
141+ cookie .SetName (config .CookieName )
142+ cookie .SetValue (token )
143+ if config .CookiePath != "" {
144+ cookie .SetPath (config .CookiePath )
145+ }
146+ if config .CookieDomain != "" {
147+ cookie .SetDomain (config .CookieDomain )
148+ }
149+ cookie .SetExpires (time .Now ().Add (time .Duration (config .CookieMaxAge ) * time .Second ))
150+ cookie .SetSecure (config .CookieSecure )
151+ cookie .SetHTTPOnly (config .CookieHTTPOnly )
152+ c .SetCookie (cookie )
153+
154+ // Store token in the context
155+ c .Set (config .ContextKey , token )
156+
157+ // Protect clients from caching the response
158+ c .Response ().Header ().Add (echo .HeaderVary , echo .HeaderCookie )
159+
158160 return next (c )
159161 }
160162 }
@@ -192,29 +194,16 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor {
192194 }
193195}
194196
195- func generateCSRFToken (secret , salt []byte ) string {
196- h := hmac .New (sha1 .New , secret )
197- h .Write (salt )
198- return fmt .Sprintf ("%s:%s" , hex .EncodeToString (h .Sum (nil )), hex .EncodeToString (salt ))
199- }
200-
201- func validateCSRFToken (serverToken , clientToken string , secret []byte ) (bool , error ) {
202- if serverToken != clientToken {
203- return false , nil
204- }
205- sep := strings .Index (clientToken , ":" )
206- if sep < 0 {
207- return false , nil
208- }
209- salt , err := hex .DecodeString (clientToken [sep + 1 :])
210- if err != nil {
211- return false , err
197+ func generateCSRFToken (n uint8 ) string {
198+ // TODO: From utility library
199+ chars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
200+ b := make ([]byte , n )
201+ for i := range b {
202+ b [i ] = chars [rand .Int63 ()% int64 (len (chars ))]
212203 }
213- return clientToken == generateCSRFToken ( secret , salt ), nil
204+ return string ( b )
214205}
215206
216- func generateSalt (len uint8 ) (salt []byte , err error ) {
217- salt = make ([]byte , len )
218- _ , err = rand .Read (salt )
219- return
207+ func validateCSRFToken (token , clientToken string ) bool {
208+ return subtle .ConstantTimeCompare ([]byte (token ), []byte (clientToken )) == 1
220209}
0 commit comments