Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
packages:
github.com/meigma/authkit:
interfaces:
PrincipalFinder:
config:
template: testify
structname: PrincipalFinder
pkgname: authkitmocks
dir: mocks/authkit
filename: principal_finder.go
320 changes: 320 additions & 0 deletions access/jwt/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
package jwt

import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"testing"
"time"

"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jws"
jwxjwt "github.com/lestrrat-go/jwx/v3/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const (
testIssuer = "https://auth.example.test"
testAudience = "notes-api"
testPrincipalID = "principal_1"
testTokenID = "token-123"
testAction = "note:read"
testRoleID = "reader"
testKeyID = "key-1"
rsaKeyBits = 2048
)

func newIssuerAndVerifier(t *testing.T) (*Issuer, *Verifier, jwk.Set) {
t.Helper()

privateKey, publicKey := newRSAKeyPair(t)
keySet := newKeySet(t, publicKey)
issuer, err := NewIssuer(issuerOptions(privateKey, nil))
require.NoError(t, err)
verifier, err := NewVerifier(verifierOptions(keySet, nil))
require.NoError(t, err)

return issuer, verifier, keySet
}

func issuerOptions(signingKey jwk.Key, mutate func(*IssuerOptions)) IssuerOptions {
opts := IssuerOptions{
Issuer: testIssuer,
Audience: testAudience,
TTL: time.Hour,
SigningKey: signingKey,
Clock: fixedTime,
TokenID: func() (string, error) {
return testTokenID, nil
},
}
if mutate != nil {
mutate(&opts)
}

return opts
}

func verifierOptions(keySet jwk.Set, mutate func(*VerifierOptions)) VerifierOptions {
opts := VerifierOptions{
Issuer: testIssuer,
Audience: testAudience,
KeySet: keySet,
Clock: fixedTime,
}
if mutate != nil {
mutate(&opts)
}

return opts
}

func issueToken(t *testing.T, issuer *Issuer) IssuedToken {
t.Helper()

issued, err := issuer.IssueToken(context.Background(), IssueRequest{
PrincipalID: testPrincipalID,
})
require.NoError(t, err)

return issued
}

func issueTokenWithOptions(t *testing.T, mutate func(*IssuerOptions)) string {
t.Helper()

privateKey, _ := newRSAKeyPair(t)
issuer, err := NewIssuer(issuerOptions(privateKey, mutate))
require.NoError(t, err)

return issueToken(t, issuer).JWT
}

func issueTokenWithClock(t *testing.T, now time.Time) string {
t.Helper()

return issueTokenWithOptions(t, func(opts *IssuerOptions) {
opts.Clock = func() time.Time {
return now
}
})
}

func issueWithWrongSignature(t *testing.T) string {
t.Helper()

privateKey, _ := newRSAKeyPair(t)
issuer, err := NewIssuer(issuerOptions(privateKey, nil))
require.NoError(t, err)

return issueToken(t, issuer).JWT
}

func issueWithWrongKeyID(t *testing.T) string {
t.Helper()

return issueTokenWithOptions(t, func(opts *IssuerOptions) {
opts.SigningKey = newRSAKey(t, "other-key", DefaultAlgorithm)
})
}

func signToken(
t *testing.T,
claims map[string]any,
tokenType string,
algorithmName string,
keyID string,
) string {
t.Helper()

key := newRSAKey(t, keyID, algorithmName)
algorithm, err := signatureAlgorithm(algorithmName)
require.NoError(t, err)
token := jwxjwt.New()
for name, value := range claims {
require.NoError(t, token.Set(name, value))
}
headers := jws.NewHeaders()
require.NoError(t, headers.Set(jws.TypeKey, tokenType))
signed, err := jwxjwt.Sign(token, jwxjwt.WithKey(algorithm, key, jws.WithProtectedHeaders(headers)))
require.NoError(t, err)

return string(signed)
}

func signTokenWithHeaders(
t *testing.T,
key jwk.Key,
claims map[string]any,
mutate func(jws.Headers),
) string {
t.Helper()

algorithm, err := signatureAlgorithm(DefaultAlgorithm)
require.NoError(t, err)
token := jwxjwt.New()
for name, value := range claims {
require.NoError(t, token.Set(name, value))
}
headers := jws.NewHeaders()
require.NoError(t, headers.Set(jws.TypeKey, TokenType))
if mutate != nil {
mutate(headers)
}
signed, err := jwxjwt.Sign(token, jwxjwt.WithKey(algorithm, key, jws.WithProtectedHeaders(headers)))
require.NoError(t, err)

return string(signed)
}

func signTokenWithoutType(t *testing.T, claims map[string]any) string {
t.Helper()

token := jwxjwt.New()
for name, value := range claims {
require.NoError(t, token.Set(name, value))
}
payload, err := json.Marshal(token)
require.NoError(t, err)

key := newRSAKey(t, testKeyID, DefaultAlgorithm)
signed, err := jws.Sign(payload, jws.WithKey(jwa.RS256(), key))
require.NoError(t, err)

return string(signed)
}

func unsignedToken(t *testing.T, claims map[string]any) string {
t.Helper()

header := map[string]any{
jws.AlgorithmKey: jwa.NoSignature().String(),
jws.TypeKey: TokenType,
}
headerBytes, err := json.Marshal(header)
require.NoError(t, err)
payloadBytes, err := json.Marshal(claims)
require.NoError(t, err)

return base64.RawURLEncoding.EncodeToString(headerBytes) + "." +
base64.RawURLEncoding.EncodeToString(payloadBytes) + "."
}

