Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions proof/oidc/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package oidc_test

import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
"github.com/stretchr/testify/require"

authkitoidc "github.com/meigma/authkit/proof/oidc"
)

const (
testAudience = "authkit-api"
testSubject = "user-123"
)

type failingProviderSource struct {
err error
}

func (s failingProviderSource) FindProvider(context.Context, string) (authkitoidc.Provider, error) {
return authkitoidc.Provider{}, s.err
}

type testIssuer struct {
server *httptest.Server
issuer string
jwksURL string
signingKey jwk.Key
publicSet jwk.Set
}

type tokenRequest struct {
issuer string
subject string
audiences []string
expiresAt time.Time
notBefore *time.Time
jwtID string
claims map[string]any
}

func newTestIssuer(t *testing.T) *testIssuer {
t.Helper()

return newTestIssuerWithPublicKey(t, nil)
}

func newTestIssuerWithPublicKey(t *testing.T, configure func(*testing.T, jwk.Key)) *testIssuer {
t.Helper()

rawKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
signingKey, err := jwk.Import(rawKey)
require.NoError(t, err)
require.NoError(t, signingKey.Set(jwk.KeyIDKey, "test-key"))
require.NoError(t, signingKey.Set(jwk.AlgorithmKey, jwa.RS256()))

privateSet := jwk.NewSet()
require.NoError(t, privateSet.AddKey(signingKey))
publicSet, err := jwk.PublicSetOf(privateSet)
require.NoError(t, err)
if configure != nil {
publicKey, ok := publicSet.Key(0)
require.True(t, ok)
configure(t, publicKey)
}

issuer := &testIssuer{
signingKey: signingKey,
publicSet: publicSet,
}
mux := http.NewServeMux()
mux.HandleFunc("/jwks", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(issuer.publicSet); err != nil {
t.Errorf("encode JWKS: %v", err)
}
})
mux.HandleFunc("/unavailable", func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
})
issuer.server = httptest.NewTLSServer(mux)
t.Cleanup(issuer.server.Close)
issuer.issuer = issuer.server.URL
issuer.jwksURL = issuer.server.URL + "/jwks"

return issuer
}

func (i *testIssuer) provider() authkitoidc.Provider {
return authkitoidc.Provider{
Issuer: i.issuer,
Audiences: []string{testAudience},
JWKSURL: i.jwksURL,
}
}

func (i *testIssuer) verifier(t *testing.T, opts ...authkitoidc.Option) *authkitoidc.Verifier {
t.Helper()

opts = append([]authkitoidc.Option{authkitoidc.WithHTTPClient(i.server.Client())}, opts...)

return newVerifier(t, i.provider(), opts...)
}

func (i *testIssuer) sign(t *testing.T, req tokenRequest) string {
t.Helper()

issuer := req.issuer
if issuer == "" {
issuer = i.issuer
}
builder := jwt.NewBuilder().
Issuer(issuer).
Audience(req.audiences).
IssuedAt(fixedTime().Add(-time.Minute))
if req.subject != "" {
builder.Subject(req.subject)
}
if !req.expiresAt.IsZero() {
builder.Expiration(req.expiresAt)
}
if req.notBefore != nil {
builder.NotBefore(*req.notBefore)
}
if req.jwtID != "" {
builder.JwtID(req.jwtID)
}
for name, value := range req.claims {
builder.Claim(name, value)
}

token, err := builder.Build()
require.NoError(t, err)
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), i.signingKey))
require.NoError(t, err)

return string(signed)
}

func newVerifier(
t *testing.T,
provider authkitoidc.Provider,
opts ...authkitoidc.Option,
) *authkitoidc.Verifier {
t.Helper()

source, err := authkitoidc.NewStaticProviderSource(provider)
require.NoError(t, err)
verifier, err := authkitoidc.NewVerifier(source, opts...)
require.NoError(t, err)

return verifier
}

func fixedTime() time.Time {
return time.Date(2026, 5, 8, 12, 0, 0, 0, time.UTC)
}
2 changes: 2 additions & 0 deletions proof/oidc/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const (
defaultKeySetCacheTTL = 5 * time.Minute
)

