diff --git a/docs/api/middleware/cors.md b/docs/api/middleware/cors.md index ca250833d6..882a74808b 100644 --- a/docs/api/middleware/cors.md +++ b/docs/api/middleware/cors.md @@ -4,13 +4,15 @@ id: cors # CORS -CORS middleware for [Fiber](https://github.com/gofiber/fiber) that can be used to enable [Cross-Origin Resource Sharing](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) with various options. +CORS (Cross-Origin Resource Sharing) is a middleware for [Fiber](https://github.com/gofiber/fiber) that allows servers to specify who can access its resources and how. It's not a security feature, but a way to relax the security model of web browsers for cross-origin requests. You can learn more about CORS on [Mozilla Developer Network](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS). -The middleware conforms to the `access-control-allow-origin` specification by parsing `AllowOrigins`. First, the middleware checks if there is a matching allowed origin for the requesting 'origin' header. If there is a match, it returns exactly one matching domain from the list of allowed origins. +This middleware works by adding CORS headers to responses from your Fiber application. These headers specify which origins, methods, and headers are allowed for cross-origin requests. It also handles preflight requests, which are a CORS mechanism to check if the actual request is safe to send. -For more control, `AllowOriginsFunc` can be used to programatically determine if an origin is allowed. If no match was found in `AllowOrigins` and if `AllowOriginsFunc` returns true then the 'access-control-allow-origin' response header is set to the 'origin' request header. +The middleware uses the `AllowOrigins` option to control which origins can make cross-origin requests. It supports single origin, multiple origins, subdomain matching, and wildcard origin. It also allows programmatic origin validation with the `AllowOriginsFunc` option. -When defining your Origins make sure they are properly formatted. The middleware validates and normalizes the provided origins, ensuring they're in the correct format by checking for valid schemes (http or https), and removing any trailing slashes. +To ensure that the provided `AllowOrigins` origins are correctly formatted, this middleware validates and normalizes them. It checks for valid schemes, i.e., HTTP or HTTPS, and it will automatically remove trailing slashes. If the provided origin is invalid, the middleware will panic. + +When configuring CORS, it's important to avoid [common pitfalls](#common-pitfalls) like using a wildcard origin with credentials, being overly permissive with origins, and inadequate validation with `AllowOriginsFunc`. Misconfiguration can expose your application to various security risks. ## Signatures @@ -31,6 +33,16 @@ import ( After you initiate your Fiber app, you can use the following possibilities: +### Basic usage + +To use the default configuration, simply use `cors.New()`. This will allow wildcard origins '*', all methods, no credentials, and no headers or exposed headers. + +```go +app.Use(cors.New()) +``` + +### Custom configuration (specific origins, headers, etc.) + ```go // Initialize default config app.Use(cors.New()) @@ -38,27 +50,50 @@ app.Use(cors.New()) // Or extend your config for customization app.Use(cors.New(cors.Config{ AllowOrigins: "https://gofiber.io, https://gofiber.net", - AllowHeaders: "Origin, Content-Type, Accept", + AllowHeaders: "Origin, Content-Type, Accept", })) ``` -Using the `AllowOriginsFunc` function. In this example any origin will be allowed via CORS. +### Dynamic origin validation + +You can use `AllowOriginsFunc` to programmatically determine whether to allow a request based on its origin. This is useful when you need to validate origins against a database or other dynamic sources. The function should return `true` if the origin is allowed, and `false` otherwise. -For example, if a browser running on `http://localhost:3000` sends a request, this will be accepted and the `access-control-allow-origin` response header will be set to `http://localhost:3000`. +Be sure to review the [security considerations](#security-considerations) when using `AllowOriginsFunc`. -**Note: Using this feature is discouraged in production and it's best practice to explicitly set CORS origins via `AllowOrigins`.** +:::caution +Never allow `AllowOriginsFunc` to return `true` for all origins. This is particularly crucial when `AllowCredentials` is set to `true`. Doing so can bypass the restriction of using a wildcard origin with credentials, exposing your application to serious security threats. + +If you need to allow wildcard origins, use `AllowOrigins` with a wildcard `"*"` instead of `AllowOriginsFunc`. +::: ```go -app.Use(cors.New()) +// dbCheckOrigin checks if the origin is in the list of allowed origins in the database. +func dbCheckOrigin(db *sql.DB, origin string) bool { + // Placeholder query - adjust according to your database schema and query needs + query := "SELECT COUNT(*) FROM allowed_origins WHERE origin = $1" + + var count int + err := db.QueryRow(query, origin).Scan(&count) + if err != nil { + // Handle error (e.g., log it); for simplicity, we return false here + return false + } + + return count > 0 +} + +// ... app.Use(cors.New(cors.Config{ - AllowOriginsFunc: func(origin string) bool { - return os.Getenv("ENVIRONMENT") == "development" - }, + AllowOriginsFunc: func(origin string) bool { + return dbCheckOrigin(db, origin) + }, })) ``` -**Note: The following configuration is considered insecure and will result in a panic.** +### Prohibited usage + +The following example is prohibited because it can expose your application to security risks. It sets `AllowOrigins` to `"*"` (a wildcard) and `AllowCredentials` to `true`. ```go app.Use(cors.New(cors.Config{ @@ -67,18 +102,24 @@ app.Use(cors.New(cors.Config{ })) ``` +This will result in the following panic: + +``` +panic: [CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to `"*"`. +``` + ## Config | Property | Type | Description | Default | |:-----------------|:---------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------| | Next | `func(*fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` | -| AllowOriginsFunc | `func(origin string) bool` | AllowOriginsFunc defines a function that will set the 'access-control-allow-origin' response header to the 'origin' request header when returned true. This allows for dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins will be not have the 'access-control-allow-credentials' header set to 'true'. | `nil` | -| AllowOrigins | `string` | AllowOrigin defines a comma separated list of origins that may access the resource. | `"*"` | +| AllowOriginsFunc | `func(origin string) bool` | `AllowOriginsFunc` is a function that dynamically determines whether to allow a request based on its origin. If this function returns `true`, the 'Access-Control-Allow-Origin' response header will be set to the request's 'origin' header. This function is only used if the request's origin doesn't match any origin in `AllowOrigins`. | `nil` | +| AllowOrigins | `string` | AllowOrigins defines a comma separated list of origins that may access the resource. This supports subdomain matching, so you can use a value like "https://*.example.com" to allow any subdomain of example.com to submit requests. | `"*"` | | AllowMethods | `string` | AllowMethods defines a list of methods allowed when accessing the resource. This is used in response to a preflight request. | `"GET,POST,HEAD,PUT,DELETE,PATCH"` | | AllowHeaders | `string` | AllowHeaders defines a list of request headers that can be used when making the actual request. This is in response to a preflight request. | `""` | -| AllowCredentials | `bool` | AllowCredentials indicates whether or not the response to the request can be exposed when the credentials flag is true. When used as part of a response to a preflight request, this indicates whether or not the actual request can be made using credentials. Note: If true, AllowOrigins cannot be set to a wildcard ("*") to prevent security vulnerabilities. | `false` | -| ExposeHeaders | `string` | ExposeHeaders defines a whitelist headers that clients are allowed to access. | `""` | -| MaxAge | `int` | MaxAge indicates how long (in seconds) the results of a preflight request can be cached. If you pass MaxAge 0, Access-Control-Max-Age header will not be added and browser will use 5 seconds by default. To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0. | `0` | +| AllowCredentials | `bool` | AllowCredentials indicates whether or not the response to the request can be exposed when the credentials flag is true. When used as part of a response to a preflight request, this indicates whether or not the actual request can be made using credentials. Note: If true, AllowOrigins cannot be set to a wildcard (`"*"`) to prevent security vulnerabilities. | `false` | +| ExposeHeaders | `string` | ExposeHeaders defines whitelist headers that clients are allowed to access. | `""` | +| MaxAge | `int` | MaxAge indicates how long (in seconds) the results of a preflight request can be cached. If you pass MaxAge 0, the Access-Control-Max-Age header will not be added and the browser will use 5 seconds by default. To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header to 0. | `0` | ## Default Config @@ -101,3 +142,73 @@ var ConfigDefault = Config{ MaxAge: 0, } ``` + +## Subdomain Matching + +The `AllowOrigins` configuration supports matching subdomains at any level. This means you can use a value like `"https://*.example.com"` to allow any subdomain of `example.com` to submit requests, including multiple subdomain levels such as `"https://sub.sub.example.com"`. + +### Example + +If you want to allow CORS requests from any subdomain of `example.com`, including nested subdomains, you can configure the `AllowOrigins` like so: + +```go +app.Use(cors.New(cors.Config{ + AllowOrigins: "https://*.example.com", +})) +``` + +# How It Works + +The CORS middleware works by adding the necessary CORS headers to responses from your Fiber application. These headers tell browsers what origins, methods, and headers are allowed for cross-origin requests. + +When a request comes in, the middleware first checks if it's a preflight request, which is a CORS mechanism to determine whether the actual request is safe to send. Preflight requests are HTTP OPTIONS requests with specific CORS headers. If it's a preflight request, the middleware responds with the appropriate CORS headers and ends the request. + +If it's not a preflight request, the middleware adds the CORS headers to the response and passes the request to the next handler. The actual CORS headers added depend on the configuration of the middleware. + +The `AllowOrigins` option controls which origins can make cross-origin requests. The middleware handles different `AllowOrigins` configurations as follows: + +- **Single origin:** If `AllowOrigins` is set to a single origin like `"http://www.example.com"`, and that origin matches the origin of the incoming request, the middleware adds the header `Access-Control-Allow-Origin: http://www.example.com` to the response. + +- **Multiple origins:** If `AllowOrigins` is set to multiple origins like `"https://example.com, https://www.example.com"`, the middleware picks the origin that matches the origin of the incoming request. + +- **Subdomain matching:** If `AllowOrigins` includes `"https://*.example.com"`, a subdomain like `https://sub.example.com` will be matched and `"https://sub.example.com"` will be the header. This will also match `https://sub.sub.example.com` and so on, but not `https://example.com`. + +- **Wildcard origin:** If `AllowOrigins` is set to `"*"`, the middleware uses that and adds the header `Access-Control-Allow-Origin: *` to the response. + +In all cases above, except the **Wildcard origin**, the middleware will either add the `Access-Control-Allow-Origin` header to the response matching the origin of the incoming request, or it will not add the header at all if the origin is not allowed. + +- **Programmatic origin validation:**: The middleware also handles the `AllowOriginsFunc` option, which allows you to programmatically determine if an origin is allowed. If `AllowOriginsFunc` returns `true` for an origin, the middleware sets the `Access-Control-Allow-Origin` header to that origin. + +The `AllowMethods` option controls which HTTP methods are allowed. For example, if `AllowMethods` is set to `"GET, POST"`, the middleware adds the header `Access-Control-Allow-Methods: GET, POST` to the response. + +The `AllowHeaders` option specifies which headers are allowed in the actual request. The middleware sets the Access-Control-Allow-Headers response header to the value of `AllowHeaders`. This informs the client which headers it can use in the actual request. + +The `AllowCredentials` option indicates whether the response to the request can be exposed when the credentials flag is true. If `AllowCredentials` is set to `true`, the middleware adds the header `Access-Control-Allow-Credentials: true` to the response. To prevent security vulnerabilities, `AllowCredentials` cannot be set to `true` if `AllowOrigins` is set to a wildcard (`*`). + +The `ExposeHeaders` option defines a whitelist of headers that clients are allowed to access. If `ExposeHeaders` is set to `"X-Custom-Header"`, the middleware adds the header `Access-Control-Expose-Headers: X-Custom-Header` to the response. + +The `MaxAge` option indicates how long the results of a preflight request can be cached. If `MaxAge` is set to `3600`, the middleware adds the header `Access-Control-Max-Age: 3600` to the response. + +The `Vary` header is used in this middleware to inform the client that the server's response to a request. For or both preflight and actual requests, the Vary header is set to `Access-Control-Request-Method` and `Access-Control-Request-Headers`. For preflight requests, the Vary header is also set to `Origin`. The `Vary` header is important for caching. It helps caches (like a web browser's cache or a CDN) determine when a cached response can be used in response to a future request, and when the server needs to be queried for a new response. + +## Security Considerations + +When configuring CORS, misconfiguration can potentially expose your application to various security risks. Here are some secure configurations and common pitfalls to avoid: + +### Secure Configurations + +- **Specify Allowed Origins**: Instead of using a wildcard (`"*"`), specify the exact domains allowed to make requests. For example, `AllowOrigins: "https://www.example.com, https://api.example.com"` ensures only these domains can make cross-origin requests to your application. + +- **Use Credentials Carefully**: If your application needs to support credentials in cross-origin requests, ensure `AllowCredentials` is set to `true` and specify exact origins in `AllowOrigins`. Do not use a wildcard origin in this case. + +- **Limit Exposed Headers**: Only whitelist headers that are necessary for the client-side application by setting `ExposeHeaders` appropriately. This minimizes the risk of exposing sensitive information. + +### Common Pitfalls + +- **Wildcard Origin with Credentials**: Setting `AllowOrigins` to `"*"` (a wildcard) and `AllowCredentials` to `true` is a common misconfiguration. This combination is prohibited because it can expose your application to security risks. + +- **Overly Permissive Origins**: Specifying too many origins or using overly broad patterns (e.g., `https://*.example.com`) can inadvertently allow malicious sites to interact with your application. Be as specific as possible with allowed origins. + +- **Inadequate `AllowOriginsFunc` Validation**: When using `AllowOriginsFunc` for dynamic origin validation, ensure the function includes robust checks to prevent unauthorized origins from being accepted. Overly permissive validation can lead to security vulnerabilities. Never allow `AllowOriginsFunc` to return `true` for all origins. This is particularly crucial when `AllowCredentials` is set to `true`. Doing so can bypass the restriction of using a wildcard origin with credentials, exposing your application to serious security threats. If you need to allow wildcard origins, use `AllowOrigins` with a wildcard `"*"` instead of `AllowOriginsFunc`. + +Remember, the key to secure CORS configuration is specificity and caution. By carefully selecting which origins, methods, and headers are allowed, you can help protect your application from cross-origin attacks. \ No newline at end of file diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index 3accc2f1dd..e27e74cba8 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -15,10 +15,10 @@ type Config struct { // Optional. Default: nil Next func(c *fiber.Ctx) bool - // AllowOriginsFunc defines a function that will set the 'access-control-allow-origin' + // AllowOriginsFunc defines a function that will set the 'Access-Control-Allow-Origin' // response header to the 'origin' request header when returned true. This allows for // dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins - // will be not have the 'access-control-allow-credentials' header set to 'true'. + // will be not have the 'Access-Control-Allow-Credentials' header set to 'true'. // // Optional. Default: nil AllowOriginsFunc func(origin string) bool @@ -115,28 +115,43 @@ func New(config ...Config) fiber.Handler { // allowOrigins is a slice of strings that contains the allowed origins // defined in the 'AllowOrigins' configuration. - var allowOrigins []string + allowOrigins := []string{} + allowSOrigins := []subdomain{} + allowAllOrigins := false + + // processOrigin processes an origin string, normalizes it and checks its validity + // it will panic if the origin is invalid + processOrigin := func(origin string) (string, bool) { + trimmedOrigin := strings.TrimSpace(origin) + isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) + if !isValid { + log.Warnf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin) + panic("[CORS] Invalid origin provided in configuration") + } + return normalizedOrigin, true + } // Validate and normalize static AllowOrigins if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" { origins := strings.Split(cfg.AllowOrigins, ",") - allowOrigins = make([]string, len(origins)) - - for i, origin := range origins { - trimmedOrigin := strings.TrimSpace(origin) - isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) - - if isValid { - allowOrigins[i] = normalizedOrigin + for _, origin := range origins { + if i := strings.Index(origin, "://*."); i != -1 { + normalizedOrigin, isValid := processOrigin(origin[:i+3] + origin[i+4:]) + if !isValid { + continue + } + sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]} + allowSOrigins = append(allowSOrigins, sd) } else { - log.Warnf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin) - panic("[CORS] Invalid origin provided in configuration") + normalizedOrigin, isValid := processOrigin(origin) + if !isValid { + continue + } + allowOrigins = append(allowOrigins, normalizedOrigin) } } - } else { - // If AllowOrigins is set to a wildcard or not set, - // set allowOrigins to a slice with a single element - allowOrigins = []string{cfg.AllowOrigins} + } else if cfg.AllowOrigins == "*" { + allowAllOrigins = true } // Strip white spaces @@ -155,18 +170,36 @@ func New(config ...Config) fiber.Handler { } // Get originHeader header - originHeader := c.Get(fiber.HeaderOrigin) + originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin)) + + // If the request does not have an Origin header, the request is outside the scope of CORS + if originHeader == "" { + return c.Next() + } + + // Set default allowOrigin to empty string allowOrigin := "" // Check allowed origins - for _, origin := range allowOrigins { - if origin == "*" { - allowOrigin = "*" - break + if allowAllOrigins { + allowOrigin = "*" + } else { + // Check if the origin is in the list of allowed origins + for _, origin := range allowOrigins { + if origin == originHeader { + allowOrigin = originHeader + break + } } - if validateDomain(originHeader, origin) { - allowOrigin = originHeader - break + + // Check if the origin is in the list of allowed subdomains + if allowOrigin == "" { + for _, sOrigin := range allowSOrigins { + if sOrigin.match(originHeader) { + allowOrigin = originHeader + break + } + } } } @@ -179,56 +212,63 @@ func New(config ...Config) fiber.Handler { // Simple request if c.Method() != fiber.MethodOptions { - c.Vary(fiber.HeaderOrigin) - c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) - - if cfg.AllowCredentials { - c.Set(fiber.HeaderAccessControlAllowCredentials, "true") - } - if exposeHeaders != "" { - c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders) - } + setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg) return c.Next() } // Preflight request - c.Vary(fiber.HeaderOrigin) c.Vary(fiber.HeaderAccessControlRequestMethod) c.Vary(fiber.HeaderAccessControlRequestHeaders) - c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) - c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods) - if cfg.AllowCredentials { - // When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*' - if allowOrigin != "*" && allowOrigin != "" { - c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) - c.Set(fiber.HeaderAccessControlAllowCredentials, "true") - } else if allowOrigin == "*" { - log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.") - } - } else { - // For non-credential requests, it's safe to set to '*' or specific origins + setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg) + + // Send 204 No Content + return c.SendStatus(fiber.StatusNoContent) + } +} + +// Function to set CORS headers +func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) { + c.Vary(fiber.HeaderOrigin) + + if cfg.AllowCredentials { + // When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*' + if allowOrigin != "*" && allowOrigin != "" { + c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) + c.Set(fiber.HeaderAccessControlAllowCredentials, "true") + } else if allowOrigin == "*" { c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) + log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.") } + } else if len(allowOrigin) > 0 { + // For non-credential requests, it's safe to set to '*' or specific origins + c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) + } - // Set Allow-Headers if not empty - if allowHeaders != "" { - c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders) - } else { - h := c.Get(fiber.HeaderAccessControlRequestHeaders) - if h != "" { - c.Set(fiber.HeaderAccessControlAllowHeaders, h) - } - } + // Set Allow-Methods if not empty + if allowMethods != "" { + c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods) + } - // Set MaxAge is set - if cfg.MaxAge > 0 { - c.Set(fiber.HeaderAccessControlMaxAge, maxAge) - } else if cfg.MaxAge < 0 { - c.Set(fiber.HeaderAccessControlMaxAge, "0") + // Set Allow-Headers if not empty + if allowHeaders != "" { + c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders) + } else { + h := c.Get(fiber.HeaderAccessControlRequestHeaders) + if h != "" { + c.Set(fiber.HeaderAccessControlAllowHeaders, h) } + } - // Send 204 No Content - return c.SendStatus(fiber.StatusNoContent) + // Set MaxAge if set + if cfg.MaxAge > 0 { + c.Set(fiber.HeaderAccessControlMaxAge, maxAge) + } else if cfg.MaxAge < 0 { + c.Set(fiber.HeaderAccessControlMaxAge, "0") + } + + // Set Expose-Headers if not empty + if exposeHeaders != "" { + c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders) } } diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 23692c3f84..ff5cdd7c25 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -49,6 +49,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { // Test default GET response headers ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) @@ -58,6 +59,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { // Test default OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) utils.AssertEqual(t, "GET,POST,HEAD,PUT,DELETE,PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods))) @@ -98,6 +100,7 @@ func Test_CORS_Wildcard(t *testing.T) { // Test non OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") handler(ctx) utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) @@ -137,6 +140,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) { // Test non OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") ctx.Request.Header.SetMethod(fiber.MethodGet) handler(ctx) @@ -171,27 +175,39 @@ func Test_CORS_Wildcard_AllowCredentials_Panic(t *testing.T) { } // go test -run -v Test_CORS_Invalid_Origin_Panic -func Test_CORS_Invalid_Origin_Panic(t *testing.T) { +func Test_CORS_Invalid_Origins_Panic(t *testing.T) { t.Parallel() - // New fiber instance - app := fiber.New() - didPanic := false - func() { - defer func() { - if r := recover(); r != nil { - didPanic = true - } - }() + invalidOrigins := []string{ + "localhost", + "http://foo.[a-z]*.example.com", + "http://*", + "https://*", + "invalid url", + // add more invalid origins as needed + } - app.Use(New(Config{ - AllowOrigins: "localhost", - AllowCredentials: true, - })) - }() + for _, origin := range invalidOrigins { + // New fiber instance + app := fiber.New() - if !didPanic { - t.Errorf("Expected a panic when Origin is missing scheme") + didPanic := false + func() { + defer func() { + if r := recover(); r != nil { + didPanic = true + } + }() + + app.Use(New(Config{ + AllowOrigins: origin, + AllowCredentials: true, + })) + }() + + if !didPanic { + t.Errorf("Expected a panic for invalid origin: %s", origin) + } } } @@ -221,6 +237,18 @@ func Test_CORS_Subdomain(t *testing.T) { ctx.Request.Reset() ctx.Response.Reset() + // Make request with domain only (disallowed) + ctx.Request.SetRequestURI("/") + ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") + + handler(ctx) + + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) + + ctx.Request.Reset() + ctx.Response.Reset() + // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) @@ -293,7 +321,7 @@ func Test_CORS_AllowOriginScheme(t *testing.T) { shouldAllowOrigin: false, }, { - pattern: "https://*--aaa.bbb.com", + pattern: "https://--aaa.bbb.com", reqOrigin: "https://prod-preview--aaa.bbb.com", shouldAllowOrigin: false, }, @@ -303,8 +331,13 @@ func Test_CORS_AllowOriginScheme(t *testing.T) { shouldAllowOrigin: true, }, { - pattern: "http://foo.[a-z]*.example.com", - reqOrigin: "http://ccc.bbb.example.com", + pattern: "http://domain-1.com, http://example.com", + reqOrigin: "http://example.com", + shouldAllowOrigin: true, + }, + { + pattern: "http://domain-1.com, http://example.com", + reqOrigin: "http://domain-2.com", shouldAllowOrigin: false, }, { @@ -345,6 +378,35 @@ func Test_CORS_AllowOriginScheme(t *testing.T) { } } +func Test_CORS_AllowOriginHeader_NoMatch(t *testing.T) { + t.Parallel() + // New fiber instance + app := fiber.New() + app.Use("/", New(Config{ + AllowOrigins: "http://example-1.com, https://example-1.com", + })) + + // Get handler pointer + handler := app.Handler() + + // Make request with disallowed origin + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/") + ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") + + // Perform request + handler(ctx) + + var headerExists bool + ctx.Response.Header.VisitAll(func(key, _ []byte) { + if string(key) == fiber.HeaderAccessControlAllowOrigin { + headerExists = true + } + }) + utils.AssertEqual(t, false, headerExists, "Access-Control-Allow-Origin header should not be set") +} + // go test -run Test_CORS_Next func Test_CORS_Next(t *testing.T) { t.Parallel() diff --git a/middleware/cors/utils.go b/middleware/cors/utils.go index d1280899c9..443e648903 100644 --- a/middleware/cors/utils.go +++ b/middleware/cors/utils.go @@ -12,37 +12,6 @@ func matchScheme(domain, pattern string) bool { return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx] } -// validateDomain checks if the domain matches the pattern -func validateDomain(domain, pattern string) bool { - // Directly compare the domain and pattern for an exact match. - if domain == pattern { - return true - } - - // Normalize domain and pattern to exclude schemes and ports for matching purposes - normalizedDomain := normalizeDomain(domain) - normalizedPattern := normalizeDomain(pattern) - - // Handling the case where pattern is a wildcard subdomain pattern. - if strings.HasPrefix(normalizedPattern, "*.") { - // Trim leading "*." from pattern for comparison. - trimmedPattern := normalizedPattern[2:] - - // Check if the domain ends with the trimmed pattern. - if strings.HasSuffix(normalizedDomain, trimmedPattern) { - // Ensure that the domain is not exactly the base domain. - if normalizedDomain != trimmedPattern { - // Special handling to prevent "example.com" matching "*.example.com". - if strings.TrimSuffix(normalizedDomain, trimmedPattern) != "" { - return true - } - } - } - } - - return false -} - // normalizeDomain removes the scheme and port from the input domain func normalizeDomain(input string) string { // Remove scheme @@ -73,6 +42,13 @@ func normalizeOrigin(origin string) (bool, string) { return false, "" } + // Don't allow a wildcard with a protocol + // wildcards cannot be used within any other value. For example, the following header is not valid: + // Access-Control-Allow-Origin: https://* + if strings.Contains(parsedOrigin.Host, "*") { + return false, "" + } + // Validate there is a host present. The presence of a path, query, or fragment components // is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized if parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" { @@ -83,3 +59,13 @@ func normalizeOrigin(origin string) (bool, string) { // The path or trailing slash is not included in the normalized origin. return true, strings.ToLower(parsedOrigin.Scheme) + "://" + strings.ToLower(parsedOrigin.Host) } + +type subdomain struct { + // The wildcard pattern + prefix string + suffix string +} + +func (s subdomain) match(o string) bool { + return len(o) >= len(s.prefix)+len(s.suffix) && strings.HasPrefix(o, s.prefix) && strings.HasSuffix(o, s.suffix) +} diff --git a/middleware/cors/utils_test.go b/middleware/cors/utils_test.go index 3acd692521..adc729d05f 100644 --- a/middleware/cors/utils_test.go +++ b/middleware/cors/utils_test.go @@ -75,44 +75,6 @@ func Test_matchScheme(t *testing.T) { } } -// go test -run -v Test_validateOrigin -func Test_validateOrigin(t *testing.T) { - testCases := []struct { - domain string - pattern string - expected bool - }{ - {"http://example.com", "http://example.com", true}, // Exact match should work. - {"https://example.com", "http://example.com", false}, // Scheme mismatch should matter in CORS context. - {"http://example.com", "https://example.com", false}, // Scheme mismatch should matter in CORS context. - {"http://example.com", "http://example.org", false}, // Different domains should not match. - {"http://example.com", "http://example.com:8080", false}, // Port mismatch should matter. - {"http://example.com:8080", "http://example.com", false}, // Port mismatch should matter. - {"http://example.com:8080", "http://example.com:8081", false}, // Different ports should not match. - {"example.com", "example.com", true}, // Simplified form, assuming scheme and port are not considered here, but in practice, they are part of the origin. - {"sub.example.com", "example.com", false}, // Subdomain should not match the base domain directly. - {"sub.example.com", "*.example.com", true}, // Correct assumption for wildcard subdomain matching. - {"example.com", "*.example.com", false}, // Base domain should not match its wildcard subdomain pattern. - {"sub.example.com", "*.com", true}, // Technically correct for pattern matching, but broad wildcard use like this is not recommended for CORS. - {"sub.sub.example.com", "*.example.com", true}, // Nested subdomain should match the wildcard pattern. - {"example.com", "*.org", false}, // Different TLDs should not match. - {"example.com", "example.org", false}, // Different domains should not match. - {"example.com:8080", "*.example.com", false}, // Different ports mean different origins. - {"example.com", "sub.example.net", false}, // Different domains should not match. - {"http://localhost", "http://localhost", true}, // Localhost should match. - {"http://127.0.0.1", "http://127.0.0.1", true}, // IPv4 address should match. - {"http://[::1]", "http://[::1]", true}, // IPv6 address should match. - } - - for _, tc := range testCases { - result := validateDomain(tc.domain, tc.pattern) - - if result != tc.expected { - t.Errorf("Expected validateOrigin('%s', '%s') to be %v, but got %v", tc.domain, tc.pattern, tc.expected, result) - } - } -} - // go test -run -v Test_normalizeDomain func Test_normalizeDomain(t *testing.T) { testCases := []struct {