Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 29 additions & 43 deletions libs/go-libs/auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
package auth

import (
"context"
"fmt"
"net/http"
"os"
"strings"

"github.com/formancehq/stack/libs/go-libs/collectionutils"
"github.com/formancehq/stack/libs/go-libs/logging"
"github.com/hashicorp/go-retryablehttp"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"go.uber.org/zap"
)

type jwtAuth struct {
logger logging.Logger
httpClient *http.Client
accessTokenVerifier op.AccessTokenVerifier

issuer string
logger logging.Logger
httpClient *http.Client
verifiers map[string]op.AccessTokenVerifier // issuer -> verifier
checkScopes bool
service string
}
Expand All @@ -35,17 +30,16 @@ func newOtlpHttpClient(maxRetries int) *http.Client {
func newJWTAuth(
logger logging.Logger,
readKeySetMaxRetries int,
issuer string,
verifiers map[string]op.AccessTokenVerifier,
service string,
checkScopes bool,
) *jwtAuth {
return &jwtAuth{
logger: logger,
httpClient: newOtlpHttpClient(readKeySetMaxRetries),
accessTokenVerifier: nil,
issuer: issuer,
checkScopes: checkScopes,
service: service,
logger: logger,
httpClient: newOtlpHttpClient(readKeySetMaxRetries),
verifiers: verifiers,
checkScopes: checkScopes,
service: service,
}
}

Expand All @@ -66,13 +60,28 @@ func (ja *jwtAuth) Authenticate(w http.ResponseWriter, r *http.Request) (bool, e
token := strings.TrimPrefix(authHeader, strings.ToLower(oidc.PrefixBearer))
token = strings.TrimPrefix(token, oidc.PrefixBearer)

accessTokenVerifier, err := ja.getAccessTokenVerifier(r.Context())
if err != nil {
ja.logger.Error("unable to create access token verifier", zap.Error(err))
return false, fmt.Errorf("unable to create access token verifier: %w", err)
// Pre-parse the token to extract the issuer claim, so we can select
// the correct verifier (each issuer has its own key set).
var preClaims oidc.TokenClaims
if _, err := oidc.ParseToken(token, &preClaims); err != nil {
ja.logger.Error("unable to parse token", zap.Error(err))
return false, fmt.Errorf("unable to parse token: %w", err)
}

claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](r.Context(), token, accessTokenVerifier)
verifier, ok := ja.verifiers[preClaims.Issuer]
if !ok {
issuers := make([]string, 0, len(ja.verifiers))
for iss := range ja.verifiers {
issuers = append(issuers, iss)
}
ja.logger.Error("untrusted issuer",
zap.String("got", preClaims.Issuer),
zap.Strings("trusted", issuers),
)
return false, fmt.Errorf("issuer does not match: got: %s, trusted: %v", preClaims.Issuer, issuers)
}

claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](r.Context(), token, verifier)
if err != nil {
ja.logger.Error("unable to verify access token", zap.Error(err))
return false, fmt.Errorf("unable to verify access token: %w", err)
Expand All @@ -97,26 +106,3 @@ func (ja *jwtAuth) Authenticate(w http.ResponseWriter, r *http.Request) (bool, e

return true, nil
}

func (ja *jwtAuth) getAccessTokenVerifier(ctx context.Context) (op.AccessTokenVerifier, error) {
if ja.accessTokenVerifier == nil {
//discoveryConfiguration, err := client.Discover(ja.Issuer, ja.httpClient)
//if err != nil {
// return nil, err
//}

// todo: ugly quick fix
authServicePort := "8080"
if fromEnv := os.Getenv("AUTH_SERVICE_PORT"); fromEnv != "" {
authServicePort = fromEnv
}
keySet := rp.NewRemoteKeySet(ja.httpClient, fmt.Sprintf("http://auth:%s/keys", authServicePort))

ja.accessTokenVerifier = op.NewAccessTokenVerifier(
os.Getenv("STACK_PUBLIC_URL")+"/api/auth",
keySet,
)
}

return ja.accessTokenVerifier, nil
}
23 changes: 21 additions & 2 deletions libs/go-libs/auth/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,42 @@ import (
const (
AuthEnabled = "auth-enabled"
AuthIssuerFlag = "auth-issuer"
AuthIssuersFlag = "auth-issuers"
AuthReadKeySetMaxRetriesFlag = "auth-read-key-set-max-retries"
AuthCheckScopesFlag = "auth-check-scopes"
AuthServiceFlag = "auth-service"
)

func InitAuthFlags(flags *flag.FlagSet) {
flags.Bool(AuthEnabled, false, "Enable auth")
flags.String(AuthIssuerFlag, "", "Issuer")
flags.String(AuthIssuerFlag, "", "Issuer (single issuer, for backward compatibility)")
flags.StringSlice(AuthIssuersFlag, nil, "Trusted issuers (comma-separated, e.g. --auth-issuers=https://issuer1,https://issuer2)")
flags.Int(AuthReadKeySetMaxRetriesFlag, 10, "ReadKeySetMaxRetries")
flags.Bool(AuthCheckScopesFlag, false, "CheckScopes")
flags.String(AuthServiceFlag, "", "Service")
}

func CLIAuthModule() fx.Option {
authIssuer := viper.GetString(AuthIssuerFlag)
authIssuers := viper.GetStringSlice(AuthIssuersFlag)

// Merge --auth-issuer into --auth-issuers for backward compatibility
if authIssuer != "" {
found := false
for _, iss := range authIssuers {
if iss == authIssuer {
found = true
break
}
}
if !found {
authIssuers = append(authIssuers, authIssuer)
}
}

return Module(ModuleConfig{
Enabled: viper.GetBool(AuthEnabled),
Issuer: viper.GetString(AuthIssuerFlag),
Issuers: authIssuers,
ReadKeySetMaxRetries: viper.GetInt(AuthReadKeySetMaxRetriesFlag),
CheckScopes: viper.GetBool(AuthCheckScopesFlag),
Service: viper.GetString(AuthServiceFlag),
Expand Down
55 changes: 51 additions & 4 deletions libs/go-libs/auth/module.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,55 @@
package auth

import (
"errors"
"net/http"
"time"

"github.com/formancehq/stack/libs/go-libs/logging"
"github.com/hashicorp/go-retryablehttp"
"github.com/zitadel/oidc/v2/pkg/client"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/op"
"go.uber.org/fx"
)

type ModuleConfig struct {
Enabled bool
Issuer string
Issuers []string
ReadKeySetMaxRetries int
CheckScopes bool
Service string

// Deprecated: use Issuers instead.
Issuer string
}

func (cfg ModuleConfig) resolveIssuers() []string {
issuers := cfg.Issuers
if cfg.Issuer != "" {
found := false
for _, iss := range issuers {
if iss == cfg.Issuer {
found = true
break
}
}
if !found {
issuers = append(issuers, cfg.Issuer)
}
}
return issuers
}

func Module(cfg ModuleConfig) fx.Option {
options := make([]fx.Option, 0)

issuers := cfg.resolveIssuers()

if cfg.Enabled && len(issuers) == 0 {
return fx.Error(errors.New("auth is enabled but no issuers are configured"))
}

options = append(options,
fx.Provide(func() Auth {
return NewNoAuth()
Expand All @@ -24,14 +58,27 @@ func Module(cfg ModuleConfig) fx.Option {

if cfg.Enabled {
options = append(options,
fx.Decorate(func(logger logging.Logger) Auth {
fx.Decorate(func(logger logging.Logger) (Auth, error) {
retryClient := retryablehttp.NewClient()
retryClient.RetryMax = cfg.ReadKeySetMaxRetries
discoveryHTTPClient := retryClient.StandardClient()

verifiers := make(map[string]op.AccessTokenVerifier, len(issuers))
for _, issuer := range issuers {
discovery, err := client.Discover(issuer, discoveryHTTPClient)
if err != nil {
return nil, err
}
keySet := rp.NewRemoteKeySet(&http.Client{Timeout: 10 * time.Second}, discovery.JwksURI)
verifiers[issuer] = op.NewAccessTokenVerifier(issuer, keySet)
}
return newJWTAuth(
logger,
cfg.ReadKeySetMaxRetries,
cfg.Issuer,
verifiers,
cfg.Service,
cfg.CheckScopes,
)
), nil
}),
)
}
Expand Down
Loading