Skip to content

Commit

Permalink
middleware/earlydata: backport to v2
Browse files Browse the repository at this point in the history
Backport of #2270 to v2.
  • Loading branch information
leonklingele committed Jan 28, 2023
1 parent 18ff026 commit 1091e7c
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 53 deletions.
14 changes: 7 additions & 7 deletions middleware/earlydata/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ First import the middleware from Fiber,

```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/earlydata"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/earlydata"
)
```

Expand Down Expand Up @@ -65,17 +65,17 @@ type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
Next func(c *fiber.Ctx) bool

// IsEarlyData returns whether the request is an early-data request.
//
// Optional. Default: a function which checks if the "Early-Data" request header equals "1".
IsEarlyData func(c fiber.Ctx) bool
IsEarlyData func(c *fiber.Ctx) bool

// AllowEarlyData returns whether the early-data request should be allowed or rejected.
//
// Optional. Default: a function which rejects the request on unsafe and allows the request on safe HTTP request methods.
AllowEarlyData func(c fiber.Ctx) bool
AllowEarlyData func(c *fiber.Ctx) bool

// Error is returned in case an early-data request is rejected.
//
Expand All @@ -88,11 +88,11 @@ type Config struct {

```go
var ConfigDefault = Config{
IsEarlyData: func(c fiber.Ctx) bool {
IsEarlyData: func(c *fiber.Ctx) bool {
return c.Get("Early-Data") == "1"
},

AllowEarlyData: func(c fiber.Ctx) bool {
AllowEarlyData: func(c *fiber.Ctx) bool {
return fiber.IsMethodSafe(c.Method())
},

Expand Down
12 changes: 6 additions & 6 deletions middleware/earlydata/config.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package earlydata

import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v2"
)

const (
Expand All @@ -14,17 +14,17 @@ type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
Next func(c *fiber.Ctx) bool

// IsEarlyData returns whether the request is an early-data request.
//
// Optional. Default: a function which checks if the "Early-Data" request header equals "1".
IsEarlyData func(c fiber.Ctx) bool
IsEarlyData func(c *fiber.Ctx) bool

// AllowEarlyData returns whether the early-data request should be allowed or rejected.
//
// Optional. Default: a function which rejects the request on unsafe and allows the request on safe HTTP request methods.
AllowEarlyData func(c fiber.Ctx) bool
AllowEarlyData func(c *fiber.Ctx) bool

// Error is returned in case an early-data request is rejected.
//
Expand All @@ -34,11 +34,11 @@ type Config struct {

// ConfigDefault is the default config
var ConfigDefault = Config{
IsEarlyData: func(c fiber.Ctx) bool {
IsEarlyData: func(c *fiber.Ctx) bool {
return c.Get(DefaultHeaderName) == DefaultHeaderTrueValue
},

AllowEarlyData: func(c fiber.Ctx) bool {
AllowEarlyData: func(c *fiber.Ctx) bool {
return fiber.IsMethodSafe(c.Method())
},

Expand Down
6 changes: 3 additions & 3 deletions middleware/earlydata/earlydata.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package earlydata

import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v2"
)

const (
localsKeyAllowed = "earlydata_allowed"
)

func IsEarly(c fiber.Ctx) bool {
func IsEarly(c *fiber.Ctx) bool {
return c.Locals(localsKeyAllowed) != nil
}

Expand All @@ -19,7 +19,7 @@ func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)

// Return new handler
return func(c fiber.Ctx) error {
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
Expand Down
78 changes: 41 additions & 37 deletions middleware/earlydata/earlydata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
"net/http/httptest"
"testing"

"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/earlydata"
"github.com/stretchr/testify/require"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/earlydata"
"github.com/gofiber/fiber/v2/utils"
)

const (
Expand All @@ -33,7 +33,7 @@ func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App {

// Middleware to test IsEarly func
const localsKeyTestValid = "earlydata_testvalid"
app.Use(func(c fiber.Ctx) error {
app.Use(func(c *fiber.Ctx) error {
isEarly := earlydata.IsEarly(c)

switch h := c.Get(headerName); h {
Expand Down Expand Up @@ -64,16 +64,20 @@ func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App {
return c.Next()
})

app.Add([]string{
fiber.MethodGet,
fiber.MethodPost,
}, "/", func(c fiber.Ctx) error {
if !c.Locals(localsKeyTestValid).(bool) {
return errors.New("handler called even though validation failed")
}
{
{
handler := func(c *fiber.Ctx) error {
if !c.Locals(localsKeyTestValid).(bool) {
return errors.New("handler called even though validation failed")
}

return nil
})
return nil
}

app.Get("/", handler)
app.Post("/", handler)
}
}

return app
}
Expand All @@ -89,36 +93,36 @@ func Test_EarlyData(t *testing.T) {
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)

resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)

req.Header.Set(headerName, headerValOff)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)

req.Header.Set(headerName, headerValOn)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
}

{
req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody)

resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)

req.Header.Set(headerName, headerValOff)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)

req.Header.Set(headerName, headerValOn)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
}
}

Expand All @@ -129,36 +133,36 @@ func Test_EarlyData(t *testing.T) {
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)

resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)

req.Header.Set(headerName, headerValOff)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)

req.Header.Set(headerName, headerValOn)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
}

{
req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody)

resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)

req.Header.Set(headerName, headerValOff)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)

req.Header.Set(headerName, headerValOn)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTooEarly, resp.StatusCode)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTooEarly, resp.StatusCode)
}
}

Expand Down

0 comments on commit 1091e7c

Please sign in to comment.