Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ fix!: context key collisions #896

Open
wants to merge 3 commits into
base: v3-beta
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions jwt/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ type Config struct {
// The order of precedence is: KeyFunc, JWKSetURLs, SigningKeys, SigningKey.
SigningKeys map[string]SigningKey

// Context key to store user information from the token into context.
// Optional. Default: "user".
ContextKey string
sixcolors marked this conversation as resolved.
Show resolved Hide resolved

// Claims are extendable claims data defining token content.
// Optional. Default value jwt.MapClaims
Claims jwt.Claims
Expand Down Expand Up @@ -122,9 +118,6 @@ func makeCfg(config []Config) (cfg Config) {
if cfg.SigningKey.Key == nil && len(cfg.SigningKeys) == 0 && len(cfg.JWKSetURLs) == 0 && cfg.KeyFunc == nil {
panic("Fiber: JWT middleware configuration: At least one of the following is required: KeyFunc, JWKSetURLs, SigningKeys, or SigningKey.")
}
if cfg.ContextKey == "" {
cfg.ContextKey = "user"
}
if cfg.Claims == nil {
cfg.Claims = jwt.MapClaims{}
}
Expand Down
3 changes: 0 additions & 3 deletions jwt/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ func TestDefaultConfiguration(t *testing.T) {
cfg := makeCfg(config)

// Assert
if cfg.ContextKey != "user" {
t.Fatalf("Default context key should be 'user'")
}
if cfg.Claims == nil {
t.Fatalf("Default claims should not be 'nil'")
}
Expand Down
17 changes: 16 additions & 1 deletion jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ import (
"github.com/golang-jwt/jwt/v5"
)

// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int

// The following contextKey values are defined to store values in context.
const (
tokenKey contextKey = 0
)

var (
defaultTokenLookup = "header:" + fiber.HeaderAuthorization
)
Expand Down Expand Up @@ -51,9 +60,15 @@ func New(config ...Config) fiber.Handler {
}
if err == nil && token.Valid {
// Store user information from token into context.
c.Locals(cfg.ContextKey, token)
c.Locals(tokenKey, token)
return cfg.SuccessHandler(c)
}
return cfg.ErrorHandler(c, err)
}
}

// FromContext returns the token from the context.
// If there is no token, nil is returned.
func FromContext(c *fiber.Ctx) *jwt.Token {
return c.Locals(tokenKey).(*jwt.Token)
}
41 changes: 41 additions & 0 deletions jwt/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,44 @@ func customKeyfunc() jwt.Keyfunc {
return []byte(defaultSigningKey), nil
}
}

