From d9996706c213b81c04d7fc77ea78e1d589598803 Mon Sep 17 00:00:00 2001 From: Gustavo Carvalho Date: Tue, 17 Mar 2026 16:54:29 -0300 Subject: [PATCH 1/5] Add JWT bootstrap command and wire key setup into run --- .env.example | 1 + Makefile | 3 +- README.md | 6 + cmd/bootstrap.go | 41 +++++++ cmd/jwt_bootstrap.go | 51 +++++++++ cmd/root.go | 1 + cmd/run.go | 24 ++++ internal/config/jwt_bootstrap.go | 152 ++++++++++++++++++++++++++ internal/config/jwt_bootstrap_test.go | 88 +++++++++++++++ 9 files changed, 365 insertions(+), 2 deletions(-) create mode 100644 cmd/bootstrap.go create mode 100644 cmd/jwt_bootstrap.go create mode 100644 internal/config/jwt_bootstrap.go create mode 100644 internal/config/jwt_bootstrap_test.go diff --git a/.env.example b/.env.example index c3872c83..4dbd9664 100644 --- a/.env.example +++ b/.env.example @@ -6,6 +6,7 @@ CCF_DB_CONNECTION="host=db user=postgres password=postgres dbname=ccf port=5432 CCF_JWT_SECRET="some-secret" CCF_JWT_PRIVATE_KEY=private.pem CCF_JWT_PUBLIC_KEY=public.pem +# Unset both CCF_JWT_PRIVATE_KEY and CCF_JWT_PUBLIC_KEY to use in-memory key generation. CCF_API_ALLOWED_ORIGINS="http://localhost:3000,http://localhost:8000" CCF_RISK_CONFIG="risk.yaml" diff --git a/Makefile b/Makefile index f6203def..256c9b36 100644 --- a/Makefile +++ b/Makefile @@ -145,8 +145,7 @@ swag: ## swag setup and lint .PHONY: generate-keys generate-keys: @$(INFO) "Generating keys for the service" - @openssl genrsa -out private_key.pem 2048 - @openssl rsa -in private_key.pem -pubout -out public_key.pem + @go run main.go bootstrap --private-key private_key.pem --public-key public_key.pem --force @$(OK) keys generated tag: ## Build and tag a production-based image of the service diff --git a/README.md b/README.md index 9ac5e2d3..c596fb0d 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,8 @@ Some examples include: ```shell $ go run main.go run # Run the API itself +$ go run main.go bootstrap # Bootstrap JWT key files (defaults to CCF_JWT_* or private.pem/public.pem) + $ go run main.go users add # Create a new user in the CCF API which can be used to authenticate with $ go run main.go migrate up # Create the database schema, or upgrade it to the current version @@ -62,6 +64,10 @@ You can configure the API using environment variables or a `.env` file. Available variables are shown in [`.env.example`](./.env.example) +JWT key behavior: +- If both `CCF_JWT_PRIVATE_KEY` and `CCF_JWT_PUBLIC_KEY` are set, `run` bootstraps key files at those paths (if needed) and then loads them. +- If either variable is unset, the API falls back to in-memory key generation. + Copy this file to .env to configure environment variables ```shell cp .env.example .env diff --git a/cmd/bootstrap.go b/cmd/bootstrap.go new file mode 100644 index 00000000..ed983a71 --- /dev/null +++ b/cmd/bootstrap.go @@ -0,0 +1,41 @@ +package cmd + +import "github.com/spf13/cobra" + +func newBootstrapCMD() *cobra.Command { + var ( + privateKeyPath string + publicKeyPath string + bitSize int + force bool + ) + + bootstrap := &cobra.Command{ + Use: "bootstrap", + Short: "Bootstrap JWT key files", + RunE: func(cmd *cobra.Command, args []string) error { + privateKeyPath, publicKeyPath = resolveJWTKeyPathsForBootstrap(privateKeyPath, publicKeyPath) + + action, err := runJWTBootstrap(privateKeyPath, publicKeyPath, bitSize, force) + if err != nil { + return err + } + + cmd.Printf( + "JWT bootstrap complete (action=%s, private=%s, public=%s)\n", + action, + privateKeyPath, + publicKeyPath, + ) + + return nil + }, + } + + bootstrap.Flags().StringVar(&privateKeyPath, "private-key", "", "Path to the JWT private key file") + bootstrap.Flags().StringVar(&publicKeyPath, "public-key", "", "Path to the JWT public key file") + bootstrap.Flags().IntVar(&bitSize, "bit-size", defaultJWTKeyBitSize, "RSA key size in bits") + bootstrap.Flags().BoolVar(&force, "force", false, "Regenerate key files even when they already exist") + + return bootstrap +} diff --git a/cmd/jwt_bootstrap.go b/cmd/jwt_bootstrap.go new file mode 100644 index 00000000..f5777e1e --- /dev/null +++ b/cmd/jwt_bootstrap.go @@ -0,0 +1,51 @@ +package cmd + +import ( + "strings" + + "github.com/compliance-framework/api/internal/config" + "github.com/spf13/viper" +) + +const defaultJWTKeyBitSize = 2048 + +func bootstrapConfiguredJWTKeys(bitSize int, force bool) (config.JWTKeyBootstrapAction, string, string, bool, error) { + privateKeyPath := strings.TrimSpace(viper.GetString("jwt_private_key")) + publicKeyPath := strings.TrimSpace(viper.GetString("jwt_public_key")) + + if privateKeyPath == "" || publicKeyPath == "" { + return "", "", "", false, nil + } + + action, err := runJWTBootstrap(privateKeyPath, publicKeyPath, bitSize, force) + if err != nil { + return "", privateKeyPath, publicKeyPath, true, err + } + + return action, privateKeyPath, publicKeyPath, true, nil +} + +func resolveJWTKeyPathsForBootstrap(privateKeyPath, publicKeyPath string) (string, string) { + privateKeyPath = strings.TrimSpace(privateKeyPath) + publicKeyPath = strings.TrimSpace(publicKeyPath) + + if privateKeyPath == "" { + privateKeyPath = strings.TrimSpace(viper.GetString("jwt_private_key")) + } + if publicKeyPath == "" { + publicKeyPath = strings.TrimSpace(viper.GetString("jwt_public_key")) + } + + if privateKeyPath == "" { + privateKeyPath = "private.pem" + } + if publicKeyPath == "" { + publicKeyPath = "public.pem" + } + + return privateKeyPath, publicKeyPath +} + +func runJWTBootstrap(privateKeyPath, publicKeyPath string, bitSize int, force bool) (config.JWTKeyBootstrapAction, error) { + return config.BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, bitSize, force) +} diff --git a/cmd/root.go b/cmd/root.go index 891c6017..49627cd5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -78,6 +78,7 @@ func init() { rootCmd.AddCommand(users.RootCmd) rootCmd.AddCommand(seed.RootCmd) rootCmd.AddCommand(newMigrateCMD()) + rootCmd.AddCommand(newBootstrapCMD()) rootCmd.AddCommand(dashboards.RootCmd) rootCmd.AddCommand(DigestCmd) } diff --git a/cmd/run.go b/cmd/run.go index 1dfa0cda..5f8376a7 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -46,6 +46,30 @@ func RunServer(cmd *cobra.Command, args []string) { } }() + bootstrapAction, privateKeyPath, publicKeyPath, bootstrapConfigured, err := bootstrapConfiguredJWTKeys(defaultJWTKeyBitSize, false) + if err != nil { + sugar.Fatalw( + "Failed to bootstrap JWT key files", + "error", + err, + "private_key_path", + privateKeyPath, + "public_key_path", + publicKeyPath, + ) + } + if bootstrapConfigured { + sugar.Infow( + "JWT key bootstrap completed", + "action", + bootstrapAction, + "private_key_path", + privateKeyPath, + "public_key_path", + publicKeyPath, + ) + } + cfg := config.NewConfig(sugar) db, err := service.ConnectSQLDb(ctx, cfg, sugar) diff --git a/internal/config/jwt_bootstrap.go b/internal/config/jwt_bootstrap.go new file mode 100644 index 00000000..0f453994 --- /dev/null +++ b/internal/config/jwt_bootstrap.go @@ -0,0 +1,152 @@ +package config + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "path/filepath" + "strings" +) + +type JWTKeyBootstrapAction string + +const ( + JWTKeyBootstrapNoop JWTKeyBootstrapAction = "noop" + JWTKeyBootstrapGenerated JWTKeyBootstrapAction = "generated" + JWTKeyBootstrapDerivedPublic JWTKeyBootstrapAction = "derived_public" + JWTKeyBootstrapRegenerated JWTKeyBootstrapAction = "regenerated" +) + +// BootstrapJWTKeyPair ensures a matching JWT RSA keypair exists at the given paths. +// +// Behavior: +// - when both files exist and force=false: no-op +// - when only private exists and force=false: derives and writes the public key +// - otherwise: generates a new keypair and writes both files +func BootstrapJWTKeyPair(privateKeyPath, publicKeyPath string, bitSize int, force bool) (JWTKeyBootstrapAction, error) { + privateKeyPath = strings.TrimSpace(privateKeyPath) + publicKeyPath = strings.TrimSpace(publicKeyPath) + + if privateKeyPath == "" { + return "", fmt.Errorf("private key path cannot be empty") + } + if publicKeyPath == "" { + return "", fmt.Errorf("public key path cannot be empty") + } + if bitSize <= 0 { + return "", fmt.Errorf("bit size must be greater than 0") + } + + privateExists, err := fileExists(privateKeyPath) + if err != nil { + return "", fmt.Errorf("failed to check private key path %s: %w", privateKeyPath, err) + } + + publicExists, err := fileExists(publicKeyPath) + if err != nil { + return "", fmt.Errorf("failed to check public key path %s: %w", publicKeyPath, err) + } + + if !force { + if privateExists && publicExists { + return JWTKeyBootstrapNoop, nil + } + + if privateExists && !publicExists { + privateKey, err := loadRSAPrivateKey(privateKeyPath) + if err != nil { + return "", fmt.Errorf("failed to load existing private key from %s: %w", privateKeyPath, err) + } + + if err := writeRSAPublicKey(publicKeyPath, &privateKey.PublicKey); err != nil { + return "", err + } + + return JWTKeyBootstrapDerivedPublic, nil + } + } + + privateKey, publicKey, err := GenerateKeyPair(bitSize) + if err != nil { + return "", err + } + + if err := writeRSAPrivateKey(privateKeyPath, privateKey); err != nil { + return "", err + } + + if err := writeRSAPublicKey(publicKeyPath, publicKey); err != nil { + return "", err + } + + if privateExists || publicExists { + return JWTKeyBootstrapRegenerated, nil + } + + return JWTKeyBootstrapGenerated, nil +} + +func fileExists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func writeRSAPrivateKey(path string, key *rsa.PrivateKey) error { + if err := ensureParentDirectory(path); err != nil { + return err + } + + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + + if err := os.WriteFile(path, privateKeyPEM, 0o600); err != nil { + return fmt.Errorf("unable to write private key to %s: %w", path, err) + } + + return nil +} + +func writeRSAPublicKey(path string, key *rsa.PublicKey) error { + if err := ensureParentDirectory(path); err != nil { + return err + } + + publicKeyDER, err := x509.MarshalPKIXPublicKey(key) + if err != nil { + return fmt.Errorf("unable to marshal public key: %w", err) + } + + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyDER, + }) + + if err := os.WriteFile(path, publicKeyPEM, 0o644); err != nil { + return fmt.Errorf("unable to write public key to %s: %w", path, err) + } + + return nil +} + +func ensureParentDirectory(path string) error { + dir := filepath.Dir(path) + if dir == "." || dir == "" { + return nil + } + + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("unable to create parent directory %s: %w", dir, err) + } + + return nil +} diff --git a/internal/config/jwt_bootstrap_test.go b/internal/config/jwt_bootstrap_test.go new file mode 100644 index 00000000..25960e0f --- /dev/null +++ b/internal/config/jwt_bootstrap_test.go @@ -0,0 +1,88 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestBootstrapJWTKeyPair_GenerateAndNoop(t *testing.T) { + privateKeyPath := filepath.Join(t.TempDir(), "private.pem") + publicKeyPath := filepath.Join(t.TempDir(), "public.pem") + + action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, 1024, false) + if err != nil { + t.Fatalf("expected bootstrap to generate keys, got error: %v", err) + } + if action != JWTKeyBootstrapGenerated { + t.Fatalf("expected action %q, got %q", JWTKeyBootstrapGenerated, action) + } + + if _, err := loadRSAPrivateKey(privateKeyPath); err != nil { + t.Fatalf("expected generated private key to be readable, got error: %v", err) + } + if _, err := loadRSAPublicKey(publicKeyPath); err != nil { + t.Fatalf("expected generated public key to be readable, got error: %v", err) + } + + action, err = BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, 1024, false) + if err != nil { + t.Fatalf("expected bootstrap noop when both keys exist, got error: %v", err) + } + if action != JWTKeyBootstrapNoop { + t.Fatalf("expected action %q, got %q", JWTKeyBootstrapNoop, action) + } +} + +func TestBootstrapJWTKeyPair_DerivePublicFromExistingPrivate(t *testing.T) { + privateKeyPath := filepath.Join(t.TempDir(), "private.pem") + publicKeyPath := filepath.Join(filepath.Dir(privateKeyPath), "public.pem") + + if _, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, 1024, false); err != nil { + t.Fatalf("expected setup bootstrap to generate keys, got error: %v", err) + } + + if err := os.Remove(publicKeyPath); err != nil { + t.Fatalf("expected to remove public key for test setup, got error: %v", err) + } + + action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, 1024, false) + if err != nil { + t.Fatalf("expected bootstrap to derive missing public key, got error: %v", err) + } + if action != JWTKeyBootstrapDerivedPublic { + t.Fatalf("expected action %q, got %q", JWTKeyBootstrapDerivedPublic, action) + } + + if _, err := loadRSAPublicKey(publicKeyPath); err != nil { + t.Fatalf("expected derived public key to be readable, got error: %v", err) + } +} + +func TestBootstrapJWTKeyPair_RegenerateWhenOnlyPublicExists(t *testing.T) { + privateKeyPath := filepath.Join(t.TempDir(), "private.pem") + publicKeyPath := filepath.Join(filepath.Dir(privateKeyPath), "public.pem") + + if _, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, 1024, false); err != nil { + t.Fatalf("expected setup bootstrap to generate keys, got error: %v", err) + } + + if err := os.Remove(privateKeyPath); err != nil { + t.Fatalf("expected to remove private key for test setup, got error: %v", err) + } + + action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, 1024, false) + if err != nil { + t.Fatalf("expected bootstrap to regenerate missing private key, got error: %v", err) + } + if action != JWTKeyBootstrapRegenerated { + t.Fatalf("expected action %q, got %q", JWTKeyBootstrapRegenerated, action) + } + + if _, err := loadRSAPrivateKey(privateKeyPath); err != nil { + t.Fatalf("expected regenerated private key to be readable, got error: %v", err) + } + if _, err := loadRSAPublicKey(publicKeyPath); err != nil { + t.Fatalf("expected regenerated public key to be readable, got error: %v", err) + } +} From f7ad55fc1ee2f5a188414d555602a900aac927f7 Mon Sep 17 00:00:00 2001 From: Gustavo Carvalho Date: Tue, 17 Mar 2026 17:13:33 -0300 Subject: [PATCH 2/5] Harden JWT bootstrap key validation and writes --- cmd/bootstrap.go | 2 +- cmd/jwt_bootstrap.go | 22 +++- internal/config/jwt_bootstrap.go | 165 ++++++++++++++++++++++---- internal/config/jwt_bootstrap_test.go | 137 ++++++++++++--------- 4 files changed, 239 insertions(+), 87 deletions(-) diff --git a/cmd/bootstrap.go b/cmd/bootstrap.go index ed983a71..2051a080 100644 --- a/cmd/bootstrap.go +++ b/cmd/bootstrap.go @@ -12,7 +12,7 @@ func newBootstrapCMD() *cobra.Command { bootstrap := &cobra.Command{ Use: "bootstrap", - Short: "Bootstrap JWT key files", + Short: "Initialize JWT signing key files for API startup", RunE: func(cmd *cobra.Command, args []string) error { privateKeyPath, publicKeyPath = resolveJWTKeyPathsForBootstrap(privateKeyPath, publicKeyPath) diff --git a/cmd/jwt_bootstrap.go b/cmd/jwt_bootstrap.go index f5777e1e..4ab7d0c4 100644 --- a/cmd/jwt_bootstrap.go +++ b/cmd/jwt_bootstrap.go @@ -10,8 +10,8 @@ import ( const defaultJWTKeyBitSize = 2048 func bootstrapConfiguredJWTKeys(bitSize int, force bool) (config.JWTKeyBootstrapAction, string, string, bool, error) { - privateKeyPath := strings.TrimSpace(viper.GetString("jwt_private_key")) - publicKeyPath := strings.TrimSpace(viper.GetString("jwt_public_key")) + privateKeyPath := normalizePathValue(viper.GetString("jwt_private_key")) + publicKeyPath := normalizePathValue(viper.GetString("jwt_public_key")) if privateKeyPath == "" || publicKeyPath == "" { return "", "", "", false, nil @@ -26,14 +26,14 @@ func bootstrapConfiguredJWTKeys(bitSize int, force bool) (config.JWTKeyBootstrap } func resolveJWTKeyPathsForBootstrap(privateKeyPath, publicKeyPath string) (string, string) { - privateKeyPath = strings.TrimSpace(privateKeyPath) - publicKeyPath = strings.TrimSpace(publicKeyPath) + privateKeyPath = normalizePathValue(privateKeyPath) + publicKeyPath = normalizePathValue(publicKeyPath) if privateKeyPath == "" { - privateKeyPath = strings.TrimSpace(viper.GetString("jwt_private_key")) + privateKeyPath = normalizePathValue(viper.GetString("jwt_private_key")) } if publicKeyPath == "" { - publicKeyPath = strings.TrimSpace(viper.GetString("jwt_public_key")) + publicKeyPath = normalizePathValue(viper.GetString("jwt_public_key")) } if privateKeyPath == "" { @@ -49,3 +49,13 @@ func resolveJWTKeyPathsForBootstrap(privateKeyPath, publicKeyPath string) (strin func runJWTBootstrap(privateKeyPath, publicKeyPath string, bitSize int, force bool) (config.JWTKeyBootstrapAction, error) { return config.BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, bitSize, force) } + +func normalizePathValue(value string) string { + value = strings.TrimSpace(value) + if len(value) >= 2 { + if (value[0] == '"' && value[len(value)-1] == '"') || (value[0] == '\'' && value[len(value)-1] == '\'') { + value = value[1 : len(value)-1] + } + } + return strings.TrimSpace(value) +} diff --git a/internal/config/jwt_bootstrap.go b/internal/config/jwt_bootstrap.go index 0f453994..e4f3381e 100644 --- a/internal/config/jwt_bootstrap.go +++ b/internal/config/jwt_bootstrap.go @@ -4,6 +4,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" "fmt" "os" "path/filepath" @@ -13,6 +14,7 @@ import ( type JWTKeyBootstrapAction string const ( + minimumJWTKeyBitSize = 2048 JWTKeyBootstrapNoop JWTKeyBootstrapAction = "noop" JWTKeyBootstrapGenerated JWTKeyBootstrapAction = "generated" JWTKeyBootstrapDerivedPublic JWTKeyBootstrapAction = "derived_public" @@ -26,8 +28,8 @@ const ( // - when only private exists and force=false: derives and writes the public key // - otherwise: generates a new keypair and writes both files func BootstrapJWTKeyPair(privateKeyPath, publicKeyPath string, bitSize int, force bool) (JWTKeyBootstrapAction, error) { - privateKeyPath = strings.TrimSpace(privateKeyPath) - publicKeyPath = strings.TrimSpace(publicKeyPath) + privateKeyPath = normalizeKeyPath(privateKeyPath) + publicKeyPath = normalizeKeyPath(publicKeyPath) if privateKeyPath == "" { return "", fmt.Errorf("private key path cannot be empty") @@ -35,8 +37,16 @@ func BootstrapJWTKeyPair(privateKeyPath, publicKeyPath string, bitSize int, forc if publicKeyPath == "" { return "", fmt.Errorf("public key path cannot be empty") } - if bitSize <= 0 { - return "", fmt.Errorf("bit size must be greater than 0") + if bitSize < minimumJWTKeyBitSize { + return "", fmt.Errorf("bit size must be at least %d", minimumJWTKeyBitSize) + } + + equal, err := keyPathsReferToSameLocation(privateKeyPath, publicKeyPath) + if err != nil { + return "", err + } + if equal { + return "", fmt.Errorf("private and public key paths must be different") } privateExists, err := fileExists(privateKeyPath) @@ -51,6 +61,9 @@ func BootstrapJWTKeyPair(privateKeyPath, publicKeyPath string, bitSize int, forc if !force { if privateExists && publicExists { + if err := validateExistingJWTKeyPair(privateKeyPath, publicKeyPath); err != nil { + return "", fmt.Errorf("existing JWT keypair is invalid or mismatched: %w; rerun with --force to regenerate", err) + } return JWTKeyBootstrapNoop, nil } @@ -100,27 +113,15 @@ func fileExists(path string) (bool, error) { } func writeRSAPrivateKey(path string, key *rsa.PrivateKey) error { - if err := ensureParentDirectory(path); err != nil { - return err - } - privateKeyPEM := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key), }) - if err := os.WriteFile(path, privateKeyPEM, 0o600); err != nil { - return fmt.Errorf("unable to write private key to %s: %w", path, err) - } - - return nil + return writePEMAtomically(path, privateKeyPEM, 0o600) } func writeRSAPublicKey(path string, key *rsa.PublicKey) error { - if err := ensureParentDirectory(path); err != nil { - return err - } - publicKeyDER, err := x509.MarshalPKIXPublicKey(key) if err != nil { return fmt.Errorf("unable to marshal public key: %w", err) @@ -131,11 +132,7 @@ func writeRSAPublicKey(path string, key *rsa.PublicKey) error { Bytes: publicKeyDER, }) - if err := os.WriteFile(path, publicKeyPEM, 0o644); err != nil { - return fmt.Errorf("unable to write public key to %s: %w", path, err) - } - - return nil + return writePEMAtomically(path, publicKeyPEM, 0o644) } func ensureParentDirectory(path string) error { @@ -150,3 +147,127 @@ func ensureParentDirectory(path string) error { return nil } + +func normalizeKeyPath(path string) string { + return filepath.Clean(stripQuotes(strings.TrimSpace(path))) +} + +func keyPathsReferToSameLocation(privateKeyPath, publicKeyPath string) (bool, error) { + privateAbs, err := filepath.Abs(privateKeyPath) + if err != nil { + return false, fmt.Errorf("failed to resolve private key path %s: %w", privateKeyPath, err) + } + publicAbs, err := filepath.Abs(publicKeyPath) + if err != nil { + return false, fmt.Errorf("failed to resolve public key path %s: %w", publicKeyPath, err) + } + + privateResolved := resolvePathIfExists(privateAbs) + publicResolved := resolvePathIfExists(publicAbs) + return privateResolved == publicResolved, nil +} + +func resolvePathIfExists(path string) string { + resolved, err := filepath.EvalSymlinks(path) + if err == nil { + return resolved + } + return path +} + +func validateExistingJWTKeyPair(privateKeyPath, publicKeyPath string) error { + privateKey, err := loadRSAPrivateKey(privateKeyPath) + if err != nil { + return err + } + publicKey, err := loadRSAPublicKey(publicKeyPath) + if err != nil { + return err + } + if !rsaPublicKeysEqual(&privateKey.PublicKey, publicKey) { + return fmt.Errorf("public key does not match private key") + } + return nil +} + +func rsaPublicKeysEqual(a, b *rsa.PublicKey) bool { + if a == nil || b == nil { + return false + } + return a.E == b.E && a.N.Cmp(b.N) == 0 +} + +func writePEMAtomically(path string, data []byte, mode os.FileMode) error { + if err := ensureParentDirectory(path); err != nil { + return err + } + if err := rejectSymlinkPath(path); err != nil { + return err + } + + dir := filepath.Dir(path) + base := filepath.Base(path) + + tmpFile, err := os.CreateTemp(dir, "."+base+".tmp-*") + if err != nil { + return fmt.Errorf("unable to create temp file for %s: %w", path, err) + } + + tmpPath := tmpFile.Name() + cleanup := true + defer func() { + if cleanup { + _ = os.Remove(tmpPath) + } + }() + + if err := tmpFile.Chmod(mode); err != nil { + _ = tmpFile.Close() + return fmt.Errorf("unable to set permissions on temp key file for %s: %w", path, err) + } + if _, err := tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + return fmt.Errorf("unable to write temp key file for %s: %w", path, err) + } + if err := tmpFile.Sync(); err != nil { + _ = tmpFile.Close() + return fmt.Errorf("unable to sync temp key file for %s: %w", path, err) + } + if err := tmpFile.Close(); err != nil { + return fmt.Errorf("unable to close temp key file for %s: %w", path, err) + } + + if err := os.Rename(tmpPath, path); err != nil { + return fmt.Errorf("unable to atomically write key file %s: %w", path, err) + } + + fileInfo, err := os.Stat(path) + if err != nil { + return fmt.Errorf("unable to stat key file %s after write: %w", path, err) + } + if fileInfo.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("refusing to use symlink key path %s", path) + } + if mode != 0 { + if err := os.Chmod(path, mode); err != nil { + return fmt.Errorf("unable to enforce permissions on key file %s: %w", path, err) + } + } + + cleanup = false + return nil +} + +func rejectSymlinkPath(path string) error { + fileInfo, err := os.Lstat(path) + if err == nil { + if fileInfo.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("refusing to write to symlink path %s", path) + } + return nil + } + if errors.Is(err, os.ErrNotExist) { + return nil + } + return fmt.Errorf("unable to inspect key path %s: %w", path, err) +} diff --git a/internal/config/jwt_bootstrap_test.go b/internal/config/jwt_bootstrap_test.go index 25960e0f..3bc76431 100644 --- a/internal/config/jwt_bootstrap_test.go +++ b/internal/config/jwt_bootstrap_test.go @@ -4,85 +4,106 @@ import ( "os" "path/filepath" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBootstrapJWTKeyPair_GenerateAndNoop(t *testing.T) { privateKeyPath := filepath.Join(t.TempDir(), "private.pem") publicKeyPath := filepath.Join(t.TempDir(), "public.pem") - action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, 1024, false) - if err != nil { - t.Fatalf("expected bootstrap to generate keys, got error: %v", err) - } - if action != JWTKeyBootstrapGenerated { - t.Fatalf("expected action %q, got %q", JWTKeyBootstrapGenerated, action) - } - - if _, err := loadRSAPrivateKey(privateKeyPath); err != nil { - t.Fatalf("expected generated private key to be readable, got error: %v", err) - } - if _, err := loadRSAPublicKey(publicKeyPath); err != nil { - t.Fatalf("expected generated public key to be readable, got error: %v", err) - } - - action, err = BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, 1024, false) - if err != nil { - t.Fatalf("expected bootstrap noop when both keys exist, got error: %v", err) - } - if action != JWTKeyBootstrapNoop { - t.Fatalf("expected action %q, got %q", JWTKeyBootstrapNoop, action) - } + action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, false) + require.NoError(t, err) + assert.Equal(t, JWTKeyBootstrapGenerated, action) + + _, err = loadRSAPrivateKey(privateKeyPath) + require.NoError(t, err) + _, err = loadRSAPublicKey(publicKeyPath) + require.NoError(t, err) + + action, err = BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, false) + require.NoError(t, err) + assert.Equal(t, JWTKeyBootstrapNoop, action) } func TestBootstrapJWTKeyPair_DerivePublicFromExistingPrivate(t *testing.T) { privateKeyPath := filepath.Join(t.TempDir(), "private.pem") publicKeyPath := filepath.Join(filepath.Dir(privateKeyPath), "public.pem") - if _, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, 1024, false); err != nil { - t.Fatalf("expected setup bootstrap to generate keys, got error: %v", err) - } + _, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, false) + require.NoError(t, err) - if err := os.Remove(publicKeyPath); err != nil { - t.Fatalf("expected to remove public key for test setup, got error: %v", err) - } + require.NoError(t, os.Remove(publicKeyPath)) - action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, 1024, false) - if err != nil { - t.Fatalf("expected bootstrap to derive missing public key, got error: %v", err) - } - if action != JWTKeyBootstrapDerivedPublic { - t.Fatalf("expected action %q, got %q", JWTKeyBootstrapDerivedPublic, action) - } - - if _, err := loadRSAPublicKey(publicKeyPath); err != nil { - t.Fatalf("expected derived public key to be readable, got error: %v", err) - } + action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, false) + require.NoError(t, err) + assert.Equal(t, JWTKeyBootstrapDerivedPublic, action) + + _, err = loadRSAPublicKey(publicKeyPath) + require.NoError(t, err) } func TestBootstrapJWTKeyPair_RegenerateWhenOnlyPublicExists(t *testing.T) { privateKeyPath := filepath.Join(t.TempDir(), "private.pem") publicKeyPath := filepath.Join(filepath.Dir(privateKeyPath), "public.pem") - if _, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, 1024, false); err != nil { - t.Fatalf("expected setup bootstrap to generate keys, got error: %v", err) - } + _, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, false) + require.NoError(t, err) + + require.NoError(t, os.Remove(privateKeyPath)) + + action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, false) + require.NoError(t, err) + assert.Equal(t, JWTKeyBootstrapRegenerated, action) + + _, err = loadRSAPrivateKey(privateKeyPath) + require.NoError(t, err) + _, err = loadRSAPublicKey(publicKeyPath) + require.NoError(t, err) +} - if err := os.Remove(privateKeyPath); err != nil { - t.Fatalf("expected to remove private key for test setup, got error: %v", err) - } +func TestBootstrapJWTKeyPair_RejectsSmallKeySizes(t *testing.T) { + privateKeyPath := filepath.Join(t.TempDir(), "private.pem") + publicKeyPath := filepath.Join(filepath.Dir(privateKeyPath), "public.pem") action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, 1024, false) - if err != nil { - t.Fatalf("expected bootstrap to regenerate missing private key, got error: %v", err) - } - if action != JWTKeyBootstrapRegenerated { - t.Fatalf("expected action %q, got %q", JWTKeyBootstrapRegenerated, action) - } - - if _, err := loadRSAPrivateKey(privateKeyPath); err != nil { - t.Fatalf("expected regenerated private key to be readable, got error: %v", err) - } - if _, err := loadRSAPublicKey(publicKeyPath); err != nil { - t.Fatalf("expected regenerated public key to be readable, got error: %v", err) - } + require.Error(t, err) + assert.Empty(t, action) + assert.Contains(t, err.Error(), "bit size must be at least") +} + +func TestBootstrapJWTKeyPair_RejectsSamePath(t *testing.T) { + samePath := filepath.Join(t.TempDir(), "jwt.pem") + + action, err := BootstrapJWTKeyPair(samePath, samePath, minimumJWTKeyBitSize, false) + require.Error(t, err) + assert.Empty(t, action) + assert.Contains(t, err.Error(), "must be different") +} + +func TestBootstrapJWTKeyPair_RejectsMismatchedExistingPair(t *testing.T) { + privateOne := filepath.Join(t.TempDir(), "private1.pem") + publicOne := filepath.Join(filepath.Dir(privateOne), "public1.pem") + _, err := BootstrapJWTKeyPair(privateOne, publicOne, minimumJWTKeyBitSize, false) + require.NoError(t, err) + + privateTwo := filepath.Join(filepath.Dir(privateOne), "private2.pem") + publicTwo := filepath.Join(filepath.Dir(privateOne), "public2.pem") + _, err = BootstrapJWTKeyPair(privateTwo, publicTwo, minimumJWTKeyBitSize, false) + require.NoError(t, err) + + mismatchedPrivatePath := filepath.Join(filepath.Dir(privateOne), "mismatched_private.pem") + mismatchedPublicPath := filepath.Join(filepath.Dir(privateOne), "mismatched_public.pem") + privateData, err := os.ReadFile(privateOne) + require.NoError(t, err) + publicData, err := os.ReadFile(publicTwo) + require.NoError(t, err) + require.NoError(t, os.WriteFile(mismatchedPrivatePath, privateData, 0o600)) + require.NoError(t, os.WriteFile(mismatchedPublicPath, publicData, 0o644)) + + action, err := BootstrapJWTKeyPair(mismatchedPrivatePath, mismatchedPublicPath, minimumJWTKeyBitSize, false) + require.Error(t, err) + assert.Empty(t, action) + assert.Contains(t, err.Error(), "invalid or mismatched") } From 624c87a4ef178d994d209edaa6659d710d0c2450 Mon Sep 17 00:00:00 2001 From: Gustavo Carvalho Date: Tue, 17 Mar 2026 18:10:24 -0300 Subject: [PATCH 3/5] Handle empty JWT key env vars in bootstrap path resolution --- .env.example | 3 ++- cmd/jwt_bootstrap.go | 17 ++++++++++++++++- internal/config/jwt_bootstrap.go | 6 +++++- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/.env.example b/.env.example index 4dbd9664..b9e964b2 100644 --- a/.env.example +++ b/.env.example @@ -6,7 +6,8 @@ CCF_DB_CONNECTION="host=db user=postgres password=postgres dbname=ccf port=5432 CCF_JWT_SECRET="some-secret" CCF_JWT_PRIVATE_KEY=private.pem CCF_JWT_PUBLIC_KEY=public.pem -# Unset both CCF_JWT_PRIVATE_KEY and CCF_JWT_PUBLIC_KEY to use in-memory key generation. +# To use in-memory key generation, remove/comment out both CCF_JWT_PRIVATE_KEY and CCF_JWT_PUBLIC_KEY lines. +# Do not set these variables to empty strings. CCF_API_ALLOWED_ORIGINS="http://localhost:3000,http://localhost:8000" CCF_RISK_CONFIG="risk.yaml" diff --git a/cmd/jwt_bootstrap.go b/cmd/jwt_bootstrap.go index 4ab7d0c4..1ad59208 100644 --- a/cmd/jwt_bootstrap.go +++ b/cmd/jwt_bootstrap.go @@ -1,6 +1,7 @@ package cmd import ( + "errors" "strings" "github.com/compliance-framework/api/internal/config" @@ -10,11 +11,18 @@ import ( const defaultJWTKeyBitSize = 2048 func bootstrapConfiguredJWTKeys(bitSize int, force bool) (config.JWTKeyBootstrapAction, string, string, bool, error) { + privateKeyConfigured := viper.IsSet("jwt_private_key") + publicKeyConfigured := viper.IsSet("jwt_public_key") + + if !privateKeyConfigured || !publicKeyConfigured { + return "", "", "", false, nil + } + privateKeyPath := normalizePathValue(viper.GetString("jwt_private_key")) publicKeyPath := normalizePathValue(viper.GetString("jwt_public_key")) if privateKeyPath == "" || publicKeyPath == "" { - return "", "", "", false, nil + return "", privateKeyPath, publicKeyPath, true, configErrorForEmptyJWTKeyPath() } action, err := runJWTBootstrap(privateKeyPath, publicKeyPath, bitSize, force) @@ -59,3 +67,10 @@ func normalizePathValue(value string) string { } return strings.TrimSpace(value) } + +func configErrorForEmptyJWTKeyPath() error { + return errors.New( + "CCF_JWT_PRIVATE_KEY and CCF_JWT_PUBLIC_KEY are set but one or both are empty. " + + "Set both to non-empty paths, or remove/comment out both to use in-memory key generation", + ) +} diff --git a/internal/config/jwt_bootstrap.go b/internal/config/jwt_bootstrap.go index e4f3381e..758b7ccc 100644 --- a/internal/config/jwt_bootstrap.go +++ b/internal/config/jwt_bootstrap.go @@ -149,7 +149,11 @@ func ensureParentDirectory(path string) error { } func normalizeKeyPath(path string) string { - return filepath.Clean(stripQuotes(strings.TrimSpace(path))) + path = stripQuotes(strings.TrimSpace(path)) + if path == "" { + return "" + } + return filepath.Clean(path) } func keyPathsReferToSameLocation(privateKeyPath, publicKeyPath string) (bool, error) { From 7814393748d7e5a588bfb09ec1102fe61d3724ba Mon Sep 17 00:00:00 2001 From: Gustavo Carvalho Date: Tue, 17 Mar 2026 18:26:22 -0300 Subject: [PATCH 4/5] Harden JWT bootstrap with locking and weak key validation --- internal/config/jwt_bootstrap.go | 189 ++++++++++++++++++++------ internal/config/jwt_bootstrap_test.go | 72 ++++++++++ 2 files changed, 223 insertions(+), 38 deletions(-) diff --git a/internal/config/jwt_bootstrap.go b/internal/config/jwt_bootstrap.go index 758b7ccc..b146105b 100644 --- a/internal/config/jwt_bootstrap.go +++ b/internal/config/jwt_bootstrap.go @@ -2,6 +2,7 @@ package config import ( "crypto/rsa" + "crypto/sha256" "crypto/x509" "encoding/pem" "errors" @@ -9,18 +10,22 @@ import ( "os" "path/filepath" "strings" + "time" ) -type JWTKeyBootstrapAction string - const ( - minimumJWTKeyBitSize = 2048 - JWTKeyBootstrapNoop JWTKeyBootstrapAction = "noop" - JWTKeyBootstrapGenerated JWTKeyBootstrapAction = "generated" - JWTKeyBootstrapDerivedPublic JWTKeyBootstrapAction = "derived_public" - JWTKeyBootstrapRegenerated JWTKeyBootstrapAction = "regenerated" + minimumJWTKeyBitSize = 2048 + jwtBootstrapLockWaitTimeout = 30 * time.Second + jwtBootstrapLockRetryInterval = 100 * time.Millisecond + jwtBootstrapLockStaleThreshold = 2 * time.Minute + JWTKeyBootstrapNoop JWTKeyBootstrapAction = "noop" + JWTKeyBootstrapGenerated JWTKeyBootstrapAction = "generated" + JWTKeyBootstrapDerivedPublic JWTKeyBootstrapAction = "derived_public" + JWTKeyBootstrapRegenerated JWTKeyBootstrapAction = "regenerated" ) +type JWTKeyBootstrapAction string + // BootstrapJWTKeyPair ensures a matching JWT RSA keypair exists at the given paths. // // Behavior: @@ -49,56 +54,158 @@ func BootstrapJWTKeyPair(privateKeyPath, publicKeyPath string, bitSize int, forc return "", fmt.Errorf("private and public key paths must be different") } - privateExists, err := fileExists(privateKeyPath) - if err != nil { - return "", fmt.Errorf("failed to check private key path %s: %w", privateKeyPath, err) - } - - publicExists, err := fileExists(publicKeyPath) - if err != nil { - return "", fmt.Errorf("failed to check public key path %s: %w", publicKeyPath, err) - } + return withJWTBootstrapLock(privateKeyPath, publicKeyPath, func() (JWTKeyBootstrapAction, error) { + privateExists, err := fileExists(privateKeyPath) + if err != nil { + return "", fmt.Errorf("failed to check private key path %s: %w", privateKeyPath, err) + } - if !force { - if privateExists && publicExists { - if err := validateExistingJWTKeyPair(privateKeyPath, publicKeyPath); err != nil { - return "", fmt.Errorf("existing JWT keypair is invalid or mismatched: %w; rerun with --force to regenerate", err) - } - return JWTKeyBootstrapNoop, nil + publicExists, err := fileExists(publicKeyPath) + if err != nil { + return "", fmt.Errorf("failed to check public key path %s: %w", publicKeyPath, err) } - if privateExists && !publicExists { - privateKey, err := loadRSAPrivateKey(privateKeyPath) - if err != nil { - return "", fmt.Errorf("failed to load existing private key from %s: %w", privateKeyPath, err) + if !force { + if privateExists && publicExists { + if err := validateExistingJWTKeyPair(privateKeyPath, publicKeyPath, bitSize); err != nil { + return "", fmt.Errorf("existing JWT keypair is invalid or mismatched: %w; rerun with --force to regenerate", err) + } + return JWTKeyBootstrapNoop, nil } - if err := writeRSAPublicKey(publicKeyPath, &privateKey.PublicKey); err != nil { - return "", err + if privateExists && !publicExists { + privateKey, err := loadRSAPrivateKey(privateKeyPath) + if err != nil { + return "", fmt.Errorf("failed to load existing private key from %s: %w", privateKeyPath, err) + } + if privateKey.N.BitLen() < bitSize { + return "", fmt.Errorf( + "existing private key size %d is below required minimum %d; rerun with --force to regenerate", + privateKey.N.BitLen(), + bitSize, + ) + } + + if err := writeRSAPublicKey(publicKeyPath, &privateKey.PublicKey); err != nil { + return "", err + } + + return JWTKeyBootstrapDerivedPublic, nil } + } - return JWTKeyBootstrapDerivedPublic, nil + privateKey, publicKey, err := GenerateKeyPair(bitSize) + if err != nil { + return "", err + } + + if err := writeRSAPrivateKey(privateKeyPath, privateKey); err != nil { + return "", err } - } - privateKey, publicKey, err := GenerateKeyPair(bitSize) + if err := writeRSAPublicKey(publicKeyPath, publicKey); err != nil { + return "", err + } + + if privateExists || publicExists { + return JWTKeyBootstrapRegenerated, nil + } + + return JWTKeyBootstrapGenerated, nil + }) +} + +func withJWTBootstrapLock(privateKeyPath, publicKeyPath string, fn func() (JWTKeyBootstrapAction, error)) (JWTKeyBootstrapAction, error) { + lockPath, err := jwtBootstrapLockPath(privateKeyPath, publicKeyPath) if err != nil { return "", err } - if err := writeRSAPrivateKey(privateKeyPath, privateKey); err != nil { + release, err := acquireJWTBootstrapLock(lockPath, jwtBootstrapLockWaitTimeout) + if err != nil { return "", err } + defer release() - if err := writeRSAPublicKey(publicKeyPath, publicKey); err != nil { - return "", err + return fn() +} + +func jwtBootstrapLockPath(privateKeyPath, publicKeyPath string) (string, error) { + privateAbs, err := filepath.Abs(privateKeyPath) + if err != nil { + return "", fmt.Errorf("failed to resolve private key path %s: %w", privateKeyPath, err) + } + publicAbs, err := filepath.Abs(publicKeyPath) + if err != nil { + return "", fmt.Errorf("failed to resolve public key path %s: %w", publicKeyPath, err) } - if privateExists || publicExists { - return JWTKeyBootstrapRegenerated, nil + lockKey := privateAbs + "\n" + publicAbs + lockHash := sha256.Sum256([]byte(lockKey)) + return filepath.Join(os.TempDir(), fmt.Sprintf("ccf-jwt-bootstrap-%x.lock", lockHash[:])), nil +} + +func acquireJWTBootstrapLock(lockPath string, timeout time.Duration) (func(), error) { + deadline := time.Now().Add(timeout) + + for { + lockFile, err := os.OpenFile(lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o600) + if err == nil { + if _, writeErr := lockFile.WriteString(fmt.Sprintf("pid=%d time=%s\n", os.Getpid(), time.Now().UTC().Format(time.RFC3339Nano))); writeErr != nil { + _ = lockFile.Close() + _ = os.Remove(lockPath) + return nil, fmt.Errorf("failed to initialize bootstrap lock %s: %w", lockPath, writeErr) + } + if closeErr := lockFile.Close(); closeErr != nil { + _ = os.Remove(lockPath) + return nil, fmt.Errorf("failed to close bootstrap lock %s: %w", lockPath, closeErr) + } + + return func() { + _ = os.Remove(lockPath) + }, nil + } + + if !errors.Is(err, os.ErrExist) { + return nil, fmt.Errorf("failed to acquire bootstrap lock %s: %w", lockPath, err) + } + + staleRemoved, staleErr := removeStaleBootstrapLock(lockPath) + if staleErr != nil { + return nil, staleErr + } + if staleRemoved { + continue + } + + if time.Now().After(deadline) { + return nil, fmt.Errorf("timed out waiting for bootstrap lock %s", lockPath) + } + + time.Sleep(jwtBootstrapLockRetryInterval) + } +} + +func removeStaleBootstrapLock(lockPath string) (bool, error) { + lockInfo, err := os.Stat(lockPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return false, nil + } + return false, fmt.Errorf("failed to inspect bootstrap lock %s: %w", lockPath, err) } - return JWTKeyBootstrapGenerated, nil + if time.Since(lockInfo.ModTime()) < jwtBootstrapLockStaleThreshold { + return false, nil + } + + if err := os.Remove(lockPath); err != nil { + if errors.Is(err, os.ErrNotExist) { + return true, nil + } + return false, fmt.Errorf("failed to remove stale bootstrap lock %s: %w", lockPath, err) + } + return true, nil } func fileExists(path string) (bool, error) { @@ -179,15 +286,21 @@ func resolvePathIfExists(path string) string { return path } -func validateExistingJWTKeyPair(privateKeyPath, publicKeyPath string) error { +func validateExistingJWTKeyPair(privateKeyPath, publicKeyPath string, bitSize int) error { privateKey, err := loadRSAPrivateKey(privateKeyPath) if err != nil { return err } + if privateKey.N.BitLen() < bitSize { + return fmt.Errorf("private key size %d is below required minimum %d", privateKey.N.BitLen(), bitSize) + } publicKey, err := loadRSAPublicKey(publicKeyPath) if err != nil { return err } + if publicKey.N.BitLen() < bitSize { + return fmt.Errorf("public key size %d is below required minimum %d", publicKey.N.BitLen(), bitSize) + } if !rsaPublicKeysEqual(&privateKey.PublicKey, publicKey) { return fmt.Errorf("public key does not match private key") } diff --git a/internal/config/jwt_bootstrap_test.go b/internal/config/jwt_bootstrap_test.go index 3bc76431..79d9056f 100644 --- a/internal/config/jwt_bootstrap_test.go +++ b/internal/config/jwt_bootstrap_test.go @@ -3,6 +3,7 @@ package config import ( "os" "path/filepath" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -107,3 +108,74 @@ func TestBootstrapJWTKeyPair_RejectsMismatchedExistingPair(t *testing.T) { assert.Empty(t, action) assert.Contains(t, err.Error(), "invalid or mismatched") } + +func TestBootstrapJWTKeyPair_RejectsWeakExistingPair(t *testing.T) { + privateKeyPath := filepath.Join(t.TempDir(), "private.pem") + publicKeyPath := filepath.Join(filepath.Dir(privateKeyPath), "public.pem") + + weakPrivateKey, weakPublicKey, err := GenerateKeyPair(1024) + require.NoError(t, err) + require.NoError(t, writeRSAPrivateKey(privateKeyPath, weakPrivateKey)) + require.NoError(t, writeRSAPublicKey(publicKeyPath, weakPublicKey)) + + action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, false) + require.Error(t, err) + assert.Empty(t, action) + assert.Contains(t, err.Error(), "below required minimum") +} + +func TestBootstrapJWTKeyPair_RejectsWeakPrivateWhenDerivingPublic(t *testing.T) { + privateKeyPath := filepath.Join(t.TempDir(), "private.pem") + publicKeyPath := filepath.Join(filepath.Dir(privateKeyPath), "public.pem") + + weakPrivateKey, _, err := GenerateKeyPair(1024) + require.NoError(t, err) + require.NoError(t, writeRSAPrivateKey(privateKeyPath, weakPrivateKey)) + + action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, false) + require.Error(t, err) + assert.Empty(t, action) + assert.Contains(t, err.Error(), "below required minimum") + + action, err = BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, true) + require.NoError(t, err) + assert.Equal(t, JWTKeyBootstrapRegenerated, action) + + regeneratedPrivateKey, err := loadRSAPrivateKey(privateKeyPath) + require.NoError(t, err) + assert.GreaterOrEqual(t, regeneratedPrivateKey.N.BitLen(), minimumJWTKeyBitSize) +} + +func TestBootstrapJWTKeyPair_ConcurrentCallsUseConsistentKeyPair(t *testing.T) { + privateKeyPath := filepath.Join(t.TempDir(), "private.pem") + publicKeyPath := filepath.Join(filepath.Dir(privateKeyPath), "public.pem") + + const workers = 8 + start := make(chan struct{}) + errCh := make(chan error, workers) + var wg sync.WaitGroup + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + _, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, false) + errCh <- err + }() + } + + close(start) + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } + + privateKey, err := loadRSAPrivateKey(privateKeyPath) + require.NoError(t, err) + publicKey, err := loadRSAPublicKey(publicKeyPath) + require.NoError(t, err) + assert.True(t, rsaPublicKeysEqual(&privateKey.PublicKey, publicKey)) +} From 8750d421f6f417def1f11aa526e4d20244cc64ec Mon Sep 17 00:00:00 2001 From: Gustavo Carvalho Date: Tue, 17 Mar 2026 18:41:51 -0300 Subject: [PATCH 5/5] Add JWT bootstrap rollback safeguards and failure tests --- internal/config/jwt_bootstrap.go | 53 +++++++++++++++++++++++++-- internal/config/jwt_bootstrap_test.go | 39 ++++++++++++++++++++ 2 files changed, 89 insertions(+), 3 deletions(-) diff --git a/internal/config/jwt_bootstrap.go b/internal/config/jwt_bootstrap.go index b146105b..e548c5fa 100644 --- a/internal/config/jwt_bootstrap.go +++ b/internal/config/jwt_bootstrap.go @@ -26,6 +26,12 @@ const ( type JWTKeyBootstrapAction string +type fileBackup struct { + exists bool + data []byte + mode os.FileMode +} + // BootstrapJWTKeyPair ensures a matching JWT RSA keypair exists at the given paths. // // Behavior: @@ -99,12 +105,20 @@ func BootstrapJWTKeyPair(privateKeyPath, publicKeyPath string, bitSize int, forc return "", err } + privateBackup, err := backupExistingFile(privateKeyPath) + if err != nil { + return "", fmt.Errorf("failed to backup existing private key %s: %w", privateKeyPath, err) + } + if err := writeRSAPrivateKey(privateKeyPath, privateKey); err != nil { return "", err } if err := writeRSAPublicKey(publicKeyPath, publicKey); err != nil { - return "", err + if rollbackErr := restoreOrRemoveFile(privateKeyPath, privateBackup); rollbackErr != nil { + return "", fmt.Errorf("failed to write public key: %w (private key rollback failed: %v)", err, rollbackErr) + } + return "", fmt.Errorf("failed to write public key: %w", err) } if privateExists || publicExists { @@ -151,7 +165,7 @@ func acquireJWTBootstrapLock(lockPath string, timeout time.Duration) (func(), er for { lockFile, err := os.OpenFile(lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o600) if err == nil { - if _, writeErr := lockFile.WriteString(fmt.Sprintf("pid=%d time=%s\n", os.Getpid(), time.Now().UTC().Format(time.RFC3339Nano))); writeErr != nil { + if _, writeErr := fmt.Fprintf(lockFile, "pid=%d time=%s\n", os.Getpid(), time.Now().UTC().Format(time.RFC3339Nano)); writeErr != nil { _ = lockFile.Close() _ = os.Remove(lockPath) return nil, fmt.Errorf("failed to initialize bootstrap lock %s: %w", lockPath, writeErr) @@ -219,6 +233,39 @@ func fileExists(path string) (bool, error) { return false, err } +func backupExistingFile(path string) (fileBackup, error) { + fileInfo, err := os.Stat(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return fileBackup{exists: false}, nil + } + return fileBackup{}, err + } + + data, err := os.ReadFile(path) + if err != nil { + return fileBackup{}, err + } + + return fileBackup{ + exists: true, + data: data, + mode: fileInfo.Mode().Perm(), + }, nil +} + +func restoreOrRemoveFile(path string, backup fileBackup) error { + if !backup.exists { + err := os.Remove(path) + if err == nil || errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + + return writePEMAtomically(path, backup.data, backup.mode) +} + func writeRSAPrivateKey(path string, key *rsa.PrivateKey) error { privateKeyPEM := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", @@ -358,7 +405,7 @@ func writePEMAtomically(path string, data []byte, mode os.FileMode) error { return fmt.Errorf("unable to atomically write key file %s: %w", path, err) } - fileInfo, err := os.Stat(path) + fileInfo, err := os.Lstat(path) if err != nil { return fmt.Errorf("unable to stat key file %s after write: %w", path, err) } diff --git a/internal/config/jwt_bootstrap_test.go b/internal/config/jwt_bootstrap_test.go index 79d9056f..094816cc 100644 --- a/internal/config/jwt_bootstrap_test.go +++ b/internal/config/jwt_bootstrap_test.go @@ -179,3 +179,42 @@ func TestBootstrapJWTKeyPair_ConcurrentCallsUseConsistentKeyPair(t *testing.T) { require.NoError(t, err) assert.True(t, rsaPublicKeysEqual(&privateKey.PublicKey, publicKey)) } + +func TestBootstrapJWTKeyPair_RollsBackPrivateKeyWhenPublicWriteFailsOnRegenerate(t *testing.T) { + privateKeyPath := filepath.Join(t.TempDir(), "private.pem") + publicKeyPath := filepath.Join(filepath.Dir(privateKeyPath), "public.pem") + + _, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, false) + require.NoError(t, err) + + originalPrivateData, err := os.ReadFile(privateKeyPath) + require.NoError(t, err) + + require.NoError(t, os.Remove(publicKeyPath)) + require.NoError(t, os.Symlink(filepath.Join(filepath.Dir(privateKeyPath), "public-target.pem"), publicKeyPath)) + + action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, true) + require.Error(t, err) + assert.Empty(t, action) + assert.Contains(t, err.Error(), "failed to write public key") + + restoredPrivateData, err := os.ReadFile(privateKeyPath) + require.NoError(t, err) + assert.Equal(t, originalPrivateData, restoredPrivateData) +} + +func TestBootstrapJWTKeyPair_RemovesNewPrivateKeyWhenPublicWriteFailsOnCreate(t *testing.T) { + privateKeyPath := filepath.Join(t.TempDir(), "private.pem") + publicKeyPath := filepath.Join(filepath.Dir(privateKeyPath), "public.pem") + + require.NoError(t, os.Symlink(filepath.Join(filepath.Dir(privateKeyPath), "public-target.pem"), publicKeyPath)) + + action, err := BootstrapJWTKeyPair(privateKeyPath, publicKeyPath, minimumJWTKeyBitSize, true) + require.Error(t, err) + assert.Empty(t, action) + assert.Contains(t, err.Error(), "failed to write public key") + + _, statErr := os.Stat(privateKeyPath) + require.Error(t, statErr) + assert.True(t, os.IsNotExist(statErr)) +}