Skip to content

Commit

Permalink
feat(middleware/csrf): TrustedOrigins using https://*.example.com sty…
Browse files Browse the repository at this point in the history
…le subdomains (#2925)

* feat(middleware/csrf): TrustedOrigins using https://*.example.com style subdomains

* Update middleware/csrf/csrf_test.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* test(middleware/csrf): parallel test

* test(middleware/csrf): parallel fix

* chmore(middleware/csrf): no pkg/log

* feat(middleware/csrf): Add tests for Trusted Origin deeply nested subdomain

* test(middleware/csrf): fix loop variable tt being captured

* docs(middleware/csrf): TrustedOrigins validates and normalizes note

* test(middleware/csrf): fix Benchmark_Middleware_CSRF_Check

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 25, 2024
1 parent 95c1814 commit 643b4b3
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 81 deletions.
33 changes: 31 additions & 2 deletions docs/api/middleware/csrf.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (h *Handler) DeleteToken(c fiber.Ctx) error
| Storage | `fiber.Storage` | Store is used to store the state of the middleware. | `nil` |
| Session | `*session.Store` | Session is used to store the state of the middleware. Overrides Storage if set. | `nil` |
| SessionKey | `string` | SessionKey is the key used to store the token in the session. | "csrfToken" |
| TrustedOrigins | `[]string` | TrustedOrigins is a list of trusted origins for unsafe requests. This supports subdomain matching, so you can use a value like "https://.example.com" to allow any subdomain of example.com to submit requests. | `[]` |
| TrustedOrigins | `[]string` | TrustedOrigins is a list of trusted origins for unsafe requests. This supports subdomain matching, so you can use a value like "https://*.example.com" to allow any subdomain of example.com to submit requests. | `[]` |

### Default Config

Expand Down Expand Up @@ -154,6 +154,36 @@ var ConfigDefault = Config{
}
```

### Trusted Origins

The `TrustedOrigins` option is used to specify a list of trusted origins for unsafe requests. This is useful when you want to allow requests from other origins. This supports matching subdomains at any level. This means you can use a value like `"https://*.example.com"` to allow any subdomain of `example.com` to submit requests, including multiple subdomain levels such as `"https://sub.sub.example.com"`.

To ensure that the provided `TrustedOrigins` origins are correctly formatted, this middleware validates and normalizes them. It checks for valid schemes, i.e., HTTP or HTTPS, and it will automatically remove trailing slashes. If the provided origin is invalid, the middleware will panic.

#### Example with Explicit Origins

In the following example, the CSRF middleware will allow requests from `trusted.example.com`, in addition to the current host.

```go
app.Use(csrf.New(csrf.Config{
TrustedOrigins: []string{"https://trusted.example.com"},
}))
```

#### Example with Subdomain Matching

In the following example, the CSRF middleware will allow requests from any subdomain of `example.com`, in addition to the current host.

```go
app.Use(csrf.New(csrf.Config{
TrustedOrigins: []string{"https://*.example.com"},
}))
```

::caution
When using `TrustedOrigins` with subdomain matching, make sure you control and trust all the subdomains, including all subdomain levels. If not, an attacker could create a subdomain under a trusted origin and use it to send harmful requests.
:::

## Constants

```go
Expand Down Expand Up @@ -273,7 +303,6 @@ When HTTPS requests are protected by CSRF, referer checking is always carried ou
The Referer header is automatically included in requests by all modern browsers, including those made using the JS Fetch API. However, if you're making use of this middleware with a custom client, it's important to ensure that the client sends a valid Referer header.
:::


### Token Lifecycle

Tokens are valid until they expire or until they are deleted. By default, tokens are valid for 1 hour, and each subsequent request extends the expiration by 1 hour. The token only expires if the user doesn't make a request for the duration of the expiration time.
Expand Down
127 changes: 69 additions & 58 deletions middleware/csrf/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ var (

// Handler for CSRF middleware
type Handler struct {
config *Config
config Config
sessionManager *sessionManager
storageManager *storageManager
}
Expand Down Expand Up @@ -56,6 +56,36 @@ func New(config ...Config) fiber.Handler {
storageManager = newStorageManager(cfg.Storage)
}

// Pre-parse trusted origins
trustedOrigins := []string{}
trustedSubOrigins := []subdomain{}

for _, origin := range cfg.TrustedOrigins {
if i := strings.Index(origin, "://*."); i != -1 {
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CSRF] Invalid origin format in configuration:" + origin)
}
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
trustedSubOrigins = append(trustedSubOrigins, sd)
} else {
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CSRF] Invalid origin format in configuration:" + origin)
}
trustedOrigins = append(trustedOrigins, normalizedOrigin)
}
}

// Create the handler outside of the returned function
handler := &Handler{
config: cfg,
sessionManager: sessionManager,
storageManager: storageManager,
}

// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
Expand All @@ -64,11 +94,7 @@ func New(config ...Config) fiber.Handler {
}

// Store the CSRF handler in the context
c.Locals(handlerKey, &Handler{
config: &cfg,
sessionManager: sessionManager,
storageManager: storageManager,
})
c.Locals(handlerKey, handler)

var token string