func TestFromContext(t *testing.T) {
t.Parallel()

defer func() {
// Assert
if err := recover(); err != nil {
t.Fatalf("Middleware should not panic")
}
}()

for _, test := range hamac {
// Arrange
app := fiber.New()

app.Use(jwtware.New(jwtware.Config{
SigningKey: jwtware.SigningKey{
JWTAlg: test.SigningMethod,
Key: []byte(defaultSigningKey),
},
}))

app.Get("/ok", func(c *fiber.Ctx) error {
token := jwtware.FromContext(c)
if token == nil {
return c.SendStatus(fiber.StatusUnauthorized)
}
return c.SendString("OK")
})

req := httptest.NewRequest("GET", "/ok", nil)
req.Header.Add("Authorization", "Bearer "+test.Token)

// Act
resp, err := app.Test(req)

// Assert
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
}
7 changes: 3 additions & 4 deletions paseto/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ pasetoware.New(config ...pasetoware.Config) func(*fiber.Ctx) error
| SymmetricKey | `[]byte` | Secret key to encrypt token. If present the middleware will generate local tokens. | `nil` |
| PrivateKey | `ed25519.PrivateKey` | Secret key to sign the tokens. If present (along with its `PublicKey`) the middleware will generate public tokens. | `nil`
| PublicKey | `crypto.PublicKey` | Public key to verify the tokens. If present (along with `PrivateKey`) the middleware will generate public tokens. | `nil`
| ContextKey | `string` | Context key to store user information from the token into context. | `"auth-token"` |
| TokenLookup | `[2]string` | TokenLookup is a string slice with size 2, that is used to extract token from the request | `["header","Authorization"]` |

## Instructions
Expand Down Expand Up @@ -128,7 +127,7 @@ func accessible(c *fiber.Ctx) error {
}

func restricted(c *fiber.Ctx) error {
payload := c.Locals(pasetoware.DefaultContextKey).(string)
payload := pasetoware.FromContext(c).(string)
return c.SendString("Welcome " + payload)
}

Expand Down Expand Up @@ -242,7 +241,7 @@ func accessible(c *fiber.Ctx) error {
}

func restricted(c *fiber.Ctx) error {
payload := c.Locals(pasetoware.DefaultContextKey).(customPayloadStruct)
payload := pasetoware.FromContext(c).(customPayloadStruct)
return c.SendString("Welcome " + payload.Name)
}

Expand Down Expand Up @@ -350,7 +349,7 @@ func accessible(c *fiber.Ctx) error {
}

func restricted(c *fiber.Ctx) error {
payload := c.Locals(pasetoware.DefaultContextKey).(string)
payload := pasetoware.FromContext(c).(string)
return c.SendString("Welcome " + payload)
}

Expand Down
9 changes: 0 additions & 9 deletions paseto/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ type Config struct {
// Required if SymmetricKey is not set
PublicKey crypto.PublicKey

// ContextKey to store user information from the token into context.
// Optional. Default: DefaultContextKey.
ContextKey string

// TokenLookup is a string slice with size 2, that is used to extract token from the request.
// Optional. Default value ["header","Authorization"].
// Possible values:
Expand All @@ -79,7 +75,6 @@ var ConfigDefault = Config{
ErrorHandler: nil,
Validate: nil,
SymmetricKey: nil,
ContextKey: DefaultContextKey,
TokenLookup: [2]string{LookupHeader, fiber.HeaderAuthorization},
}

Expand Down Expand Up @@ -140,10 +135,6 @@ func configDefault(authConfigs ...Config) Config {
config.Validate = defaultValidateFunc
}

if config.ContextKey == "" {
config.ContextKey = ConfigDefault.ContextKey
}

if config.TokenLookup[0] == "" {
config.TokenLookup[0] = ConfigDefault.TokenLookup[0]
}
Expand Down
1 change: 0 additions & 1 deletion paseto/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ func Test_ConfigDefault(t *testing.T) {
utils.AssertEqual(t, LookupHeader, config.TokenLookup[0])
utils.AssertEqual(t, fiber.HeaderAuthorization, config.TokenLookup[1])

utils.AssertEqual(t, DefaultContextKey, config.ContextKey)
utils.AssertEqual(t, true, config.Validate != nil)
}

Expand Down
16 changes: 15 additions & 1 deletion paseto/paseto.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ import (
"github.com/gofiber/fiber/v2"
)

// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int

// The following contextKey values are defined to store values in context.
const (
payloadKey contextKey = 0
)

// New PASETO middleware, returns a handler that takes a token in selected lookup param and in case token is valid
// it saves the decrypted token on ctx.Locals, take a look on Config to know more configuration options
func New(authConfigs ...Config) fiber.Handler {
Expand Down Expand Up @@ -47,11 +56,16 @@ func New(authConfigs ...Config) fiber.Handler {
payload, err := config.Validate(outData)
if err == nil {
// Store user information from token into context.
c.Locals(config.ContextKey, payload)
c.Locals(payloadKey, payload)

return config.SuccessHandler(c)
}

return config.ErrorHandler(c, err)
}
}

// FromContext returns the payload from the context.
func FromContext(c *fiber.Ctx) interface{} {
return c.Locals(payloadKey)
}
19 changes: 4 additions & 15 deletions paseto/paseto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ func Test_PASETO_LocalToken_MissingToken(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
SymmetricKey: []byte(symmetricKey),
ContextKey: DefaultContextKey,
ErrorHandler: assertErrorHandler(t, ErrMissingToken),
}))
request := httptest.NewRequest("GET", "/", nil)
Expand All @@ -101,7 +100,6 @@ func Test_PASETO_PublicToken_MissingToken(t *testing.T) {
app.Use(New(Config{
PrivateKey: privateKey,
PublicKey: privateKey.Public(),
ContextKey: DefaultContextKey,
ErrorHandler: assertErrorHandler(t, ErrMissingToken),
}))

Expand All @@ -117,7 +115,6 @@ func Test_PASETO_LocalToken_ErrDataUnmarshal(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
SymmetricKey: []byte(symmetricKey),
ContextKey: DefaultContextKey,
ErrorHandler: assertErrorHandler(t, ErrDataUnmarshal),
}))
request, err := generateTokenRequest("/", createCustomToken, durationTest, PurposeLocal)
Expand All @@ -136,7 +133,6 @@ func Test_PASETO_PublicToken_ErrDataUnmarshal(t *testing.T) {
app.Use(New(Config{
PrivateKey: privateKey,
PublicKey: privateKey.Public(),
ContextKey: DefaultContextKey,
}))

request, err := generateTokenRequest("/", createCustomToken, durationTest, PurposePublic)
Expand All @@ -153,7 +149,6 @@ func Test_PASETO_LocalToken_ErrTokenExpired(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
SymmetricKey: []byte(symmetricKey),
ContextKey: DefaultContextKey,
ErrorHandler: assertErrorHandler(t, ErrExpiredToken),
}))
request, err := generateTokenRequest("/", CreateToken, time.Nanosecond*-10, PurposeLocal)
Expand All @@ -172,7 +167,6 @@ func Test_PASETO_PublicToken_ErrTokenExpired(t *testing.T) {
app.Use(New(Config{
PrivateKey: privateKey,
PublicKey: privateKey.Public(),
ContextKey: DefaultContextKey,
ErrorHandler: assertErrorHandler(t, ErrExpiredToken),
}))

Expand Down Expand Up @@ -226,10 +220,9 @@ func Test_PASETO_LocalTokenDecrypt(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
SymmetricKey: []byte(symmetricKey),
ContextKey: DefaultContextKey,
}))
app.Get("/", func(ctx *fiber.Ctx) error {
utils.AssertEqual(t, testMessage, ctx.Locals(DefaultContextKey))
utils.AssertEqual(t, testMessage, FromContext(ctx))
return nil
})
request, err := generateTokenRequest("/", CreateToken, durationTest, PurposeLocal)
Expand All @@ -249,10 +242,9 @@ func Test_PASETO_PublicTokenVerify(t *testing.T) {
app.Use(New(Config{
PrivateKey: privateKey,
PublicKey: privateKey.Public(),
ContextKey: DefaultContextKey,
}))
app.Get("/", func(ctx *fiber.Ctx) error {
utils.AssertEqual(t, testMessage, ctx.Locals(DefaultContextKey))
utils.AssertEqual(t, testMessage, FromContext(ctx))
return nil
})
request, err := generateTokenRequest("/", CreateToken, durationTest, PurposePublic)
Expand All @@ -268,7 +260,6 @@ func Test_PASETO_LocalToken_IncorrectBearerToken(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
SymmetricKey: []byte(symmetricKey),
ContextKey: DefaultContextKey,
TokenPrefix: "Gopher",
ErrorHandler: func(ctx *fiber.Ctx, err error) error {
if errors.Is(err, ErrIncorrectTokenPrefix) {
Expand Down Expand Up @@ -312,7 +303,6 @@ func Test_PASETO_LocalToken_InvalidToken(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
SymmetricKey: []byte(symmetricKey),
ContextKey: DefaultContextKey,
}))
request := httptest.NewRequest("GET", "/", nil)
request.Header.Set(fiber.HeaderAuthorization, invalidToken)
Expand All @@ -328,7 +318,6 @@ func Test_PASETO_PublicToken_InvalidToken(t *testing.T) {
app.Use(New(Config{
PrivateKey: privateKey,
PublicKey: privateKey.Public(),
ContextKey: DefaultContextKey,
}))

request := httptest.NewRequest("GET", "/", nil)
Expand Down Expand Up @@ -357,7 +346,7 @@ func Test_PASETO_LocalToken_CustomValidate(t *testing.T) {
}))

app.Get("/", func(ctx *fiber.Ctx) error {
utils.AssertEqual(t, testMessage, ctx.Locals(DefaultContextKey))
utils.AssertEqual(t, testMessage, FromContext(ctx))
return nil
})

Expand Down Expand Up @@ -394,7 +383,7 @@ func Test_PASETO_PublicToken_CustomValidate(t *testing.T) {
}))

app.Get("/", func(ctx *fiber.Ctx) error {
utils.AssertEqual(t, testMessage, ctx.Locals(DefaultContextKey))
utils.AssertEqual(t, testMessage, FromContext(ctx))
return nil
})

Expand Down
Loading