diff --git a/libs/go-libs/auth/auth.go b/libs/go-libs/auth/auth.go index 289db2c98c..920f8e9884 100644 --- a/libs/go-libs/auth/auth.go +++ b/libs/go-libs/auth/auth.go @@ -1,27 +1,22 @@ package auth import ( - "context" "fmt" "net/http" - "os" "strings" "github.com/formancehq/stack/libs/go-libs/collectionutils" "github.com/formancehq/stack/libs/go-libs/logging" "github.com/hashicorp/go-retryablehttp" - "github.com/zitadel/oidc/v2/pkg/client/rp" "github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v2/pkg/op" "go.uber.org/zap" ) type jwtAuth struct { - logger logging.Logger - httpClient *http.Client - accessTokenVerifier op.AccessTokenVerifier - - issuer string + logger logging.Logger + httpClient *http.Client + verifiers map[string]op.AccessTokenVerifier // issuer -> verifier checkScopes bool service string } @@ -35,17 +30,16 @@ func newOtlpHttpClient(maxRetries int) *http.Client { func newJWTAuth( logger logging.Logger, readKeySetMaxRetries int, - issuer string, + verifiers map[string]op.AccessTokenVerifier, service string, checkScopes bool, ) *jwtAuth { return &jwtAuth{ - logger: logger, - httpClient: newOtlpHttpClient(readKeySetMaxRetries), - accessTokenVerifier: nil, - issuer: issuer, - checkScopes: checkScopes, - service: service, + logger: logger, + httpClient: newOtlpHttpClient(readKeySetMaxRetries), + verifiers: verifiers, + checkScopes: checkScopes, + service: service, } } @@ -66,13 +60,28 @@ func (ja *jwtAuth) Authenticate(w http.ResponseWriter, r *http.Request) (bool, e token := strings.TrimPrefix(authHeader, strings.ToLower(oidc.PrefixBearer)) token = strings.TrimPrefix(token, oidc.PrefixBearer) - accessTokenVerifier, err := ja.getAccessTokenVerifier(r.Context()) - if err != nil { - ja.logger.Error("unable to create access token verifier", zap.Error(err)) - return false, fmt.Errorf("unable to create access token verifier: %w", err) + // Pre-parse the token to extract the issuer claim, so we can select + // the correct verifier (each issuer has its own key set). + var preClaims oidc.TokenClaims + if _, err := oidc.ParseToken(token, &preClaims); err != nil { + ja.logger.Error("unable to parse token", zap.Error(err)) + return false, fmt.Errorf("unable to parse token: %w", err) } - claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](r.Context(), token, accessTokenVerifier) + verifier, ok := ja.verifiers[preClaims.Issuer] + if !ok { + issuers := make([]string, 0, len(ja.verifiers)) + for iss := range ja.verifiers { + issuers = append(issuers, iss) + } + ja.logger.Error("untrusted issuer", + zap.String("got", preClaims.Issuer), + zap.Strings("trusted", issuers), + ) + return false, fmt.Errorf("issuer does not match: got: %s, trusted: %v", preClaims.Issuer, issuers) + } + + claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](r.Context(), token, verifier) if err != nil { ja.logger.Error("unable to verify access token", zap.Error(err)) return false, fmt.Errorf("unable to verify access token: %w", err) @@ -97,26 +106,3 @@ func (ja *jwtAuth) Authenticate(w http.ResponseWriter, r *http.Request) (bool, e return true, nil } - -func (ja *jwtAuth) getAccessTokenVerifier(ctx context.Context) (op.AccessTokenVerifier, error) { - if ja.accessTokenVerifier == nil { - //discoveryConfiguration, err := client.Discover(ja.Issuer, ja.httpClient) - //if err != nil { - // return nil, err - //} - - // todo: ugly quick fix - authServicePort := "8080" - if fromEnv := os.Getenv("AUTH_SERVICE_PORT"); fromEnv != "" { - authServicePort = fromEnv - } - keySet := rp.NewRemoteKeySet(ja.httpClient, fmt.Sprintf("http://auth:%s/keys", authServicePort)) - - ja.accessTokenVerifier = op.NewAccessTokenVerifier( - os.Getenv("STACK_PUBLIC_URL")+"/api/auth", - keySet, - ) - } - - return ja.accessTokenVerifier, nil -} diff --git a/libs/go-libs/auth/cli.go b/libs/go-libs/auth/cli.go index b324af6d28..eb3cb3de48 100644 --- a/libs/go-libs/auth/cli.go +++ b/libs/go-libs/auth/cli.go @@ -9,6 +9,7 @@ import ( const ( AuthEnabled = "auth-enabled" AuthIssuerFlag = "auth-issuer" + AuthIssuersFlag = "auth-issuers" AuthReadKeySetMaxRetriesFlag = "auth-read-key-set-max-retries" AuthCheckScopesFlag = "auth-check-scopes" AuthServiceFlag = "auth-service" @@ -16,16 +17,34 @@ const ( func InitAuthFlags(flags *flag.FlagSet) { flags.Bool(AuthEnabled, false, "Enable auth") - flags.String(AuthIssuerFlag, "", "Issuer") + flags.String(AuthIssuerFlag, "", "Issuer (single issuer, for backward compatibility)") + flags.StringSlice(AuthIssuersFlag, nil, "Trusted issuers (comma-separated, e.g. --auth-issuers=https://issuer1,https://issuer2)") flags.Int(AuthReadKeySetMaxRetriesFlag, 10, "ReadKeySetMaxRetries") flags.Bool(AuthCheckScopesFlag, false, "CheckScopes") flags.String(AuthServiceFlag, "", "Service") } func CLIAuthModule() fx.Option { + authIssuer := viper.GetString(AuthIssuerFlag) + authIssuers := viper.GetStringSlice(AuthIssuersFlag) + + // Merge --auth-issuer into --auth-issuers for backward compatibility + if authIssuer != "" { + found := false + for _, iss := range authIssuers { + if iss == authIssuer { + found = true + break + } + } + if !found { + authIssuers = append(authIssuers, authIssuer) + } + } + return Module(ModuleConfig{ Enabled: viper.GetBool(AuthEnabled), - Issuer: viper.GetString(AuthIssuerFlag), + Issuers: authIssuers, ReadKeySetMaxRetries: viper.GetInt(AuthReadKeySetMaxRetriesFlag), CheckScopes: viper.GetBool(AuthCheckScopesFlag), Service: viper.GetString(AuthServiceFlag), diff --git a/libs/go-libs/auth/module.go b/libs/go-libs/auth/module.go index 7bc29379b7..451998d589 100644 --- a/libs/go-libs/auth/module.go +++ b/libs/go-libs/auth/module.go @@ -1,21 +1,55 @@ package auth import ( + "errors" + "net/http" + "time" + "github.com/formancehq/stack/libs/go-libs/logging" + "github.com/hashicorp/go-retryablehttp" + "github.com/zitadel/oidc/v2/pkg/client" + "github.com/zitadel/oidc/v2/pkg/client/rp" + "github.com/zitadel/oidc/v2/pkg/op" "go.uber.org/fx" ) type ModuleConfig struct { Enabled bool - Issuer string + Issuers []string ReadKeySetMaxRetries int CheckScopes bool Service string + + // Deprecated: use Issuers instead. + Issuer string +} + +func (cfg ModuleConfig) resolveIssuers() []string { + issuers := cfg.Issuers + if cfg.Issuer != "" { + found := false + for _, iss := range issuers { + if iss == cfg.Issuer { + found = true + break + } + } + if !found { + issuers = append(issuers, cfg.Issuer) + } + } + return issuers } func Module(cfg ModuleConfig) fx.Option { options := make([]fx.Option, 0) + issuers := cfg.resolveIssuers() + + if cfg.Enabled && len(issuers) == 0 { + return fx.Error(errors.New("auth is enabled but no issuers are configured")) + } + options = append(options, fx.Provide(func() Auth { return NewNoAuth() @@ -24,14 +58,27 @@ func Module(cfg ModuleConfig) fx.Option { if cfg.Enabled { options = append(options, - fx.Decorate(func(logger logging.Logger) Auth { + fx.Decorate(func(logger logging.Logger) (Auth, error) { + retryClient := retryablehttp.NewClient() + retryClient.RetryMax = cfg.ReadKeySetMaxRetries + discoveryHTTPClient := retryClient.StandardClient() + + verifiers := make(map[string]op.AccessTokenVerifier, len(issuers)) + for _, issuer := range issuers { + discovery, err := client.Discover(issuer, discoveryHTTPClient) + if err != nil { + return nil, err + } + keySet := rp.NewRemoteKeySet(&http.Client{Timeout: 10 * time.Second}, discovery.JwksURI) + verifiers[issuer] = op.NewAccessTokenVerifier(issuer, keySet) + } return newJWTAuth( logger, cfg.ReadKeySetMaxRetries, - cfg.Issuer, + verifiers, cfg.Service, cfg.CheckScopes, - ) + ), nil }), ) }