Expand All @@ -88,12 +114,12 @@ func New(config ...Config) fiber.Handler {
// Assume that anything not defined as 'safe' by RFC7231 needs protection

// Enforce an origin check for unsafe requests.
err := originMatchesHost(c, cfg.TrustedOrigins)
err := originMatchesHost(c, trustedOrigins, trustedSubOrigins)

// If there's no origin, enforce a referer check for HTTPS connections.
if errors.Is(err, errOriginNotFound) {
if c.Scheme() == "https" {
err = refererMatchesHost(c, cfg.TrustedOrigins)
err = refererMatchesHost(c, trustedOrigins, trustedSubOrigins)
} else {
// If it's not HTTPS, clear the error to allow the request to proceed.
err = nil
Expand Down Expand Up @@ -237,20 +263,15 @@ func setCSRFCookie(c fiber.Ctx, cfg Config, token string, expiry time.Duration)
// DeleteToken removes the token found in the context from the storage
// and expires the CSRF cookie
func (handler *Handler) DeleteToken(c fiber.Ctx) error {
// Get the config from the context
config := handler.config
if config == nil {
panic("CSRF Handler config not found in context")
}
// Extract token from the client request cookie
cookieToken := c.Cookies(config.CookieName)
cookieToken := c.Cookies(handler.config.CookieName)
if cookieToken == "" {
return config.ErrorHandler(c, ErrTokenNotFound)
return handler.config.ErrorHandler(c, ErrTokenNotFound)
}
// Remove the token from storage
deleteTokenFromStorage(c, cookieToken, *config, handler.sessionManager, handler.storageManager)
deleteTokenFromStorage(c, cookieToken, handler.config, handler.sessionManager, handler.storageManager)
// Expire the cookie
expireCSRFCookie(c, *config)
expireCSRFCookie(c, handler.config)
return nil
}

Expand All @@ -262,8 +283,8 @@ func isFromCookie(extractor any) bool {
// originMatchesHost checks that the origin header matches the host header
// returns an error if the origin header is not present or is invalid
// returns nil if the origin header is valid
func originMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
origin := c.Get(fiber.HeaderOrigin)
func originMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error {
origin := strings.ToLower(c.Get(fiber.HeaderOrigin))
if origin == "" || origin == "null" { // "null" is set by some browsers when the origin is a secure context https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin#description
return errOriginNotFound
}
Expand All @@ -273,23 +294,31 @@ func originMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
return ErrOriginInvalid
}

if originURL.Host != c.Host() {
for _, trustedOrigin := range trustedOrigins {
if isTrustedSchemeAndDomain(trustedOrigin, origin) {
return nil
}
if originURL.Scheme == c.Scheme() && originURL.Host == c.Host() {
return nil
}

for _, trustedOrigin := range trustedOrigins {
if origin == trustedOrigin {
return nil
}
return ErrOriginNoMatch
}

return nil
for _, trustedSubOrigin := range trustedSubOrigins {
if trustedSubOrigin.match(origin) {
return nil
}
}

return ErrOriginNoMatch
}

// refererMatchesHost checks that the referer header matches the host header
// returns an error if the referer header is not present or is invalid
// returns nil if the referer header is valid
func refererMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
referer := c.Get(fiber.HeaderReferer)
func refererMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error {
referer := strings.ToLower(c.Get(fiber.HeaderReferer))

if referer == "" {
return ErrRefererNotFound
}
Expand All @@ -299,41 +328,23 @@ func refererMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
return ErrRefererInvalid
}

if refererURL.Host != c.Host() {
for _, trustedOrigin := range trustedOrigins {
if isTrustedSchemeAndDomain(trustedOrigin, referer) {
return nil
}
}
return ErrRefererNoMatch
}

return nil
}

// isTrustedSchemeAndDomain checks if the trustedProtoDomain is the same as the protoDomain
// or if the protoDomain is a subdomain of the trustedProtoDomain where trustedProtoDomain
// is prefixed with "https://." or "http://."
func isTrustedSchemeAndDomain(trustedProtoDomain, protoDomain string) bool {
if trustedProtoDomain == protoDomain {
return true
if refererURL.Scheme == c.Scheme() && refererURL.Host == c.Host() {
return nil
}

// Use constant prefixes for better readability and avoid magic numbers.
const httpsPrefix = "https://."
const httpPrefix = "http://."
referer = refererURL.String()

if strings.HasPrefix(trustedProtoDomain, httpsPrefix) {
trustedProtoDomain = trustedProtoDomain[len(httpsPrefix):]
protoDomain = strings.TrimPrefix(protoDomain, "https://")
return strings.HasSuffix(protoDomain, "."+trustedProtoDomain)
for _, trustedOrigin := range trustedOrigins {
if referer == trustedOrigin {
return nil
}
}

if strings.HasPrefix(trustedProtoDomain, httpPrefix) {
trustedProtoDomain = trustedProtoDomain[len(httpPrefix):]
protoDomain = strings.TrimPrefix(protoDomain, "http://")
return strings.HasSuffix(protoDomain, "."+trustedProtoDomain)
for _, trustedSubOrigin := range trustedSubOrigins {
if trustedSubOrigin.match(referer) {
return nil
}
}

return false
return ErrRefererNoMatch
}
Loading

1 comment on commit 643b4b3

@ReneWerner87
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 643b4b3 Previous: ba10e68 Ratio
Benchmark_Etag 198.5 ns/op 0 B/op 0 allocs/op 97.45 ns/op 0 B/op 0 allocs/op 2.04
Benchmark_Middleware_Favicon 210.3 ns/op 12 B/op 4 allocs/op 91.88 ns/op 3 B/op 1 allocs/op 2.29

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.