func hmacSignedToken(t *testing.T, claims map[string]any) string {
t.Helper()

key, err := jwk.Import([]byte("secret"))
require.NoError(t, err)
require.NoError(t, key.Set(jwk.KeyIDKey, testKeyID))
token := jwxjwt.New()
for name, value := range claims {
require.NoError(t, token.Set(name, value))
}
headers := jws.NewHeaders()
require.NoError(t, headers.Set(jws.TypeKey, TokenType))
signed, err := jwxjwt.Sign(token, jwxjwt.WithKey(jwa.HS256(), key, jws.WithProtectedHeaders(headers)))
require.NoError(t, err)

return string(signed)
}

func baseClaims() map[string]any {
now := fixedTime()

return map[string]any{
jwxjwt.IssuerKey: testIssuer,
jwxjwt.SubjectKey: testPrincipalID,
jwxjwt.AudienceKey: []string{testAudience},
jwxjwt.IssuedAtKey: now,
jwxjwt.ExpirationKey: now.Add(time.Hour),
jwxjwt.JwtIDKey: testTokenID,
}
}

func assertProtectedHeader(t *testing.T, jwt string, tokenType string, algorithm string, keyID string) {
t.Helper()

message, err := jws.Parse([]byte(jwt), jws.WithCompact())
require.NoError(t, err)
require.Len(t, message.Signatures(), 1)
headers := message.Signatures()[0].ProtectedHeaders()
require.NotNil(t, headers)

gotType, ok := headers.Type()
require.True(t, ok)
assert.Equal(t, tokenType, gotType)
gotAlgorithm, ok := headers.Algorithm()
require.True(t, ok)
assert.Equal(t, algorithm, gotAlgorithm.String())
gotKeyID, ok := headers.KeyID()
require.True(t, ok)
assert.Equal(t, keyID, gotKeyID)
}

func assertNoAuthorizationClaims(t *testing.T, jwt string, keySet jwk.Set) {
t.Helper()

token, err := jwxjwt.Parse(
[]byte(jwt),
jwxjwt.WithKeySet(keySet),
jwxjwt.WithIssuer(testIssuer),
jwxjwt.WithAudience(testAudience),
jwxjwt.WithClock(jwxjwt.ClockFunc(fixedTime)),
)
require.NoError(t, err)
assert.ElementsMatch(t, []string{
jwxjwt.AudienceKey,
jwxjwt.ExpirationKey,
jwxjwt.IssuedAtKey,
jwxjwt.IssuerKey,
jwxjwt.JwtIDKey,
jwxjwt.SubjectKey,
}, token.Keys())
assert.False(t, token.Has("roles"))
assert.False(t, token.Has("permissions"))
assert.False(t, token.Has("actions"))
}

func newRSAKeyPair(t *testing.T) (jwk.Key, jwk.Key) {
t.Helper()

privateKey := newRSAKey(t, testKeyID, DefaultAlgorithm)
publicKey, err := jwk.PublicKeyOf(privateKey)
require.NoError(t, err)

return privateKey, publicKey
}

func newRSAKey(t *testing.T, keyID string, algorithm string) jwk.Key {
t.Helper()

rawKey, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
require.NoError(t, err)
key, err := jwk.Import(rawKey)
require.NoError(t, err)
if keyID != "" {
require.NoError(t, key.Set(jwk.KeyIDKey, keyID))
}
if algorithm != "" {
require.NoError(t, key.Set(jwk.AlgorithmKey, algorithm))
}

return key
}

func newKeySet(t *testing.T, key jwk.Key) jwk.Set {
t.Helper()

keySet := jwk.NewSet()
require.NoError(t, keySet.AddKey(key))

return keySet
}

func fixedTime() time.Time {
return time.Date(2026, time.May, 13, 21, 0, 0, 0, time.UTC)
}
16 changes: 3 additions & 13 deletions access/jwt/issuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"crypto/rand"
"errors"
"fmt"
"strings"
"time"

"github.com/lestrrat-go/jwx/v3/jwa"
Expand Down Expand Up @@ -117,24 +116,14 @@ func (i *Issuer) IssueToken(ctx context.Context, req IssueRequest) (IssuedToken,

return IssuedToken{
ID: tokenID,
Plaintext: string(signed),
JWT: string(signed),
PrincipalID: req.PrincipalID,
IssuedAt: issuedAt,
ExpiresAt: expiresAt,
}, nil
}

func validateRequiredString(name string, value string) error {
if strings.TrimSpace(value) == "" {
return fmt.Errorf("jwt: %s is required", name)
}
if strings.TrimSpace(value) != value {
return fmt.Errorf("jwt: %s must not contain surrounding whitespace", name)
}

return nil
}

// defaultString returns value when non-empty and fallback otherwise.
func defaultString(value string, fallback string) string {
if value == "" {
return fallback
Expand All @@ -143,6 +132,7 @@ func defaultString(value string, fallback string) string {
return value
}

// randomTokenID returns a cryptographically random token ID for use as the jti claim.
func randomTokenID() (string, error) {
return rand.Text(), nil
}
Loading