// options is the resolved configuration consumed by a Verifier.
type options struct {
httpClient *http.Client
clock func() time.Time
Expand All @@ -23,6 +24,7 @@ type options struct {
// Option configures a Verifier.
type Option func(*options)

// defaultOptions returns the baseline options applied before any caller-supplied Option.
func defaultOptions() options {
return options{
httpClient: &http.Client{Timeout: defaultHTTPTimeout},
Expand Down
88 changes: 88 additions & 0 deletions proof/oidc/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package oidc_test

import (
"testing"

"github.com/stretchr/testify/require"

authkitoidc "github.com/meigma/authkit/proof/oidc"
)

func TestProviderValidation(t *testing.T) {
tests := []struct {
name string
provider authkitoidc.Provider
}{
{
name: "missing issuer",
provider: authkitoidc.Provider{
Audiences: []string{testAudience},
JWKSURL: "https://issuer.example/jwks",
},
},
{
name: "missing audience",
provider: authkitoidc.Provider{
Issuer: "https://issuer.example",
JWKSURL: "https://issuer.example/jwks",
},
},
{
name: "missing JWKS URL",
provider: authkitoidc.Provider{
Issuer: "https://issuer.example",
Audiences: []string{testAudience},
},
},
{
name: "insecure issuer",
provider: authkitoidc.Provider{
Issuer: "http://issuer.example",
Audiences: []string{testAudience},
JWKSURL: "https://issuer.example/jwks",
},
},
{
name: "relative issuer",
provider: authkitoidc.Provider{
Issuer: "/issuer",
Audiences: []string{testAudience},
JWKSURL: "https://issuer.example/jwks",
},
},
{
name: "insecure JWKS URL",
provider: authkitoidc.Provider{
Issuer: "https://issuer.example",
Audiences: []string{testAudience},
JWKSURL: "http://issuer.example/jwks",
},
},
{
name: "symmetric algorithm",
provider: authkitoidc.Provider{
Issuer: "https://issuer.example",
Audiences: []string{testAudience},
JWKSURL: "https://issuer.example/jwks",
SupportedSigningAlgorithms: []string{"HS256"},
},
},
{
name: "unknown algorithm",
provider: authkitoidc.Provider{
Issuer: "https://issuer.example",
Audiences: []string{testAudience},
JWKSURL: "https://issuer.example/jwks",
SupportedSigningAlgorithms: []string{"unknown"},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.provider.Validate()

require.Error(t, err)
})
}
}
11 changes: 11 additions & 0 deletions proof/oidc/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ type StaticProviderSource struct {
providers map[string]Provider
}

var _ ProviderSource = (*StaticProviderSource)(nil)

// NewStaticProviderSource constructs a static provider source.
func NewStaticProviderSource(providers ...Provider) (*StaticProviderSource, error) {
source := &StaticProviderSource{
Expand Down Expand Up @@ -64,6 +66,8 @@ func (s *StaticProviderSource) FindProvider(ctx context.Context, issuer string)
return cloneProvider(provider), nil
}

// cloneProvider returns a deep copy of provider so callers cannot mutate the
// audience, signing-algorithm, or forwarded-claim slices the source retains.
func cloneProvider(provider Provider) Provider {
provider.Audiences = cloneStrings(provider.Audiences)
provider.SupportedSigningAlgorithms = cloneStrings(provider.SupportedSigningAlgorithms)
Expand All @@ -72,6 +76,7 @@ func cloneProvider(provider Provider) Provider {
return provider
}

// cloneStrings returns a defensive copy of values, returning nil for empty input.
func cloneStrings(values []string) []string {
if len(values) == 0 {
return nil
Expand All @@ -83,6 +88,7 @@ func cloneStrings(values []string) []string {
return cloned
}

// cloneClaimPaths returns a deep copy of paths, returning nil for empty input.
func cloneClaimPaths(paths []authkit.ClaimPath) []authkit.ClaimPath {
if len(paths) == 0 {
return nil
Expand All @@ -96,6 +102,7 @@ func cloneClaimPaths(paths []authkit.ClaimPath) []authkit.ClaimPath {
return cloned
}

// cloneClaimPath returns a defensive copy of path, returning nil for empty input.
func cloneClaimPath(path authkit.ClaimPath) authkit.ClaimPath {
if len(path) == 0 {
return nil
Expand All @@ -107,6 +114,8 @@ func cloneClaimPath(path authkit.ClaimPath) authkit.ClaimPath {
return cloned
}

// claimPathKey serializes path into a deterministic dedup key. Returns the empty
// string when path is invalid so the caller can skip it.
func claimPathKey(path authkit.ClaimPath) string {
if !path.Valid() {
return ""
Expand All @@ -115,6 +124,8 @@ func claimPathKey(path authkit.ClaimPath) string {
return strings.Join(path, "\x00")
}

// cloneClaimValue returns a deep copy of value covering the JSON-decoded shapes
// that token claims can carry. Other scalars are returned as-is.
func cloneClaimValue(value any) any {
switch typed := value.(type) {
case map[string]any:
Expand Down
53 changes: 53 additions & 0 deletions proof/oidc/source_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package oidc_test

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

authkitoidc "github.com/meigma/authkit/proof/oidc"
)

func TestStaticProviderSourceFindsProvidersByIssuer(t *testing.T) {
source, err := authkitoidc.NewStaticProviderSource(authkitoidc.Provider{
Issuer: "https://issuer.example",
Audiences: []string{testAudience},
JWKSURL: "https://issuer.example/jwks",
SupportedSigningAlgorithms: []string{"RS256"},
})
require.NoError(t, err)

provider, err := source.FindProvider(context.Background(), "https://issuer.example")
require.NoError(t, err)
assert.Equal(t, "https://issuer.example", provider.Issuer)

provider.Audiences[0] = "mutated"
provider, err = source.FindProvider(context.Background(), "https://issuer.example")
require.NoError(t, err)
assert.Equal(t, []string{testAudience}, provider.Audiences)
}

func TestStaticProviderSourceRejectsDuplicateIssuers(t *testing.T) {
provider := authkitoidc.Provider{
Issuer: "https://issuer.example",
Audiences: []string{testAudience},
JWKSURL: "https://issuer.example/jwks",
}

source, err := authkitoidc.NewStaticProviderSource(provider, provider)

require.Error(t, err)
assert.Nil(t, source)
}

func TestStaticProviderSourceReturnsProviderNotFound(t *testing.T) {
source, err := authkitoidc.NewStaticProviderSource()
require.NoError(t, err)

provider, err := source.FindProvider(context.Background(), "https://issuer.example")

require.ErrorIs(t, err, authkitoidc.ErrProviderNotFound)
assert.Empty(t, provider)
}
Loading