From c7084c34b36bb9be26875a58d044e03db85182d7 Mon Sep 17 00:00:00 2001 From: Joshua Gilman Date: Tue, 26 May 2026 14:25:20 -0700 Subject: [PATCH] refactor(store/postgres): split adapter by domain and extract shared helpers Break the 1584-line store.go into per-domain files (principal, role, provisioning, identity, oidc, token) plus shared helpers tx.go, errors.go, codec.go, clone.go, validation.go. CreateProvisioningRule, UpdateProvisioningRule, ProvisionIdentity, and CreateRegistration now use a generic withTx helper instead of duplicated Begin/Rollback/Commit blocks. ProvisionIdentity's unique-violation race recovery uses a package-private sentinel so the post-rollback read happens after withTx unwinds; LinkIdentity and the passkey-registration link path share a single findIdentityLink. The public surface stays exactly NewStore/*Store/Migrate. Test files mirror the new layout with per-domain TestStoreSatisfiesXContracts assertion bundles; store_integration_test.go is untouched. Every unexported helper gains a one-line godoc; inline comments name the migration advisory lock, transaction scopes, the credential-binding-pinned WHERE clause as a security invariant, the malformed-hash rejection in findToken, and the race-recovery path in ProvisionIdentity. Co-Authored-By: Claude Opus 4.7 (1M context) --- store/postgres/clone.go | 79 ++ store/postgres/codec.go | 94 ++ store/postgres/errors.go | 17 + store/postgres/identity.go | 325 ++++++ store/postgres/identity_test.go | 13 + store/postgres/migrate.go | 8 +- store/postgres/oidc.go | 111 ++ store/postgres/oidc_test.go | 12 + store/postgres/passkey.go | 187 ++-- store/postgres/passkey_test.go | 11 + store/postgres/principal.go | 167 +++ store/postgres/principal_test.go | 13 + store/postgres/provisioning.go | 401 +++++++ store/postgres/provisioning_test.go | 15 + store/postgres/role.go | 262 +++++ store/postgres/role_test.go | 16 + store/postgres/store.go | 1554 +-------------------------- store/postgres/store_test.go | 30 - store/postgres/token.go | 249 +++++ store/postgres/token_test.go | 12 + store/postgres/tx.go | 47 + store/postgres/validation.go | 47 + 22 files changed, 2017 insertions(+), 1653 deletions(-) create mode 100644 store/postgres/clone.go create mode 100644 store/postgres/codec.go create mode 100644 store/postgres/errors.go create mode 100644 store/postgres/identity.go create mode 100644 store/postgres/identity_test.go create mode 100644 store/postgres/oidc.go create mode 100644 store/postgres/oidc_test.go create mode 100644 store/postgres/passkey_test.go create mode 100644 store/postgres/principal.go create mode 100644 store/postgres/principal_test.go create mode 100644 store/postgres/provisioning.go create mode 100644 store/postgres/provisioning_test.go create mode 100644 store/postgres/role.go create mode 100644 store/postgres/role_test.go create mode 100644 store/postgres/token.go create mode 100644 store/postgres/token_test.go create mode 100644 store/postgres/tx.go create mode 100644 store/postgres/validation.go diff --git a/store/postgres/clone.go b/store/postgres/clone.go new file mode 100644 index 0000000..9b7c5f3 --- /dev/null +++ b/store/postgres/clone.go @@ -0,0 +1,79 @@ +package postgres + +import ( + "maps" + + "github.com/meigma/authkit" + "github.com/meigma/authkit/oidc" +) + +// cloneAttributes returns an independent copy of attrs, or nil for an empty +// input. +func cloneAttributes(attrs map[string]any) map[string]any { + if len(attrs) == 0 { + return nil + } + + cloned := make(map[string]any, len(attrs)) + maps.Copy(cloned, attrs) + + return cloned +} + +// cloneStrings returns an independent copy of values, or nil for an empty +// input. +func cloneStrings(values []string) []string { + if len(values) == 0 { + return nil + } + + cloned := make([]string, len(values)) + copy(cloned, values) + + return cloned +} + +// cloneClaimPath returns an independent copy of path. +func cloneClaimPath(path authkit.ClaimPath) authkit.ClaimPath { + if len(path) == 0 { + return nil + } + + cloned := make(authkit.ClaimPath, len(path)) + copy(cloned, path) + + return cloned +} + +// cloneClaimPaths returns an independent copy of paths with every inner +// ClaimPath copied as well. +func cloneClaimPaths(paths []authkit.ClaimPath) []authkit.ClaimPath { + if len(paths) == 0 { + return nil + } + + cloned := make([]authkit.ClaimPath, len(paths)) + for i, path := range paths { + cloned[i] = cloneClaimPath(path) + } + + return cloned +} + +// cloneProvider returns a copy of provider with independent Audiences, +// SupportedSigningAlgorithms, and ForwardedClaims slices. +func cloneProvider(provider oidc.Provider) oidc.Provider { + provider.Audiences = cloneStrings(provider.Audiences) + provider.SupportedSigningAlgorithms = cloneStrings(provider.SupportedSigningAlgorithms) + provider.ForwardedClaims = cloneClaimPaths(provider.ForwardedClaims) + + return provider +} + +// cloneProvisioningRule returns a copy of rule with an independent +// AssignRoleIDs slice. +func cloneProvisioningRule(rule authkit.ProvisioningRule) authkit.ProvisioningRule { + rule.AssignRoleIDs = cloneStrings(rule.AssignRoleIDs) + + return rule +} diff --git a/store/postgres/codec.go b/store/postgres/codec.go new file mode 100644 index 0000000..abb8b64 --- /dev/null +++ b/store/postgres/codec.go @@ -0,0 +1,94 @@ +package postgres + +import ( + "encoding/json" + "fmt" + + "github.com/go-webauthn/webauthn/webauthn" + + "github.com/meigma/authkit" +) + +// encodeAttributes returns the JSON encoding of attrs for storage as a JSONB +// column. An empty or nil map encodes to the empty string so call sites can +// pass the result through nullif(... , empty string) and cast to jsonb to +// normalize NULL. +func encodeAttributes(attrs map[string]any) (string, error) { + if len(attrs) == 0 { + return "", nil + } + + encoded, err := json.Marshal(attrs) + if err != nil { + return "", fmt.Errorf("postgres: encode principal attributes: %w", err) + } + + return string(encoded), nil +} + +// decodeAttributes parses a principal-attributes JSONB column into a map. +// Empty, "null", or `{}` payloads return nil so principals with no +// attributes carry no allocated map. +func decodeAttributes(encoded string) (map[string]any, error) { + if encoded == "" || encoded == "null" { + //nolint:nilnil // Nil attributes are the normalized zero value for principals. + return nil, nil + } + + var attrs map[string]any + if err := json.Unmarshal([]byte(encoded), &attrs); err != nil { + return nil, fmt.Errorf("postgres: decode principal attributes: %w", err) + } + if len(attrs) == 0 { + //nolint:nilnil // Nil attributes are the normalized zero value for principals. + return nil, nil + } + + return attrs, nil +} + +// encodeClaimPaths returns the JSON encoding of paths for storage as a JSONB +// column. A nil or empty input encodes to the literal "[]" so the column is +// never NULL. +func encodeClaimPaths(paths []authkit.ClaimPath) (string, error) { + if len(paths) == 0 { + return "[]", nil + } + + encoded, err := json.Marshal(paths) + if err != nil { + return "", fmt.Errorf("postgres: encode claim paths: %w", err) + } + + return string(encoded), nil +} + +// decodeClaimPaths parses a JSONB-encoded slice of claim paths. Empty or +// "null" payloads return nil. The returned slice is independent of any +// internal buffer. +func decodeClaimPaths(encoded string) ([]authkit.ClaimPath, error) { + if encoded == "" || encoded == "null" { + return nil, nil + } + + var paths []authkit.ClaimPath + if err := json.Unmarshal([]byte(encoded), &paths); err != nil { + return nil, fmt.Errorf("postgres: decode claim paths: %w", err) + } + if len(paths) == 0 { + return nil, nil + } + + return cloneClaimPaths(paths), nil +} + +// encodeWebAuthnCredential returns the JSON encoding of credential for +// storage as a JSONB column. The full upstream record is preserved. +func encodeWebAuthnCredential(credential webauthn.Credential) (string, error) { + encoded, err := json.Marshal(credential) + if err != nil { + return "", fmt.Errorf("postgres: encode passkey credential: %w", err) + } + + return string(encoded), nil +} diff --git a/store/postgres/errors.go b/store/postgres/errors.go new file mode 100644 index 0000000..a4fd7b5 --- /dev/null +++ b/store/postgres/errors.go @@ -0,0 +1,17 @@ +package postgres + +import ( + "errors" + + "github.com/jackc/pgx/v5/pgconn" +) + +// isPostgresCode reports whether err is a pgx-wrapped PostgreSQL error with +// the supplied sqlstate code. Use with the package-level violation +// constants (`uniqueViolation`, `foreignKeyViolation`) rather than passing +// raw sqlstate strings at call sites. +func isPostgresCode(err error, code string) bool { + var pgErr *pgconn.PgError + + return errors.As(err, &pgErr) && pgErr.Code == code +} diff --git a/store/postgres/identity.go b/store/postgres/identity.go new file mode 100644 index 0000000..18457b3 --- /dev/null +++ b/store/postgres/identity.go @@ -0,0 +1,325 @@ +package postgres + +import ( + "context" + "errors" + "fmt" + + "github.com/jackc/pgx/v5" + + "github.com/meigma/authkit" +) + +// errProvisionIdentityConflict is an internal sentinel used by +// ProvisionIdentity to signal a unique-violation race between concurrent +// provisioners. The conflict is recovered by reading the winning row from +// the pool after the transaction has been rolled back; the sentinel never +// escapes this package. +var errProvisionIdentityConflict = errors.New("postgres: provision identity link conflict") + +// LinkIdentity links an external identity to an existing principal. Linking +// the same identity to the same principal twice is idempotent; linking the +// same identity to a different principal returns an error so a credential +// rebind never silently swaps principals. +func (s *Store) LinkIdentity( + ctx context.Context, + req authkit.LinkIdentityRequest, +) (authkit.ExternalIdentity, error) { + if err := ctx.Err(); err != nil { + return authkit.ExternalIdentity{}, err + } + if req.Provider == "" { + return authkit.ExternalIdentity{}, errors.New("postgres: provider is required") + } + if req.Subject == "" { + return authkit.ExternalIdentity{}, errors.New("postgres: subject is required") + } + if req.PrincipalID == "" { + return authkit.ExternalIdentity{}, errors.New("postgres: principal ID is required") + } + + link, err := findIdentityLink(ctx, s.pool, req.Provider, req.Subject) + if err == nil { + // Cross-principal rebind is rejected to keep external identities + // pinned to one principal; an admin must unlink first. + if link.PrincipalID == req.PrincipalID { + return link, nil + } + + return authkit.ExternalIdentity{}, fmt.Errorf( + "postgres: identity %q/%q is already linked to principal %q", + req.Provider, + req.Subject, + link.PrincipalID, + ) + } + if !errors.Is(err, pgx.ErrNoRows) { + return authkit.ExternalIdentity{}, fmt.Errorf("postgres: find identity link: %w", err) + } + + link = authkit.ExternalIdentity(req) + if _, err := s.pool.Exec( + ctx, + `insert into authkit_external_identities (provider, subject, principal_id) + values ($1, $2, $3)`, + link.Provider, + link.Subject, + link.PrincipalID, + ); err != nil { + // A unique-violation here means a concurrent writer beat us to the + // insert; resolve by re-reading and re-checking the binding. + if isPostgresCode(err, uniqueViolation) { + return s.resolveIdentityLinkConflict(ctx, req) + } + if isPostgresCode(err, foreignKeyViolation) { + return authkit.ExternalIdentity{}, fmt.Errorf( + "postgres: principal %q does not exist", + req.PrincipalID, + ) + } + + return authkit.ExternalIdentity{}, fmt.Errorf("postgres: link identity: %w", err) + } + + return link, nil +} + +// ResolveIdentity returns the principal linked to identity, or wraps +// `authkit.ErrUnresolvedIdentity` when the identity is missing required +// fields or has no link. +func (s *Store) ResolveIdentity( + ctx context.Context, + identity authkit.Identity, +) (*authkit.Principal, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if identity.Provider == "" || identity.Subject == "" { + return nil, fmt.Errorf("%w: provider and subject are required", authkit.ErrUnresolvedIdentity) + } + + var principal authkit.Principal + var kind string + var attributes string + err := s.pool.QueryRow( + ctx, + `select p.id, p.kind, p.display_name, coalesce(p.attributes::text, '') + from authkit_external_identities as i + join authkit_principals as p on p.id = i.principal_id + where i.provider = $1 and i.subject = $2`, + identity.Provider, + identity.Subject, + ).Scan(&principal.ID, &kind, &principal.DisplayName, &attributes) + if errors.Is(err, pgx.ErrNoRows) { + return nil, fmt.Errorf( + "%w: identity %q/%q is not linked", + authkit.ErrUnresolvedIdentity, + identity.Provider, + identity.Subject, + ) + } + if err != nil { + return nil, fmt.Errorf("postgres: resolve identity: %w", err) + } + + principal.Kind = authkit.PrincipalKind(kind) + principal.Attributes, err = decodeAttributes(attributes) + if err != nil { + return nil, err + } + + return &principal, nil +} + +// ProvisionIdentity creates a principal and links the supplied external +// identity to it in one transaction, or returns the existing link when the +// identity is already provisioned. Initial roles, when supplied, are +// assigned to the newly created principal; existing principals are never +// mutated by a subsequent ProvisionIdentity call. +// +// When two callers race to provision the same identity, the loser +// encounters a unique-violation on the identity-link insert. The +// transaction rolls back (via withTx's deferred Rollback) and the winning +// row is read from the pool to return a consistent +// `ProvisionIdentityResult{Created: false}` instead of an error. +// +//nolint:gocognit // The validation guard, existing-link short-circuit, and race-recovery branch are intrinsic to this operation's contract. +func (s *Store) ProvisionIdentity( + ctx context.Context, + req authkit.ProvisionIdentityRequest, +) (authkit.ProvisionIdentityResult, error) { + if err := ctx.Err(); err != nil { + return authkit.ProvisionIdentityResult{}, err + } + if req.Identity.Provider == "" || req.Identity.Subject == "" { + return authkit.ProvisionIdentityResult{}, fmt.Errorf( + "%w: provider and subject are required", + authkit.ErrUnresolvedIdentity, + ) + } + if req.Principal.Kind != authkit.PrincipalKindUser && req.Principal.Kind != authkit.PrincipalKindService { + return authkit.ProvisionIdentityResult{}, fmt.Errorf( + "postgres: unsupported principal kind %q", + req.Principal.Kind, + ) + } + + result, err := withTx(ctx, s.pool, "provision identity", func(tx pgx.Tx) (authkit.ProvisionIdentityResult, error) { + // Existing-link short-circuit: return without touching the + // principal row so a re-provision never mutates display name, + // attributes, or roles. + existing, err := findProvisionedIdentity(ctx, tx, req.Identity) + if err == nil { + return existing, nil + } + if !errors.Is(err, pgx.ErrNoRows) { + return authkit.ProvisionIdentityResult{}, fmt.Errorf("postgres: find provisioned identity: %w", err) + } + + principal, err := createPrincipal(ctx, tx, req.Principal) + if err != nil { + return authkit.ProvisionIdentityResult{}, err + } + + link := authkit.ExternalIdentity{ + Provider: req.Identity.Provider, + Subject: req.Identity.Subject, + PrincipalID: principal.ID, + } + if _, err := tx.Exec( + ctx, + `insert into authkit_external_identities (provider, subject, principal_id) + values ($1, $2, $3)`, + link.Provider, + link.Subject, + link.PrincipalID, + ); err != nil { + if isPostgresCode(err, uniqueViolation) { + return authkit.ProvisionIdentityResult{}, errProvisionIdentityConflict + } + + return authkit.ProvisionIdentityResult{}, fmt.Errorf("postgres: link provisioned identity: %w", err) + } + if err := assignInitialRoles(ctx, tx, principal.ID, req.InitialRoleIDs); err != nil { + return authkit.ProvisionIdentityResult{}, err + } + + return authkit.ProvisionIdentityResult{ + Principal: principal, + Link: link, + Created: true, + }, nil + }) + if errors.Is(err, errProvisionIdentityConflict) { + // The transaction has been rolled back by withTx's deferred + // Rollback; the winning row is now visible on the pool. + winner, findErr := findProvisionedIdentity(ctx, s.pool, req.Identity) + if findErr != nil { + return authkit.ProvisionIdentityResult{}, fmt.Errorf( + "postgres: find provisioned identity conflict: %w", + findErr, + ) + } + + return winner, nil + } + + return result, err +} + +// resolveIdentityLinkConflict re-reads the identity link after a +// unique-violation during LinkIdentity. Returns the link when the conflict +// resolves to the same principal (treat as idempotent) or an explanatory +// error when the conflict resolves to a different principal. +func (s *Store) resolveIdentityLinkConflict( + ctx context.Context, + req authkit.LinkIdentityRequest, +) (authkit.ExternalIdentity, error) { + link, err := findIdentityLink(ctx, s.pool, req.Provider, req.Subject) + if err != nil { + return authkit.ExternalIdentity{}, fmt.Errorf("postgres: find identity link conflict: %w", err) + } + if link.PrincipalID == req.PrincipalID { + return link, nil + } + + return authkit.ExternalIdentity{}, fmt.Errorf( + "postgres: identity %q/%q is already linked to principal %q", + req.Provider, + req.Subject, + link.PrincipalID, + ) +} + +// findIdentityLink reads the identity link row matching (provider, subject) +// using query (a pool or a transaction). Returns `pgx.ErrNoRows` when no +// link is stored; callers translate that to the appropriate domain +// sentinel. +func findIdentityLink( + ctx context.Context, + query rowQuerier, + provider string, + subject string, +) (authkit.ExternalIdentity, error) { + var link authkit.ExternalIdentity + err := query.QueryRow( + ctx, + `select provider, subject, principal_id + from authkit_external_identities + where provider = $1 and subject = $2`, + provider, + subject, + ).Scan(&link.Provider, &link.Subject, &link.PrincipalID) + if err != nil { + return authkit.ExternalIdentity{}, err + } + + return link, nil +} + +// findProvisionedIdentity joins the identity link to its principal in a +// single query, returning a `ProvisionIdentityResult{Created: false}`. +// Returns `pgx.ErrNoRows` when the identity has not been provisioned. +func findProvisionedIdentity( + ctx context.Context, + query rowQuerier, + identity authkit.Identity, +) (authkit.ProvisionIdentityResult, error) { + var principal authkit.Principal + var kind string + var attributes string + var link authkit.ExternalIdentity + err := query.QueryRow( + ctx, + `select p.id, p.kind, p.display_name, coalesce(p.attributes::text, ''), + i.provider, i.subject, i.principal_id + from authkit_external_identities as i + join authkit_principals as p on p.id = i.principal_id + where i.provider = $1 and i.subject = $2`, + identity.Provider, + identity.Subject, + ).Scan( + &principal.ID, + &kind, + &principal.DisplayName, + &attributes, + &link.Provider, + &link.Subject, + &link.PrincipalID, + ) + if err != nil { + return authkit.ProvisionIdentityResult{}, err + } + + principal.Kind = authkit.PrincipalKind(kind) + principal.Attributes, err = decodeAttributes(attributes) + if err != nil { + return authkit.ProvisionIdentityResult{}, err + } + + return authkit.ProvisionIdentityResult{ + Principal: principal, + Link: link, + Created: false, + }, nil +} diff --git a/store/postgres/identity_test.go b/store/postgres/identity_test.go new file mode 100644 index 0000000..357a63f --- /dev/null +++ b/store/postgres/identity_test.go @@ -0,0 +1,13 @@ +package postgres + +import ( + "testing" + + "github.com/meigma/authkit" +) + +func TestStoreSatisfiesIdentityContracts(_ *testing.T) { + var _ authkit.IdentityLinker = (*Store)(nil) + var _ authkit.IdentityProvisioner = (*Store)(nil) + var _ authkit.PrincipalResolver = (*Store)(nil) +} diff --git a/store/postgres/migrate.go b/store/postgres/migrate.go index 6e6ae0f..dddc5e9 100644 --- a/store/postgres/migrate.go +++ b/store/postgres/migrate.go @@ -15,7 +15,10 @@ const migrationLockID int64 = 0x617574686b6974 //go:embed migrations/*.sql var migrationFiles embed.FS -// Migrate applies authkit's PostgreSQL schema migrations. +// Migrate applies authkit's PostgreSQL schema migrations. The migrations +// are embedded into the binary, applied in lexical filename order, and +// idempotent on re-runs (each migration is a CREATE TABLE IF NOT EXISTS or +// equivalent). Pool is the caller's; Migrate does not take ownership. func Migrate(ctx context.Context, pool *pgxpool.Pool) error { if err := ctx.Err(); err != nil { return err @@ -32,6 +35,9 @@ func Migrate(ctx context.Context, pool *pgxpool.Pool) error { _ = tx.Rollback(ctx) }() + // Transaction-scoped advisory lock so concurrent Migrate callers + // serialize against each other; the lock releases automatically when + // the transaction commits or rolls back. if _, execErr := tx.Exec(ctx, "select pg_advisory_xact_lock($1)", migrationLockID); execErr != nil { return fmt.Errorf("postgres: acquire migration lock: %w", execErr) } diff --git a/store/postgres/oidc.go b/store/postgres/oidc.go new file mode 100644 index 0000000..ebe6f13 --- /dev/null +++ b/store/postgres/oidc.go @@ -0,0 +1,111 @@ +package postgres + +import ( + "context" + "errors" + "fmt" + + "github.com/jackc/pgx/v5" + + "github.com/meigma/authkit/oidc" +) + +// TrustProvider stores provider as trusted for its issuer. The provider is +// validated before storage; subsequent calls for the same issuer upsert the +// row. +func (s *Store) TrustProvider(ctx context.Context, provider oidc.Provider) (oidc.Provider, error) { + if err := ctx.Err(); err != nil { + return oidc.Provider{}, err + } + if err := provider.Validate(); err != nil { + return oidc.Provider{}, err + } + + trusted := cloneProvider(provider) + signingAlgorithms := trusted.SupportedSigningAlgorithms + if signingAlgorithms == nil { + // pgx maps a nil []string to NULL, so substitute an empty array to + // keep the column non-null and consistent with the table default. + signingAlgorithms = []string{} + } + forwardedClaims, err := encodeClaimPaths(trusted.ForwardedClaims) + if err != nil { + return oidc.Provider{}, err + } + if _, err := s.pool.Exec( + ctx, + `insert into authkit_oidc_providers + (issuer, jwks_url, audiences, supported_signing_algorithms, forwarded_claims) + values ($1, $2, $3, $4, $5::jsonb) + on conflict (issuer) do update set + jwks_url = excluded.jwks_url, + audiences = excluded.audiences, + supported_signing_algorithms = excluded.supported_signing_algorithms, + forwarded_claims = excluded.forwarded_claims, + updated_at = now()`, + trusted.Issuer, + trusted.JWKSURL, + trusted.Audiences, + signingAlgorithms, + forwardedClaims, + ); err != nil { + return oidc.Provider{}, fmt.Errorf("postgres: trust OIDC provider: %w", err) + } + + return cloneProvider(trusted), nil +} + +// FindProvider returns the trusted OIDC provider for issuer, or +// `oidc.ErrProviderNotFound` if no such provider has been trusted. The +// loaded row is revalidated before return. +func (s *Store) FindProvider(ctx context.Context, issuer string) (oidc.Provider, error) { + if err := ctx.Err(); err != nil { + return oidc.Provider{}, err + } + + provider, err := findTrustedProvider(ctx, s.pool, issuer) + if errors.Is(err, pgx.ErrNoRows) { + return oidc.Provider{}, oidc.ErrProviderNotFound + } + if err != nil { + return oidc.Provider{}, fmt.Errorf("postgres: find OIDC provider: %w", err) + } + if err := provider.Validate(); err != nil { + return oidc.Provider{}, fmt.Errorf("postgres: invalid OIDC provider %q: %w", issuer, err) + } + + return cloneProvider(provider), nil +} + +// findTrustedProvider reads an OIDC provider row using query (a pool or a +// transaction). Returns `pgx.ErrNoRows` when no provider is trusted for +// the issuer; callers translate that to the appropriate domain sentinel. +// Used by both FindProvider and by provisioning-rule validation. +func findTrustedProvider(ctx context.Context, query rowQuerier, issuer string) (oidc.Provider, error) { + var provider oidc.Provider + var forwardedClaims string + err := query.QueryRow( + ctx, + `select issuer, audiences, jwks_url, supported_signing_algorithms, + coalesce(forwarded_claims::text, '[]') + from authkit_oidc_providers + where issuer = $1`, + issuer, + ).Scan( + &provider.Issuer, + &provider.Audiences, + &provider.JWKSURL, + &provider.SupportedSigningAlgorithms, + &forwardedClaims, + ) + if err != nil { + return oidc.Provider{}, err + } + + provider.ForwardedClaims, err = decodeClaimPaths(forwardedClaims) + if err != nil { + return oidc.Provider{}, err + } + + return cloneProvider(provider), nil +} diff --git a/store/postgres/oidc_test.go b/store/postgres/oidc_test.go new file mode 100644 index 0000000..9b6b715 --- /dev/null +++ b/store/postgres/oidc_test.go @@ -0,0 +1,12 @@ +package postgres + +import ( + "testing" + + "github.com/meigma/authkit/oidc" +) + +func TestStoreSatisfiesOIDCContracts(_ *testing.T) { + var _ oidc.ProviderSource = (*Store)(nil) + var _ oidc.ProviderTrustStore = (*Store)(nil) +} diff --git a/store/postgres/passkey.go b/store/postgres/passkey.go index 78f8e84..55a0d61 100644 --- a/store/postgres/passkey.go +++ b/store/postgres/passkey.go @@ -14,7 +14,8 @@ import ( "github.com/meigma/authkit/passkey" ) -// FindUserByPrincipal returns the passkey user for principalID and rpID. +// FindUserByPrincipal returns the passkey user for principalID and rpID, or +// `passkey.ErrUserNotFound` if no such user exists. func (s *Store) FindUserByPrincipal(ctx context.Context, rpID string, principalID string) (passkey.User, error) { if err := ctx.Err(); err != nil { return passkey.User{}, err @@ -31,7 +32,8 @@ func (s *Store) FindUserByPrincipal(ctx context.Context, rpID string, principalI return user, nil } -// FindUserByHandle returns the passkey user for handle and rpID. +// FindUserByHandle returns the passkey user for handle and rpID, or +// `passkey.ErrUserNotFound` if no such user exists. func (s *Store) FindUserByHandle(ctx context.Context, rpID string, handle []byte) (passkey.User, error) { if err := ctx.Err(); err != nil { return passkey.User{}, err @@ -48,7 +50,9 @@ func (s *Store) FindUserByHandle(ctx context.Context, rpID string, handle []byte return user, nil } -// ListCredentials returns passkey credentials for userHandle and rpID. +// ListCredentials returns passkey credentials for userHandle and rpID, +// sorted by credential ID. Returns an empty slice (no error) when no user +// matches the handle. func (s *Store) ListCredentials(ctx context.Context, rpID string, userHandle []byte) ([]passkey.Credential, error) { if err := ctx.Err(); err != nil { return nil, err @@ -83,7 +87,12 @@ func (s *Store) ListCredentials(ctx context.Context, rpID string, userHandle []b return credentials, nil } -// CreateRegistration atomically stores a passkey user, credential, and identity link. +// CreateRegistration atomically stores a passkey user, credential, and +// identity link in one transaction. Returns `passkey.ErrUserExists` when +// the same handle is already bound to a different principal (or the same +// principal to a different handle), `passkey.ErrCredentialExists` when the +// credential ID is already stored, and a generic error when the identity +// link conflicts with a different principal. func (s *Store) CreateRegistration( ctx context.Context, registration passkey.Registration, @@ -95,66 +104,61 @@ func (s *Store) CreateRegistration( return passkey.RegistrationResult{}, err } - tx, err := s.pool.Begin(ctx) - if err != nil { - return passkey.RegistrationResult{}, fmt.Errorf("postgres: begin passkey registration: %w", err) - } - defer func() { - _ = tx.Rollback(ctx) - }() - - if userErr := ensurePasskeyUser(ctx, tx, registration.User); userErr != nil { - return passkey.RegistrationResult{}, userErr - } + return withTx(ctx, s.pool, "passkey registration", func(tx pgx.Tx) (passkey.RegistrationResult, error) { + if userErr := ensurePasskeyUser(ctx, tx, registration.User); userErr != nil { + return passkey.RegistrationResult{}, userErr + } - encodedCredential, err := encodeWebAuthnCredential(registration.Credential.WebAuthn) - if err != nil { - return passkey.RegistrationResult{}, err - } - if _, execErr := tx.Exec( - ctx, - `insert into authkit_passkey_credentials - (rp_id, credential_id, principal_id, user_handle, webauthn) - values ($1, $2, $3, $4, $5::jsonb)`, - registration.Credential.RPID, - registration.Credential.CredentialID, - registration.Credential.PrincipalID, - registration.Credential.UserHandle, - encodedCredential, - ); execErr != nil { - if isPostgresCode(execErr, uniqueViolation) { - return passkey.RegistrationResult{}, passkey.ErrCredentialExists + encodedCredential, err := encodeWebAuthnCredential(registration.Credential.WebAuthn) + if err != nil { + return passkey.RegistrationResult{}, err } - if isPostgresCode(execErr, foreignKeyViolation) { - return passkey.RegistrationResult{}, fmt.Errorf( - "postgres: passkey user or principal does not exist: %w", - execErr, - ) + if _, execErr := tx.Exec( + ctx, + `insert into authkit_passkey_credentials + (rp_id, credential_id, principal_id, user_handle, webauthn) + values ($1, $2, $3, $4, $5::jsonb)`, + registration.Credential.RPID, + registration.Credential.CredentialID, + registration.Credential.PrincipalID, + registration.Credential.UserHandle, + encodedCredential, + ); execErr != nil { + if isPostgresCode(execErr, uniqueViolation) { + return passkey.RegistrationResult{}, passkey.ErrCredentialExists + } + if isPostgresCode(execErr, foreignKeyViolation) { + return passkey.RegistrationResult{}, fmt.Errorf( + "postgres: passkey user or principal does not exist: %w", + execErr, + ) + } + + return passkey.RegistrationResult{}, fmt.Errorf("postgres: create passkey credential: %w", execErr) } - return passkey.RegistrationResult{}, fmt.Errorf("postgres: create passkey credential: %w", execErr) - } + link, err := linkPasskeyIdentity(ctx, tx, authkit.LinkIdentityRequest{ + Provider: registration.Identity.Provider, + Subject: registration.Identity.Subject, + PrincipalID: registration.User.PrincipalID, + }) + if err != nil { + return passkey.RegistrationResult{}, err + } - link, err := linkPasskeyIdentity(ctx, tx, authkit.LinkIdentityRequest{ - Provider: registration.Identity.Provider, - Subject: registration.Identity.Subject, - PrincipalID: registration.User.PrincipalID, + return passkey.RegistrationResult{ + User: clonePasskeyUser(registration.User), + Credential: clonePasskeyCredential(registration.Credential), + Link: link, + }, nil }) - if err != nil { - return passkey.RegistrationResult{}, err - } - if err := tx.Commit(ctx); err != nil { - return passkey.RegistrationResult{}, fmt.Errorf("postgres: commit passkey registration: %w", err) - } - - return passkey.RegistrationResult{ - User: clonePasskeyUser(registration.User), - Credential: clonePasskeyCredential(registration.Credential), - Link: link, - }, nil } -// UpdateCredentialAfterLogin stores passkey credential metadata after login. +// UpdateCredentialAfterLogin stores passkey credential metadata after a +// successful login. The credential's binding fields (RPID, PrincipalID, +// UserHandle, CredentialID) are part of the WHERE clause so a rebind +// attempt updates zero rows and returns an error rather than silently +// rebinding. func (s *Store) UpdateCredentialAfterLogin(ctx context.Context, credential passkey.Credential) error { if err := ctx.Err(); err != nil { return err @@ -167,6 +171,9 @@ func (s *Store) UpdateCredentialAfterLogin(ctx context.Context, credential passk if err != nil { return err } + // Security invariant: the WHERE clause pins (RP, credential, principal, + // user handle) so this update never rebinds a credential to a + // different user. tag, err := s.pool.Exec( ctx, `update authkit_passkey_credentials @@ -192,6 +199,12 @@ func (s *Store) UpdateCredentialAfterLogin(ctx context.Context, credential passk return nil } +// ensurePasskeyUser inserts the passkey user when neither (principal, RP) +// nor (handle, RP) is bound, returns nil when both indexes already point +// at an identical user, and returns `passkey.ErrUserExists` when either +// index resolves to a different user. The third disagreement case (one +// index hits, the other misses, with the hit not matching the input) +// signals a corrupted store and surfaces as a generic error. func ensurePasskeyUser(ctx context.Context, tx pgx.Tx, user passkey.User) error { existingByPrincipal, principalErr := findPasskeyUserByPrincipal(ctx, tx, user.RPID, user.PrincipalID) if principalErr != nil && !errors.Is(principalErr, pgx.ErrNoRows) { @@ -239,6 +252,8 @@ func ensurePasskeyUser(ctx context.Context, tx pgx.Tx, user passkey.User) error return nil } +// findPasskeyUserByPrincipal reads the passkey user row keyed by (rpID, +// principalID). Returns `pgx.ErrNoRows` when no such user exists. func findPasskeyUserByPrincipal( ctx context.Context, query rowQuerier, @@ -255,6 +270,8 @@ func findPasskeyUserByPrincipal( )) } +// findPasskeyUserByHandle reads the passkey user row keyed by (rpID, +// handle). Returns `pgx.ErrNoRows` when no such user exists. func findPasskeyUserByHandle( ctx context.Context, query rowQuerier, @@ -271,6 +288,8 @@ func findPasskeyUserByHandle( )) } +// scanPasskeyUser reads a passkey-user row. The row must select rp_id, +// principal_id, user_handle, name, and display_name in that order. func scanPasskeyUser(row scanner) (passkey.User, error) { var user passkey.User if err := row.Scan( @@ -286,6 +305,9 @@ func scanPasskeyUser(row scanner) (passkey.User, error) { return clonePasskeyUser(user), nil } +// scanPasskeyCredential reads a passkey-credential row. The row must +// select rp_id, principal_id, user_handle, credential_id, and the +// JSONB-encoded WebAuthn record (as text) in that order. func scanPasskeyCredential(row scanner) (passkey.Credential, error) { var credential passkey.Credential var encodedCredential string @@ -305,6 +327,10 @@ func scanPasskeyCredential(row scanner) (passkey.Credential, error) { return clonePasskeyCredential(credential), nil } +// linkPasskeyIdentity links the supplied identity to req.PrincipalID inside +// the passkey-registration transaction. Returns the existing link when it +// already binds to the same principal (idempotent), or an error when it +// binds to a different principal (cross-principal rebind rejection). func linkPasskeyIdentity( ctx context.Context, query queryExecutor, @@ -312,6 +338,8 @@ func linkPasskeyIdentity( ) (authkit.ExternalIdentity, error) { link, err := findIdentityLink(ctx, query, req.Provider, req.Subject) if err == nil { + // Cross-principal rebind is rejected so a passkey enrollment + // cannot silently steal another principal's external identity. if link.PrincipalID == req.PrincipalID { return link, nil } @@ -349,28 +377,8 @@ func linkPasskeyIdentity( return link, nil } -func findIdentityLink( - ctx context.Context, - query rowQuerier, - provider string, - subject string, -) (authkit.ExternalIdentity, error) { - var link authkit.ExternalIdentity - err := query.QueryRow( - ctx, - `select provider, subject, principal_id - from authkit_external_identities - where provider = $1 and subject = $2`, - provider, - subject, - ).Scan(&link.Provider, &link.Subject, &link.PrincipalID) - if err != nil { - return authkit.ExternalIdentity{}, err - } - - return link, nil -} - +// validatePasskeyRegistration validates a registration's structural fields +// and the (user, credential) binding consistency. func validatePasskeyRegistration(registration passkey.Registration) error { if err := validatePasskeyUser(registration.User); err != nil { return err @@ -393,6 +401,7 @@ func validatePasskeyRegistration(registration passkey.Registration) error { return nil } +// validatePasskeyUser validates the structural fields of a passkey user. func validatePasskeyUser(user passkey.User) error { if user.RPID == "" { return errors.New("postgres: passkey RP ID is required") @@ -407,6 +416,9 @@ func validatePasskeyUser(user passkey.User) error { return nil } +// validatePasskeyCredential validates the structural fields of a passkey +// credential including non-empty RPID, PrincipalID, UserHandle, and +// CredentialID. func validatePasskeyCredential(credential passkey.Credential) error { if credential.RPID == "" { return errors.New("postgres: passkey credential RP ID is required") @@ -424,6 +436,8 @@ func validatePasskeyCredential(credential passkey.Credential) error { return nil } +// samePasskeyUser reports whether two passkey users describe the same +// record across every observable field. func samePasskeyUser(a passkey.User, b passkey.User) bool { return a.RPID == b.RPID && a.PrincipalID == b.PrincipalID && @@ -432,15 +446,7 @@ func samePasskeyUser(a passkey.User, b passkey.User) bool { a.DisplayName == b.DisplayName } -func encodeWebAuthnCredential(credential webauthn.Credential) (string, error) { - encoded, err := json.Marshal(credential) - if err != nil { - return "", fmt.Errorf("postgres: encode passkey credential: %w", err) - } - - return string(encoded), nil -} - +// clonePasskeyUser returns a copy of user with an independent Handle slice. func clonePasskeyUser(user passkey.User) passkey.User { return passkey.User{ RPID: user.RPID, @@ -451,6 +457,8 @@ func clonePasskeyUser(user passkey.User) passkey.User { } } +// clonePasskeyCredential returns a copy of credential with independent +// UserHandle, CredentialID, and WebAuthn byte slices. func clonePasskeyCredential(credential passkey.Credential) passkey.Credential { return passkey.Credential{ RPID: credential.RPID, @@ -461,6 +469,9 @@ func clonePasskeyCredential(credential passkey.Credential) passkey.Credential { } } +// cloneWebAuthnCredential returns a copy of the upstream webauthn.Credential +// with every nested byte slice copied so the stored record cannot mutate +// through aliasing. func cloneWebAuthnCredential(credential webauthn.Credential) webauthn.Credential { clone := credential clone.ID = clonePasskeyBytes(credential.ID) @@ -475,6 +486,8 @@ func cloneWebAuthnCredential(credential webauthn.Credential) webauthn.Credential return clone } +// clonePasskeyBytes returns an independent copy of value. A nil or empty +// input returns nil. func clonePasskeyBytes(value []byte) []byte { if len(value) == 0 { return nil diff --git a/store/postgres/passkey_test.go b/store/postgres/passkey_test.go new file mode 100644 index 0000000..e02cadd --- /dev/null +++ b/store/postgres/passkey_test.go @@ -0,0 +1,11 @@ +package postgres + +import ( + "testing" + + "github.com/meigma/authkit/passkey" +) + +func TestStoreSatisfiesPasskeyContracts(_ *testing.T) { + var _ passkey.Store = (*Store)(nil) +} diff --git a/store/postgres/principal.go b/store/postgres/principal.go new file mode 100644 index 0000000..fc899d3 --- /dev/null +++ b/store/postgres/principal.go @@ -0,0 +1,167 @@ +package postgres + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + + "github.com/jackc/pgx/v5" + + "github.com/meigma/authkit" +) + +// CreatePrincipal creates a principal in PostgreSQL. The principal ID is +// randomly generated; the insert is retried on the rare event of an ID +// collision before returning an error. +func (s *Store) CreatePrincipal( + ctx context.Context, + req authkit.CreatePrincipalRequest, +) (authkit.Principal, error) { + if err := ctx.Err(); err != nil { + return authkit.Principal{}, err + } + if req.Kind != authkit.PrincipalKindUser && req.Kind != authkit.PrincipalKindService { + return authkit.Principal{}, fmt.Errorf("postgres: unsupported principal kind %q", req.Kind) + } + + return createPrincipal(ctx, s.pool, req) +} + +// FindPrincipal returns the principal identified by id, or +// `authkit.ErrPrincipalNotFound` if no such principal exists. +func (s *Store) FindPrincipal(ctx context.Context, id string) (authkit.Principal, error) { + if err := ctx.Err(); err != nil { + return authkit.Principal{}, err + } + if id == "" { + return authkit.Principal{}, errors.New("postgres: principal ID is required") + } + + principal, err := scanPrincipal(s.pool.QueryRow( + ctx, + `select id, kind, display_name, coalesce(attributes::text, '') + from authkit_principals + where id = $1`, + id, + )) + if errors.Is(err, pgx.ErrNoRows) { + return authkit.Principal{}, authkit.ErrPrincipalNotFound + } + if err != nil { + return authkit.Principal{}, err + } + + return principal, nil +} + +// ListPrincipals returns every principal in the store, sorted by ID. +func (s *Store) ListPrincipals(ctx context.Context) ([]authkit.Principal, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + rows, err := s.pool.Query( + ctx, + `select id, kind, display_name, coalesce(attributes::text, '') + from authkit_principals + order by id`, + ) + if err != nil { + return nil, fmt.Errorf("postgres: list principals: %w", err) + } + defer rows.Close() + + var principals []authkit.Principal + for rows.Next() { + principal, scanErr := scanPrincipal(rows) + if scanErr != nil { + return nil, scanErr + } + principals = append(principals, principal) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("postgres: read principals: %w", err) + } + + return principals, nil +} + +// createPrincipal inserts a new principal row using exec (a pool or a +// transaction) and returns it. The ID is generated with crypto/rand and +// retried on `uniqueViolation` up to `principalIDAttempts` times before +// giving up. Used by both CreatePrincipal and ProvisionIdentity. +func createPrincipal( + ctx context.Context, + exec sqlExecutor, + req authkit.CreatePrincipalRequest, +) (authkit.Principal, error) { + attributes, err := encodeAttributes(req.Attributes) + if err != nil { + return authkit.Principal{}, err + } + + for range principalIDAttempts { + principal := authkit.Principal{ + ID: principalIDPrefix + rand.Text(), + Kind: req.Kind, + DisplayName: req.DisplayName, + Attributes: cloneAttributes(req.Attributes), + } + _, err := exec.Exec( + ctx, + `insert into authkit_principals (id, kind, display_name, attributes) + values ($1, $2, $3, nullif($4, '')::jsonb)`, + principal.ID, + string(principal.Kind), + principal.DisplayName, + attributes, + ) + if err == nil { + return principal, nil + } + if !isPostgresCode(err, uniqueViolation) { + return authkit.Principal{}, fmt.Errorf("postgres: create principal: %w", err) + } + } + + return authkit.Principal{}, errors.New("postgres: create principal: generated duplicate principal IDs") +} + +// principalExists reports whether principalID identifies an existing +// principal. Used by methods that need an existence check before issuing a +// not-found-style error (e.g. ListPrincipalRoleAssignments, +// ResolvePrincipalActions). +func (s *Store) principalExists(ctx context.Context, principalID string) (bool, error) { + var exists bool + if err := s.pool.QueryRow( + ctx, + `select exists(select 1 from authkit_principals where id = $1)`, + principalID, + ).Scan(&exists); err != nil { + return false, fmt.Errorf("postgres: find principal: %w", err) + } + + return exists, nil +} + +// scanPrincipal reads a principal row from row (a pgx.Row or pgx.Rows). The +// row must select id, kind, display_name, and attributes::text in that +// order. +func scanPrincipal(row scanner) (authkit.Principal, error) { + var principal authkit.Principal + var kind string + var attributes string + if err := row.Scan(&principal.ID, &kind, &principal.DisplayName, &attributes); err != nil { + return authkit.Principal{}, err + } + + principal.Kind = authkit.PrincipalKind(kind) + attrs, err := decodeAttributes(attributes) + if err != nil { + return authkit.Principal{}, err + } + principal.Attributes = attrs + + return principal, nil +} diff --git a/store/postgres/principal_test.go b/store/postgres/principal_test.go new file mode 100644 index 0000000..f6cb78d --- /dev/null +++ b/store/postgres/principal_test.go @@ -0,0 +1,13 @@ +package postgres + +import ( + "testing" + + "github.com/meigma/authkit" +) + +func TestStoreSatisfiesPrincipalContracts(_ *testing.T) { + var _ authkit.PrincipalCreator = (*Store)(nil) + var _ authkit.PrincipalFinder = (*Store)(nil) + var _ authkit.PrincipalLister = (*Store)(nil) +} diff --git a/store/postgres/provisioning.go b/store/postgres/provisioning.go new file mode 100644 index 0000000..6c54367 --- /dev/null +++ b/store/postgres/provisioning.go @@ -0,0 +1,401 @@ +package postgres + +import ( + "context" + "errors" + "fmt" + + "github.com/jackc/pgx/v5" + + "github.com/meigma/authkit" + "github.com/meigma/authkit/provisioning" +) + +// CreateProvisioningRule creates a provisioning rule. The CEL condition is +// normalized at write time, the provider must be trusted, and every role in +// AssignRoleIDs must already exist. The full insert (rule + role +// assignments) runs in one transaction so a partial insert never escapes. +func (s *Store) CreateProvisioningRule( + ctx context.Context, + req authkit.CreateProvisioningRuleRequest, +) (authkit.ProvisioningRule, error) { + if err := ctx.Err(); err != nil { + return authkit.ProvisioningRule{}, err + } + + rule := provisioningRuleFromCreate(req) + + return withTx(ctx, s.pool, "create provisioning rule", func(tx pgx.Tx) (authkit.ProvisioningRule, error) { + if err := validateProvisioningRule(ctx, tx, rule); err != nil { + return authkit.ProvisioningRule{}, err + } + if err := insertProvisioningRule(ctx, tx, rule); err != nil { + return authkit.ProvisioningRule{}, err + } + if err := insertProvisioningRuleRoles(ctx, tx, rule.ID, rule.AssignRoleIDs); err != nil { + return authkit.ProvisioningRule{}, err + } + + return cloneProvisioningRule(rule), nil + }) +} + +// UpdateProvisioningRule replaces an existing provisioning rule wholesale. +// Returns `authkit.ErrProvisioningRuleNotFound` if no rule with the given +// ID exists. Like CreateProvisioningRule, the row update and the role +// rewrite run in one transaction. +func (s *Store) UpdateProvisioningRule( + ctx context.Context, + req authkit.UpdateProvisioningRuleRequest, +) (authkit.ProvisioningRule, error) { + if err := ctx.Err(); err != nil { + return authkit.ProvisioningRule{}, err + } + + rule := provisioningRuleFromUpdate(req) + + return withTx(ctx, s.pool, "update provisioning rule", func(tx pgx.Tx) (authkit.ProvisioningRule, error) { + exists, err := provisioningRuleExists(ctx, tx, rule.ID) + if err != nil { + return authkit.ProvisioningRule{}, err + } + if !exists { + return authkit.ProvisioningRule{}, authkit.ErrProvisioningRuleNotFound + } + if validateErr := validateProvisioningRule(ctx, tx, rule); validateErr != nil { + return authkit.ProvisioningRule{}, validateErr + } + + tag, err := tx.Exec( + ctx, + `update authkit_provisioning_rules + set display_name = $2, + provider = $3, + condition = $4, + enabled = $5, + updated_at = now() + where id = $1`, + rule.ID, + rule.DisplayName, + rule.Provider, + rule.Condition, + rule.Enabled, + ) + if err != nil { + return authkit.ProvisioningRule{}, fmt.Errorf("postgres: update provisioning rule: %w", err) + } + if tag.RowsAffected() == 0 { + return authkit.ProvisioningRule{}, authkit.ErrProvisioningRuleNotFound + } + // Replace the existing role assignments wholesale so the new rule's + // AssignRoleIDs are the authoritative set. + if _, err := tx.Exec( + ctx, + `delete from authkit_provisioning_rule_roles where rule_id = $1`, + rule.ID, + ); err != nil { + return authkit.ProvisioningRule{}, fmt.Errorf("postgres: clear provisioning rule roles: %w", err) + } + if err := insertProvisioningRuleRoles(ctx, tx, rule.ID, rule.AssignRoleIDs); err != nil { + return authkit.ProvisioningRule{}, err + } + + return cloneProvisioningRule(rule), nil + }) +} + +// DeleteProvisioningRule removes the provisioning rule identified by id. +// Returns `authkit.ErrProvisioningRuleNotFound` if no such rule exists. +func (s *Store) DeleteProvisioningRule(ctx context.Context, id string) error { + if err := ctx.Err(); err != nil { + return err + } + if id == "" { + return errors.New("postgres: provisioning rule ID is required") + } + + tag, err := s.pool.Exec(ctx, `delete from authkit_provisioning_rules where id = $1`, id) + if err != nil { + return fmt.Errorf("postgres: delete provisioning rule: %w", err) + } + if tag.RowsAffected() == 0 { + return authkit.ErrProvisioningRuleNotFound + } + + return nil +} + +// FindProvisioningRule returns the provisioning rule identified by id, or +// `authkit.ErrProvisioningRuleNotFound` if no such rule exists. +func (s *Store) FindProvisioningRule(ctx context.Context, id string) (authkit.ProvisioningRule, error) { + if err := ctx.Err(); err != nil { + return authkit.ProvisioningRule{}, err + } + if id == "" { + return authkit.ProvisioningRule{}, errors.New("postgres: provisioning rule ID is required") + } + + rule, err := findProvisioningRule(ctx, s.pool, id) + if errors.Is(err, pgx.ErrNoRows) { + return authkit.ProvisioningRule{}, authkit.ErrProvisioningRuleNotFound + } + if err != nil { + return authkit.ProvisioningRule{}, err + } + + return rule, nil +} + +// ListProvisioningRules returns all provisioning rules sorted by ID. Returns +// nil (no error) when no rules are stored. +func (s *Store) ListProvisioningRules(ctx context.Context) ([]authkit.ProvisioningRule, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + rows, err := s.pool.Query( + ctx, + `select r.id, r.display_name, r.provider, r.condition, r.enabled, + coalesce(array_agg(rr.role_id order by rr.role_id) + filter (where rr.role_id is not null), '{}'::text[]) as role_ids + from authkit_provisioning_rules as r + left join authkit_provisioning_rule_roles as rr on rr.rule_id = r.id + group by r.id + order by r.id`, + ) + if err != nil { + return nil, fmt.Errorf("postgres: list provisioning rules: %w", err) + } + defer rows.Close() + + var rules []authkit.ProvisioningRule + for rows.Next() { + rule, scanErr := scanProvisioningRule(rows) + if scanErr != nil { + return nil, scanErr + } + rules = append(rules, rule) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("postgres: read provisioning rules: %w", err) + } + + return rules, nil +} + +// validateProvisioningRule validates rule's structural fields, CEL +// condition, trusted provider, and referenced roles using query (a pool or +// a transaction). Used by both CreateProvisioningRule and +// UpdateProvisioningRule. +func validateProvisioningRule(ctx context.Context, query queryExecutor, rule authkit.ProvisioningRule) error { + if rule.ID == "" { + return errors.New("postgres: provisioning rule ID is required") + } + if rule.Provider == "" { + return errors.New("postgres: provisioning rule provider is required") + } + if err := provisioning.ValidateCondition(rule.Condition); err != nil { + return fmt.Errorf("postgres: %w", err) + } + if err := validateRequiredStrings("provisioning rule role ID", rule.AssignRoleIDs); err != nil { + return fmt.Errorf("postgres: %w", err) + } + + _, err := findTrustedProvider(ctx, query, rule.Provider) + if errors.Is(err, pgx.ErrNoRows) { + return fmt.Errorf("postgres: provider %q is not trusted", rule.Provider) + } + if err != nil { + return err + } + + var roleCount int + if err := query.QueryRow( + ctx, + `select count(*) from authkit_roles where id = any($1)`, + rule.AssignRoleIDs, + ).Scan(&roleCount); err != nil { + return fmt.Errorf("postgres: validate provisioning rule roles: %w", err) + } + if roleCount != len(rule.AssignRoleIDs) { + return errors.New("postgres: provisioning rule references missing role") + } + + return nil +} + +// provisioningRuleExists reports whether a provisioning rule with the +// supplied id is already stored. +func provisioningRuleExists(ctx context.Context, query rowQuerier, id string) (bool, error) { + var exists bool + if err := query.QueryRow( + ctx, + `select exists(select 1 from authkit_provisioning_rules where id = $1)`, + id, + ).Scan(&exists); err != nil { + return false, fmt.Errorf("postgres: find provisioning rule: %w", err) + } + + return exists, nil +} + +// insertProvisioningRule inserts a single rule row, mapping a +// uniqueViolation to a friendlier "already exists" error. +func insertProvisioningRule(ctx context.Context, exec sqlExecutor, rule authkit.ProvisioningRule) error { + if _, err := exec.Exec( + ctx, + `insert into authkit_provisioning_rules + (id, display_name, provider, condition, enabled) + values ($1, $2, $3, $4, $5)`, + rule.ID, + rule.DisplayName, + rule.Provider, + rule.Condition, + rule.Enabled, + ); err != nil { + if isPostgresCode(err, uniqueViolation) { + return fmt.Errorf("postgres: provisioning rule %q already exists", rule.ID) + } + + return fmt.Errorf("postgres: create provisioning rule: %w", err) + } + + return nil +} + +// insertProvisioningRuleRoles inserts ruleID's role-assignment join rows. +// Idempotent via `on conflict do nothing`; an empty roleIDs slice is a +// no-op. +func insertProvisioningRuleRoles( + ctx context.Context, + exec sqlExecutor, + ruleID string, + roleIDs []string, +) error { + if len(roleIDs) == 0 { + return nil + } + if _, err := exec.Exec( + ctx, + `insert into authkit_provisioning_rule_roles (rule_id, role_id) + select $1, unnest($2::text[]) + on conflict (rule_id, role_id) do nothing`, + ruleID, + roleIDs, + ); err != nil { + return fmt.Errorf("postgres: assign provisioning rule roles: %w", err) + } + + return nil +} + +// findProvisioningRule reads a single provisioning rule with its assigned +// role IDs aggregated via array_agg. Returns `pgx.ErrNoRows` when the rule +// is missing; callers translate this to +// `authkit.ErrProvisioningRuleNotFound`. +func findProvisioningRule( + ctx context.Context, + query rowQuerier, + id string, +) (authkit.ProvisioningRule, error) { + rule, err := scanProvisioningRule(query.QueryRow( + ctx, + `select r.id, r.display_name, r.provider, r.condition, r.enabled, + coalesce(array_agg(rr.role_id order by rr.role_id) + filter (where rr.role_id is not null), '{}'::text[]) as role_ids + from authkit_provisioning_rules as r + left join authkit_provisioning_rule_roles as rr on rr.rule_id = r.id + where r.id = $1 + group by r.id`, + id, + )) + if err != nil { + return authkit.ProvisioningRule{}, err + } + + return rule, nil +} + +// scanProvisioningRule reads a provisioning rule row from row (a pgx.Row +// or pgx.Rows). The row must select id, display_name, provider, condition, +// enabled, and an aggregated role_ids text[] in that order. +func scanProvisioningRule(row scanner) (authkit.ProvisioningRule, error) { + var rule authkit.ProvisioningRule + if err := row.Scan( + &rule.ID, + &rule.DisplayName, + &rule.Provider, + &rule.Condition, + &rule.Enabled, + &rule.AssignRoleIDs, + ); err != nil { + return authkit.ProvisioningRule{}, err + } + + return cloneProvisioningRule(rule), nil +} + +// assignInitialRoles attaches roleIDs to principalID. Used by +// ProvisionIdentity. An empty roleIDs slice is a no-op; a non-empty slice +// must contain only non-empty entries. Maps a foreign-key violation back to +// a "initial role does not exist" error. +func assignInitialRoles(ctx context.Context, exec sqlExecutor, principalID string, roleIDs []string) error { + roleIDs = uniqueStrings(roleIDs) + if err := validateNonEmptyStrings("initial role ID", roleIDs); err != nil { + return fmt.Errorf("postgres: %w", err) + } + if len(roleIDs) == 0 { + return nil + } + + if _, err := exec.Exec( + ctx, + `insert into authkit_principal_roles (principal_id, role_id) + select $1, unnest($2::text[]) + on conflict (principal_id, role_id) do nothing`, + principalID, + roleIDs, + ); err != nil { + if isPostgresCode(err, foreignKeyViolation) { + return errors.New("postgres: initial role does not exist") + } + + return fmt.Errorf("postgres: assign initial roles: %w", err) + } + + return nil +} + +// provisioningRuleFromCreate converts a create request into a normalized +// ProvisioningRule (condition normalized, role IDs deduplicated and copied). +func provisioningRuleFromCreate(req authkit.CreateProvisioningRuleRequest) authkit.ProvisioningRule { + return normalizeProvisioningRule(authkit.ProvisioningRule{ + ID: req.ID, + DisplayName: req.DisplayName, + Provider: req.Provider, + Condition: provisioning.NormalizeCondition(req.Condition), + AssignRoleIDs: cloneStrings(req.AssignRoleIDs), + Enabled: req.Enabled, + }) +} + +// provisioningRuleFromUpdate converts an update request into a normalized +// ProvisioningRule using the same normalization rules as +// provisioningRuleFromCreate. +func provisioningRuleFromUpdate(req authkit.UpdateProvisioningRuleRequest) authkit.ProvisioningRule { + return normalizeProvisioningRule(authkit.ProvisioningRule{ + ID: req.ID, + DisplayName: req.DisplayName, + Provider: req.Provider, + Condition: provisioning.NormalizeCondition(req.Condition), + AssignRoleIDs: cloneStrings(req.AssignRoleIDs), + Enabled: req.Enabled, + }) +} + +// normalizeProvisioningRule deduplicates rule.AssignRoleIDs. +func normalizeProvisioningRule(rule authkit.ProvisioningRule) authkit.ProvisioningRule { + rule.AssignRoleIDs = uniqueStrings(rule.AssignRoleIDs) + + return rule +} diff --git a/store/postgres/provisioning_test.go b/store/postgres/provisioning_test.go new file mode 100644 index 0000000..ec9f7b7 --- /dev/null +++ b/store/postgres/provisioning_test.go @@ -0,0 +1,15 @@ +package postgres + +import ( + "testing" + + "github.com/meigma/authkit" +) + +func TestStoreSatisfiesProvisioningContracts(_ *testing.T) { + var _ authkit.ProvisioningRuleCreator = (*Store)(nil) + var _ authkit.ProvisioningRuleUpdater = (*Store)(nil) + var _ authkit.ProvisioningRuleDeleter = (*Store)(nil) + var _ authkit.ProvisioningRuleFinder = (*Store)(nil) + var _ authkit.ProvisioningRuleLister = (*Store)(nil) +} diff --git a/store/postgres/role.go b/store/postgres/role.go new file mode 100644 index 0000000..229c410 --- /dev/null +++ b/store/postgres/role.go @@ -0,0 +1,262 @@ +package postgres + +import ( + "context" + "errors" + "fmt" + + "github.com/meigma/authkit" +) + +// CreateRole creates a local role. Returns an error if a role with the same +// ID already exists. +func (s *Store) CreateRole(ctx context.Context, req authkit.CreateRoleRequest) (authkit.Role, error) { + if err := ctx.Err(); err != nil { + return authkit.Role{}, err + } + if req.ID == "" { + return authkit.Role{}, errors.New("postgres: role ID is required") + } + + role := authkit.Role(req) + if _, err := s.pool.Exec( + ctx, + `insert into authkit_roles (id, display_name, description) + values ($1, $2, $3)`, + role.ID, + role.DisplayName, + role.Description, + ); err != nil { + if isPostgresCode(err, uniqueViolation) { + return authkit.Role{}, fmt.Errorf("postgres: role %q already exists", req.ID) + } + + return authkit.Role{}, fmt.Errorf("postgres: create role: %w", err) + } + + return role, nil +} + +// GrantRoleAction grants an action to a local role. Idempotent via +// `on conflict do nothing`. Returns an error if the role does not exist. +func (s *Store) GrantRoleAction(ctx context.Context, req authkit.GrantRoleActionRequest) error { + if err := ctx.Err(); err != nil { + return err + } + if req.RoleID == "" { + return errors.New("postgres: role ID is required") + } + if req.Action == "" { + return errors.New("postgres: action is required") + } + + if _, err := s.pool.Exec( + ctx, + `insert into authkit_role_actions (role_id, action) + values ($1, $2) + on conflict (role_id, action) do nothing`, + req.RoleID, + req.Action, + ); err != nil { + if isPostgresCode(err, foreignKeyViolation) { + return fmt.Errorf("postgres: role %q does not exist", req.RoleID) + } + + return fmt.Errorf("postgres: grant role action: %w", err) + } + + return nil +} + +// AssignPrincipalRole assigns principal req.PrincipalID to local role +// req.RoleID. Idempotent via `on conflict do nothing`. Returns an error if +// either side of the join is missing. +func (s *Store) AssignPrincipalRole(ctx context.Context, req authkit.AssignPrincipalRoleRequest) error { + if err := ctx.Err(); err != nil { + return err + } + if req.PrincipalID == "" { + return errors.New("postgres: principal ID is required") + } + if req.RoleID == "" { + return errors.New("postgres: role ID is required") + } + + if _, err := s.pool.Exec( + ctx, + `insert into authkit_principal_roles (principal_id, role_id) + values ($1, $2) + on conflict (principal_id, role_id) do nothing`, + req.PrincipalID, + req.RoleID, + ); err != nil { + if isPostgresCode(err, foreignKeyViolation) { + return fmt.Errorf( + "postgres: principal %q or role %q does not exist", + req.PrincipalID, + req.RoleID, + ) + } + + return fmt.Errorf("postgres: assign principal role: %w", err) + } + + return nil +} + +// UnassignPrincipalRole removes principal req.PrincipalID from local role +// req.RoleID. Idempotent: deleting a non-existent assignment is not an +// error. Returns `authkit.ErrPrincipalNotFound` when the principal itself +// is missing, or a generic error when the role is missing. +func (s *Store) UnassignPrincipalRole(ctx context.Context, req authkit.UnassignPrincipalRoleRequest) error { + if err := ctx.Err(); err != nil { + return err + } + if req.PrincipalID == "" { + return errors.New("postgres: principal ID is required") + } + if req.RoleID == "" { + return errors.New("postgres: role ID is required") + } + + exists, err := s.principalExists(ctx, req.PrincipalID) + if err != nil { + return err + } + if !exists { + return authkit.ErrPrincipalNotFound + } + roleExists, err := s.roleExists(ctx, req.RoleID) + if err != nil { + return err + } + if !roleExists { + return fmt.Errorf("postgres: role %q does not exist", req.RoleID) + } + + if _, err := s.pool.Exec( + ctx, + `delete from authkit_principal_roles where principal_id = $1 and role_id = $2`, + req.PrincipalID, + req.RoleID, + ); err != nil { + return fmt.Errorf("postgres: unassign principal role: %w", err) + } + + return nil +} + +// ListPrincipalRoleAssignments returns role assignments for principalID, +// sorted by role ID. Returns `authkit.ErrPrincipalNotFound` if the +// principal does not exist. +func (s *Store) ListPrincipalRoleAssignments( + ctx context.Context, + principalID string, +) ([]authkit.PrincipalRoleAssignment, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if principalID == "" { + return nil, errors.New("postgres: principal ID is required") + } + + exists, err := s.principalExists(ctx, principalID) + if err != nil { + return nil, err + } + if !exists { + return nil, authkit.ErrPrincipalNotFound + } + + rows, err := s.pool.Query( + ctx, + `select principal_id, role_id + from authkit_principal_roles + where principal_id = $1 + order by role_id`, + principalID, + ) + if err != nil { + return nil, fmt.Errorf("postgres: list principal role assignments: %w", err) + } + defer rows.Close() + + var assignments []authkit.PrincipalRoleAssignment + for rows.Next() { + var assignment authkit.PrincipalRoleAssignment + if err := rows.Scan(&assignment.PrincipalID, &assignment.RoleID); err != nil { + return nil, fmt.Errorf("postgres: scan principal role assignment: %w", err) + } + assignments = append(assignments, assignment) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("postgres: read principal role assignments: %w", err) + } + + return assignments, nil +} + +// ResolvePrincipalActions returns the distinct, sorted set of actions +// granted to principalID through its role assignments. Returns an error +// when the principal does not exist; returns nil (no error) when the +// principal exists but holds no granted actions. +func (s *Store) ResolvePrincipalActions(ctx context.Context, principalID string) ([]string, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if principalID == "" { + return nil, errors.New("postgres: principal ID is required") + } + + exists, err := s.principalExists(ctx, principalID) + if err != nil { + return nil, err + } + if !exists { + return nil, fmt.Errorf("postgres: principal %q does not exist", principalID) + } + + rows, err := s.pool.Query( + ctx, + `select distinct ra.action + from authkit_principal_roles as pr + join authkit_role_actions as ra on ra.role_id = pr.role_id + where pr.principal_id = $1 + order by ra.action`, + principalID, + ) + if err != nil { + return nil, fmt.Errorf("postgres: resolve principal actions: %w", err) + } + defer rows.Close() + + var actions []string + for rows.Next() { + var action string + if err := rows.Scan(&action); err != nil { + return nil, fmt.Errorf("postgres: scan principal action: %w", err) + } + actions = append(actions, action) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("postgres: read principal actions: %w", err) + } + + return actions, nil +} + +// roleExists reports whether roleID identifies an existing role. Used by +// methods that need to disambiguate "principal missing" from "role +// missing" in their error paths. +func (s *Store) roleExists(ctx context.Context, roleID string) (bool, error) { + var exists bool + if err := s.pool.QueryRow( + ctx, + `select exists(select 1 from authkit_roles where id = $1)`, + roleID, + ).Scan(&exists); err != nil { + return false, fmt.Errorf("postgres: find role: %w", err) + } + + return exists, nil +} diff --git a/store/postgres/role_test.go b/store/postgres/role_test.go new file mode 100644 index 0000000..266fc35 --- /dev/null +++ b/store/postgres/role_test.go @@ -0,0 +1,16 @@ +package postgres + +import ( + "testing" + + "github.com/meigma/authkit" +) + +func TestStoreSatisfiesRoleContracts(_ *testing.T) { + var _ authkit.RoleCreator = (*Store)(nil) + var _ authkit.RoleActionGranter = (*Store)(nil) + var _ authkit.PrincipalRoleAssigner = (*Store)(nil) + var _ authkit.PrincipalRoleUnassigner = (*Store)(nil) + var _ authkit.PrincipalRoleAssignmentLister = (*Store)(nil) + var _ authkit.PrincipalActionResolver = (*Store)(nil) +} diff --git a/store/postgres/store.go b/store/postgres/store.go index da0311a..4ebeac4 100644 --- a/store/postgres/store.go +++ b/store/postgres/store.go @@ -2,23 +2,11 @@ package postgres import ( "context" - "crypto/rand" - "crypto/sha256" - "encoding/json" "errors" - "fmt" - "maps" - "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" - - "github.com/meigma/authkit" - "github.com/meigma/authkit/apikey" - "github.com/meigma/authkit/oidc" - "github.com/meigma/authkit/provisioning" ) const ( @@ -28,30 +16,47 @@ const ( uniqueViolation = "23505" ) -// Store persists authkit principals, identity links, API tokens, and OIDC provider trust in PostgreSQL. +// Store persists authkit principals, roles, provisioning rules, identity +// links, API tokens, and OIDC provider trust in PostgreSQL. A single Store +// implements every authkit storage port (see `internal/storetest.Store`); +// domain methods are split across sibling files in this package by domain +// (principal, role, provisioning, identity, oidc, token, passkey). +// +// All exported methods are safe for concurrent use. Multi-write operations +// (CreateProvisioningRule, UpdateProvisioningRule, ProvisionIdentity, and +// the passkey CreateRegistration) execute inside a transaction; see `tx.go` +// for the shared `withTx` helper. type Store struct { pool *pgxpool.Pool } +// sqlExecutor narrows pgx's Exec surface so helpers can accept either a +// pool or a transaction. type sqlExecutor interface { Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) } +// rowQuerier narrows pgx's QueryRow surface for the same reason. type rowQuerier interface { QueryRow(ctx context.Context, sql string, args ...any) pgx.Row } +// queryExecutor combines exec, single-row read, and multi-row read so helpers +// that need every shape can take a single argument. type queryExecutor interface { sqlExecutor rowQuerier Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) } +// scanner abstracts pgx.Row and pgx.Rows to a single Scan method so +// row-shaping helpers work uniformly for QueryRow and Query loops. type scanner interface { Scan(dest ...any) error } -// NewStore constructs a PostgreSQL store around pool. +// NewStore constructs a PostgreSQL store around pool. The pool is the +// caller's to manage; NewStore does not retain ownership of its lifecycle. func NewStore(pool *pgxpool.Pool) (*Store, error) { if pool == nil { return nil, errors.New("postgres: pool is required") @@ -61,1524 +66,3 @@ func NewStore(pool *pgxpool.Pool) (*Store, error) { pool: pool, }, nil } - -// CreatePrincipal creates a principal in PostgreSQL. -func (s *Store) CreatePrincipal( - ctx context.Context, - req authkit.CreatePrincipalRequest, -) (authkit.Principal, error) { - if err := ctx.Err(); err != nil { - return authkit.Principal{}, err - } - if req.Kind != authkit.PrincipalKindUser && req.Kind != authkit.PrincipalKindService { - return authkit.Principal{}, fmt.Errorf("postgres: unsupported principal kind %q", req.Kind) - } - - return createPrincipal(ctx, s.pool, req) -} - -// FindPrincipal returns a principal by ID. -func (s *Store) FindPrincipal(ctx context.Context, id string) (authkit.Principal, error) { - if err := ctx.Err(); err != nil { - return authkit.Principal{}, err - } - if id == "" { - return authkit.Principal{}, errors.New("postgres: principal ID is required") - } - - principal, err := scanPrincipal(s.pool.QueryRow( - ctx, - `select id, kind, display_name, coalesce(attributes::text, '') - from authkit_principals - where id = $1`, - id, - )) - if errors.Is(err, pgx.ErrNoRows) { - return authkit.Principal{}, authkit.ErrPrincipalNotFound - } - if err != nil { - return authkit.Principal{}, err - } - - return principal, nil -} - -// ListPrincipals returns all principals sorted by ID. -func (s *Store) ListPrincipals(ctx context.Context) ([]authkit.Principal, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - - rows, err := s.pool.Query( - ctx, - `select id, kind, display_name, coalesce(attributes::text, '') - from authkit_principals - order by id`, - ) - if err != nil { - return nil, fmt.Errorf("postgres: list principals: %w", err) - } - defer rows.Close() - - var principals []authkit.Principal - for rows.Next() { - principal, scanErr := scanPrincipal(rows) - if scanErr != nil { - return nil, scanErr - } - principals = append(principals, principal) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("postgres: read principals: %w", err) - } - - return principals, nil -} - -// CreateRole creates a local role in PostgreSQL. -func (s *Store) CreateRole(ctx context.Context, req authkit.CreateRoleRequest) (authkit.Role, error) { - if err := ctx.Err(); err != nil { - return authkit.Role{}, err - } - if req.ID == "" { - return authkit.Role{}, errors.New("postgres: role ID is required") - } - - role := authkit.Role(req) - if _, err := s.pool.Exec( - ctx, - `insert into authkit_roles (id, display_name, description) - values ($1, $2, $3)`, - role.ID, - role.DisplayName, - role.Description, - ); err != nil { - if isPostgresCode(err, uniqueViolation) { - return authkit.Role{}, fmt.Errorf("postgres: role %q already exists", req.ID) - } - - return authkit.Role{}, fmt.Errorf("postgres: create role: %w", err) - } - - return role, nil -} - -// GrantRoleAction grants an action to a local role. -func (s *Store) GrantRoleAction(ctx context.Context, req authkit.GrantRoleActionRequest) error { - if err := ctx.Err(); err != nil { - return err - } - if req.RoleID == "" { - return errors.New("postgres: role ID is required") - } - if req.Action == "" { - return errors.New("postgres: action is required") - } - - if _, err := s.pool.Exec( - ctx, - `insert into authkit_role_actions (role_id, action) - values ($1, $2) - on conflict (role_id, action) do nothing`, - req.RoleID, - req.Action, - ); err != nil { - if isPostgresCode(err, foreignKeyViolation) { - return fmt.Errorf("postgres: role %q does not exist", req.RoleID) - } - - return fmt.Errorf("postgres: grant role action: %w", err) - } - - return nil -} - -// AssignPrincipalRole assigns a principal to a local role. -func (s *Store) AssignPrincipalRole(ctx context.Context, req authkit.AssignPrincipalRoleRequest) error { - if err := ctx.Err(); err != nil { - return err - } - if req.PrincipalID == "" { - return errors.New("postgres: principal ID is required") - } - if req.RoleID == "" { - return errors.New("postgres: role ID is required") - } - - if _, err := s.pool.Exec( - ctx, - `insert into authkit_principal_roles (principal_id, role_id) - values ($1, $2) - on conflict (principal_id, role_id) do nothing`, - req.PrincipalID, - req.RoleID, - ); err != nil { - if isPostgresCode(err, foreignKeyViolation) { - return fmt.Errorf( - "postgres: principal %q or role %q does not exist", - req.PrincipalID, - req.RoleID, - ) - } - - return fmt.Errorf("postgres: assign principal role: %w", err) - } - - return nil -} - -// UnassignPrincipalRole removes a principal from a local role. -func (s *Store) UnassignPrincipalRole(ctx context.Context, req authkit.UnassignPrincipalRoleRequest) error { - if err := ctx.Err(); err != nil { - return err - } - if req.PrincipalID == "" { - return errors.New("postgres: principal ID is required") - } - if req.RoleID == "" { - return errors.New("postgres: role ID is required") - } - - exists, err := s.principalExists(ctx, req.PrincipalID) - if err != nil { - return err - } - if !exists { - return authkit.ErrPrincipalNotFound - } - roleExists, err := s.roleExists(ctx, req.RoleID) - if err != nil { - return err - } - if !roleExists { - return fmt.Errorf("postgres: role %q does not exist", req.RoleID) - } - - if _, err := s.pool.Exec( - ctx, - `delete from authkit_principal_roles where principal_id = $1 and role_id = $2`, - req.PrincipalID, - req.RoleID, - ); err != nil { - return fmt.Errorf("postgres: unassign principal role: %w", err) - } - - return nil -} - -// ListPrincipalRoleAssignments returns role assignments for a principal. -func (s *Store) ListPrincipalRoleAssignments( - ctx context.Context, - principalID string, -) ([]authkit.PrincipalRoleAssignment, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - if principalID == "" { - return nil, errors.New("postgres: principal ID is required") - } - - exists, err := s.principalExists(ctx, principalID) - if err != nil { - return nil, err - } - if !exists { - return nil, authkit.ErrPrincipalNotFound - } - - rows, err := s.pool.Query( - ctx, - `select principal_id, role_id - from authkit_principal_roles - where principal_id = $1 - order by role_id`, - principalID, - ) - if err != nil { - return nil, fmt.Errorf("postgres: list principal role assignments: %w", err) - } - defer rows.Close() - - var assignments []authkit.PrincipalRoleAssignment - for rows.Next() { - var assignment authkit.PrincipalRoleAssignment - if err := rows.Scan(&assignment.PrincipalID, &assignment.RoleID); err != nil { - return nil, fmt.Errorf("postgres: scan principal role assignment: %w", err) - } - assignments = append(assignments, assignment) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("postgres: read principal role assignments: %w", err) - } - - return assignments, nil -} - -// ResolvePrincipalActions returns the distinct actions granted to principalID through roles. -func (s *Store) ResolvePrincipalActions(ctx context.Context, principalID string) ([]string, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - if principalID == "" { - return nil, errors.New("postgres: principal ID is required") - } - - exists, err := s.principalExists(ctx, principalID) - if err != nil { - return nil, err - } - if !exists { - return nil, fmt.Errorf("postgres: principal %q does not exist", principalID) - } - - rows, err := s.pool.Query( - ctx, - `select distinct ra.action - from authkit_principal_roles as pr - join authkit_role_actions as ra on ra.role_id = pr.role_id - where pr.principal_id = $1 - order by ra.action`, - principalID, - ) - if err != nil { - return nil, fmt.Errorf("postgres: resolve principal actions: %w", err) - } - defer rows.Close() - - var actions []string - for rows.Next() { - var action string - if err := rows.Scan(&action); err != nil { - return nil, fmt.Errorf("postgres: scan principal action: %w", err) - } - actions = append(actions, action) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("postgres: read principal actions: %w", err) - } - - return actions, nil -} - -// CreateProvisioningRule creates a provisioning rule in PostgreSQL. -func (s *Store) CreateProvisioningRule( - ctx context.Context, - req authkit.CreateProvisioningRuleRequest, -) (authkit.ProvisioningRule, error) { - if err := ctx.Err(); err != nil { - return authkit.ProvisioningRule{}, err - } - - rule := provisioningRuleFromCreate(req) - tx, err := s.pool.Begin(ctx) - if err != nil { - return authkit.ProvisioningRule{}, fmt.Errorf("postgres: begin create provisioning rule: %w", err) - } - defer func() { - _ = tx.Rollback(ctx) - }() - - if validationErr := validateProvisioningRule(ctx, tx, rule); validationErr != nil { - return authkit.ProvisioningRule{}, validationErr - } - if err := insertProvisioningRule(ctx, tx, rule); err != nil { - return authkit.ProvisioningRule{}, err - } - if err := insertProvisioningRuleRoles(ctx, tx, rule.ID, rule.AssignRoleIDs); err != nil { - return authkit.ProvisioningRule{}, err - } - if err := tx.Commit(ctx); err != nil { - return authkit.ProvisioningRule{}, fmt.Errorf("postgres: commit create provisioning rule: %w", err) - } - - return cloneProvisioningRule(rule), nil -} - -// UpdateProvisioningRule replaces a provisioning rule in PostgreSQL. -func (s *Store) UpdateProvisioningRule( - ctx context.Context, - req authkit.UpdateProvisioningRuleRequest, -) (authkit.ProvisioningRule, error) { - if err := ctx.Err(); err != nil { - return authkit.ProvisioningRule{}, err - } - - rule := provisioningRuleFromUpdate(req) - tx, err := s.pool.Begin(ctx) - if err != nil { - return authkit.ProvisioningRule{}, fmt.Errorf("postgres: begin update provisioning rule: %w", err) - } - defer func() { - _ = tx.Rollback(ctx) - }() - - exists, err := provisioningRuleExists(ctx, tx, rule.ID) - if err != nil { - return authkit.ProvisioningRule{}, err - } - if !exists { - return authkit.ProvisioningRule{}, authkit.ErrProvisioningRuleNotFound - } - if validationErr := validateProvisioningRule(ctx, tx, rule); validationErr != nil { - return authkit.ProvisioningRule{}, validationErr - } - - tag, err := tx.Exec( - ctx, - `update authkit_provisioning_rules - set display_name = $2, - provider = $3, - condition = $4, - enabled = $5, - updated_at = now() - where id = $1`, - rule.ID, - rule.DisplayName, - rule.Provider, - rule.Condition, - rule.Enabled, - ) - if err != nil { - return authkit.ProvisioningRule{}, fmt.Errorf("postgres: update provisioning rule: %w", err) - } - if tag.RowsAffected() == 0 { - return authkit.ProvisioningRule{}, authkit.ErrProvisioningRuleNotFound - } - if _, err := tx.Exec(ctx, `delete from authkit_provisioning_rule_roles where rule_id = $1`, rule.ID); err != nil { - return authkit.ProvisioningRule{}, fmt.Errorf("postgres: clear provisioning rule roles: %w", err) - } - if err := insertProvisioningRuleRoles(ctx, tx, rule.ID, rule.AssignRoleIDs); err != nil { - return authkit.ProvisioningRule{}, err - } - if err := tx.Commit(ctx); err != nil { - return authkit.ProvisioningRule{}, fmt.Errorf("postgres: commit update provisioning rule: %w", err) - } - - return cloneProvisioningRule(rule), nil -} - -// DeleteProvisioningRule deletes a provisioning rule from PostgreSQL. -func (s *Store) DeleteProvisioningRule(ctx context.Context, id string) error { - if err := ctx.Err(); err != nil { - return err - } - if id == "" { - return errors.New("postgres: provisioning rule ID is required") - } - - tag, err := s.pool.Exec(ctx, `delete from authkit_provisioning_rules where id = $1`, id) - if err != nil { - return fmt.Errorf("postgres: delete provisioning rule: %w", err) - } - if tag.RowsAffected() == 0 { - return authkit.ErrProvisioningRuleNotFound - } - - return nil -} - -// FindProvisioningRule returns a provisioning rule by ID. -func (s *Store) FindProvisioningRule(ctx context.Context, id string) (authkit.ProvisioningRule, error) { - if err := ctx.Err(); err != nil { - return authkit.ProvisioningRule{}, err - } - if id == "" { - return authkit.ProvisioningRule{}, errors.New("postgres: provisioning rule ID is required") - } - - rule, err := findProvisioningRule(ctx, s.pool, id) - if errors.Is(err, pgx.ErrNoRows) { - return authkit.ProvisioningRule{}, authkit.ErrProvisioningRuleNotFound - } - if err != nil { - return authkit.ProvisioningRule{}, err - } - - return rule, nil -} - -// ListProvisioningRules returns all provisioning rules. -func (s *Store) ListProvisioningRules(ctx context.Context) ([]authkit.ProvisioningRule, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - - rows, err := s.pool.Query( - ctx, - `select r.id, r.display_name, r.provider, r.condition, r.enabled, - coalesce(array_agg(rr.role_id order by rr.role_id) - filter (where rr.role_id is not null), '{}'::text[]) as role_ids - from authkit_provisioning_rules as r - left join authkit_provisioning_rule_roles as rr on rr.rule_id = r.id - group by r.id - order by r.id`, - ) - if err != nil { - return nil, fmt.Errorf("postgres: list provisioning rules: %w", err) - } - defer rows.Close() - - var rules []authkit.ProvisioningRule - for rows.Next() { - rule, scanErr := scanProvisioningRule(rows) - if scanErr != nil { - return nil, scanErr - } - rules = append(rules, rule) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("postgres: read provisioning rules: %w", err) - } - - return rules, nil -} - -func createPrincipal( - ctx context.Context, - exec sqlExecutor, - req authkit.CreatePrincipalRequest, -) (authkit.Principal, error) { - attributes, err := encodeAttributes(req.Attributes) - if err != nil { - return authkit.Principal{}, err - } - - for range principalIDAttempts { - principal := authkit.Principal{ - ID: principalIDPrefix + rand.Text(), - Kind: req.Kind, - DisplayName: req.DisplayName, - Attributes: cloneAttributes(req.Attributes), - } - _, err := exec.Exec( - ctx, - `insert into authkit_principals (id, kind, display_name, attributes) - values ($1, $2, $3, nullif($4, '')::jsonb)`, - principal.ID, - string(principal.Kind), - principal.DisplayName, - attributes, - ) - if err == nil { - return principal, nil - } - if !isPostgresCode(err, uniqueViolation) { - return authkit.Principal{}, fmt.Errorf("postgres: create principal: %w", err) - } - } - - return authkit.Principal{}, errors.New("postgres: create principal: generated duplicate principal IDs") -} - -// LinkIdentity links an external identity to an existing principal. -func (s *Store) LinkIdentity( - ctx context.Context, - req authkit.LinkIdentityRequest, -) (authkit.ExternalIdentity, error) { - if err := ctx.Err(); err != nil { - return authkit.ExternalIdentity{}, err - } - if req.Provider == "" { - return authkit.ExternalIdentity{}, errors.New("postgres: provider is required") - } - if req.Subject == "" { - return authkit.ExternalIdentity{}, errors.New("postgres: subject is required") - } - if req.PrincipalID == "" { - return authkit.ExternalIdentity{}, errors.New("postgres: principal ID is required") - } - - link, err := s.findIdentityLink(ctx, req.Provider, req.Subject) - if err == nil { - if link.PrincipalID == req.PrincipalID { - return link, nil - } - - return authkit.ExternalIdentity{}, fmt.Errorf( - "postgres: identity %q/%q is already linked to principal %q", - req.Provider, - req.Subject, - link.PrincipalID, - ) - } - if !errors.Is(err, pgx.ErrNoRows) { - return authkit.ExternalIdentity{}, fmt.Errorf("postgres: find identity link: %w", err) - } - - link = authkit.ExternalIdentity(req) - if _, err := s.pool.Exec( - ctx, - `insert into authkit_external_identities (provider, subject, principal_id) - values ($1, $2, $3)`, - link.Provider, - link.Subject, - link.PrincipalID, - ); err != nil { - if isPostgresCode(err, uniqueViolation) { - return s.resolveIdentityLinkConflict(ctx, req) - } - if isPostgresCode(err, foreignKeyViolation) { - return authkit.ExternalIdentity{}, fmt.Errorf( - "postgres: principal %q does not exist", - req.PrincipalID, - ) - } - - return authkit.ExternalIdentity{}, fmt.Errorf("postgres: link identity: %w", err) - } - - return link, nil -} - -// ResolveIdentity returns the principal linked to identity. -func (s *Store) ResolveIdentity( - ctx context.Context, - identity authkit.Identity, -) (*authkit.Principal, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - if identity.Provider == "" || identity.Subject == "" { - return nil, fmt.Errorf("%w: provider and subject are required", authkit.ErrUnresolvedIdentity) - } - - var principal authkit.Principal - var kind string - var attributes string - err := s.pool.QueryRow( - ctx, - `select p.id, p.kind, p.display_name, coalesce(p.attributes::text, '') - from authkit_external_identities as i - join authkit_principals as p on p.id = i.principal_id - where i.provider = $1 and i.subject = $2`, - identity.Provider, - identity.Subject, - ).Scan(&principal.ID, &kind, &principal.DisplayName, &attributes) - if errors.Is(err, pgx.ErrNoRows) { - return nil, fmt.Errorf( - "%w: identity %q/%q is not linked", - authkit.ErrUnresolvedIdentity, - identity.Provider, - identity.Subject, - ) - } - if err != nil { - return nil, fmt.Errorf("postgres: resolve identity: %w", err) - } - - principal.Kind = authkit.PrincipalKind(kind) - principal.Attributes, err = decodeAttributes(attributes) - if err != nil { - return nil, err - } - - return &principal, nil -} - -// ProvisionIdentity creates and links a principal for identity or returns the existing link. -func (s *Store) ProvisionIdentity( - ctx context.Context, - req authkit.ProvisionIdentityRequest, -) (authkit.ProvisionIdentityResult, error) { - if err := ctx.Err(); err != nil { - return authkit.ProvisionIdentityResult{}, err - } - if req.Identity.Provider == "" || req.Identity.Subject == "" { - return authkit.ProvisionIdentityResult{}, fmt.Errorf( - "%w: provider and subject are required", - authkit.ErrUnresolvedIdentity, - ) - } - if req.Principal.Kind != authkit.PrincipalKindUser && req.Principal.Kind != authkit.PrincipalKindService { - return authkit.ProvisionIdentityResult{}, fmt.Errorf( - "postgres: unsupported principal kind %q", - req.Principal.Kind, - ) - } - - tx, err := s.pool.Begin(ctx) - if err != nil { - return authkit.ProvisionIdentityResult{}, fmt.Errorf("postgres: begin provision identity: %w", err) - } - defer func() { - _ = tx.Rollback(ctx) - }() - - existing, err := findProvisionedIdentity(ctx, tx, req.Identity) - if err == nil { - if commitErr := tx.Commit(ctx); commitErr != nil { - return authkit.ProvisionIdentityResult{}, fmt.Errorf("postgres: commit provision identity: %w", commitErr) - } - - return existing, nil - } - if !errors.Is(err, pgx.ErrNoRows) { - return authkit.ProvisionIdentityResult{}, fmt.Errorf("postgres: find provisioned identity: %w", err) - } - - principal, err := createPrincipal(ctx, tx, req.Principal) - if err != nil { - return authkit.ProvisionIdentityResult{}, err - } - - link := authkit.ExternalIdentity{ - Provider: req.Identity.Provider, - Subject: req.Identity.Subject, - PrincipalID: principal.ID, - } - if _, err := tx.Exec( - ctx, - `insert into authkit_external_identities (provider, subject, principal_id) - values ($1, $2, $3)`, - link.Provider, - link.Subject, - link.PrincipalID, - ); err != nil { - return s.handleProvisionIdentityLinkError(ctx, tx, req.Identity, err) - } - if err := assignInitialRoles(ctx, tx, principal.ID, req.InitialRoleIDs); err != nil { - return authkit.ProvisionIdentityResult{}, err - } - if err := tx.Commit(ctx); err != nil { - return authkit.ProvisionIdentityResult{}, fmt.Errorf("postgres: commit provision identity: %w", err) - } - - return authkit.ProvisionIdentityResult{ - Principal: principal, - Link: link, - Created: true, - }, nil -} - -func (s *Store) handleProvisionIdentityLinkError( - ctx context.Context, - tx pgx.Tx, - identity authkit.Identity, - err error, -) (authkit.ProvisionIdentityResult, error) { - if !isPostgresCode(err, uniqueViolation) { - return authkit.ProvisionIdentityResult{}, fmt.Errorf("postgres: link provisioned identity: %w", err) - } - if rollbackErr := tx.Rollback(ctx); rollbackErr != nil { - return authkit.ProvisionIdentityResult{}, fmt.Errorf( - "postgres: rollback provision identity conflict: %w", - rollbackErr, - ) - } - - winner, findErr := findProvisionedIdentity(ctx, s.pool, identity) - if findErr != nil { - return authkit.ProvisionIdentityResult{}, fmt.Errorf( - "postgres: find provisioned identity conflict: %w", - findErr, - ) - } - - return winner, nil -} - -// CreateToken stores token. -func (s *Store) CreateToken(ctx context.Context, token apikey.StoredToken) error { - if err := ctx.Err(); err != nil { - return err - } - if token.ID == "" { - return errors.New("postgres: token ID is required") - } - - if _, err := s.pool.Exec( - ctx, - `insert into authkit_api_tokens - (id, principal_id, name, secret_hash, expires_at, last_used_at, revoked_at) - values ($1, $2, $3, $4, $5, $6, $7)`, - token.ID, - token.PrincipalID, - token.Name, - token.SecretHash[:], - token.ExpiresAt, - token.LastUsedAt, - token.RevokedAt, - ); err != nil { - if isPostgresCode(err, foreignKeyViolation) { - return fmt.Errorf( - "%w: postgres: principal %q does not exist", - authkit.ErrPrincipalNotFound, - token.PrincipalID, - ) - } - - return fmt.Errorf("postgres: create token: %w", err) - } - - return nil -} - -// FindToken returns the token for tokenID. -func (s *Store) FindToken(ctx context.Context, tokenID string) (apikey.StoredToken, error) { - if err := ctx.Err(); err != nil { - return apikey.StoredToken{}, err - } - - token, err := s.findToken(ctx, tokenID) - if errors.Is(err, pgx.ErrNoRows) { - return apikey.StoredToken{}, apikey.ErrTokenNotFound - } - if err != nil { - return apikey.StoredToken{}, err - } - - return token, nil -} - -// UpdateTokenLastUsed records the most recent successful use of tokenID. -func (s *Store) UpdateTokenLastUsed(ctx context.Context, tokenID string, usedAt time.Time) error { - if err := ctx.Err(); err != nil { - return err - } - - tag, err := s.pool.Exec( - ctx, - `update authkit_api_tokens set last_used_at = $2 where id = $1`, - tokenID, - usedAt, - ) - if err != nil { - return fmt.Errorf("postgres: update token last used: %w", err) - } - if tag.RowsAffected() == 0 { - return apikey.ErrTokenNotFound - } - - return nil -} - -// RevokeToken records tokenID as revoked. -func (s *Store) RevokeToken(ctx context.Context, tokenID string, revokedAt time.Time) error { - if err := ctx.Err(); err != nil { - return err - } - - tag, err := s.pool.Exec( - ctx, - `update authkit_api_tokens set revoked_at = $2 where id = $1`, - tokenID, - revokedAt, - ) - if err != nil { - return fmt.Errorf("postgres: revoke token: %w", err) - } - if tag.RowsAffected() == 0 { - return apikey.ErrTokenNotFound - } - - return nil -} - -// ListPrincipalTokenMetadata returns API-token metadata for principalID. -func (s *Store) ListPrincipalTokenMetadata( - ctx context.Context, - principalID string, -) ([]apikey.TokenMetadata, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - if principalID == "" { - return nil, errors.New("postgres: principal ID is required") - } - - exists, err := s.principalExists(ctx, principalID) - if err != nil { - return nil, err - } - if !exists { - return nil, authkit.ErrPrincipalNotFound - } - - rows, err := s.pool.Query( - ctx, - `select id, principal_id, name, expires_at, last_used_at, revoked_at - from authkit_api_tokens - where principal_id = $1 - order by id`, - principalID, - ) - if err != nil { - return nil, fmt.Errorf("postgres: list principal API token metadata: %w", err) - } - defer rows.Close() - - var tokens []apikey.TokenMetadata - for rows.Next() { - token, scanErr := scanTokenMetadata(rows) - if scanErr != nil { - return nil, scanErr - } - tokens = append(tokens, token) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("postgres: read principal API token metadata: %w", err) - } - - return tokens, nil -} - -// TrustProvider stores provider as trusted for its issuer. -func (s *Store) TrustProvider(ctx context.Context, provider oidc.Provider) (oidc.Provider, error) { - if err := ctx.Err(); err != nil { - return oidc.Provider{}, err - } - if err := provider.Validate(); err != nil { - return oidc.Provider{}, err - } - - trusted := cloneProvider(provider) - signingAlgorithms := trusted.SupportedSigningAlgorithms - if signingAlgorithms == nil { - signingAlgorithms = []string{} - } - forwardedClaims, err := encodeClaimPaths(trusted.ForwardedClaims) - if err != nil { - return oidc.Provider{}, err - } - if _, err := s.pool.Exec( - ctx, - `insert into authkit_oidc_providers - (issuer, jwks_url, audiences, supported_signing_algorithms, forwarded_claims) - values ($1, $2, $3, $4, $5::jsonb) - on conflict (issuer) do update set - jwks_url = excluded.jwks_url, - audiences = excluded.audiences, - supported_signing_algorithms = excluded.supported_signing_algorithms, - forwarded_claims = excluded.forwarded_claims, - updated_at = now()`, - trusted.Issuer, - trusted.JWKSURL, - trusted.Audiences, - signingAlgorithms, - forwardedClaims, - ); err != nil { - return oidc.Provider{}, fmt.Errorf("postgres: trust OIDC provider: %w", err) - } - - return cloneProvider(trusted), nil -} - -// FindProvider returns the trusted OIDC provider for issuer. -func (s *Store) FindProvider(ctx context.Context, issuer string) (oidc.Provider, error) { - if err := ctx.Err(); err != nil { - return oidc.Provider{}, err - } - - var provider oidc.Provider - var forwardedClaims string - err := s.pool.QueryRow( - ctx, - `select issuer, audiences, jwks_url, supported_signing_algorithms, - coalesce(forwarded_claims::text, '[]') - from authkit_oidc_providers - where issuer = $1`, - issuer, - ).Scan( - &provider.Issuer, - &provider.Audiences, - &provider.JWKSURL, - &provider.SupportedSigningAlgorithms, - &forwardedClaims, - ) - if errors.Is(err, pgx.ErrNoRows) { - return oidc.Provider{}, oidc.ErrProviderNotFound - } - if err != nil { - return oidc.Provider{}, fmt.Errorf("postgres: find OIDC provider: %w", err) - } - provider.ForwardedClaims, err = decodeClaimPaths(forwardedClaims) - if err != nil { - return oidc.Provider{}, err - } - if err := provider.Validate(); err != nil { - return oidc.Provider{}, fmt.Errorf("postgres: invalid OIDC provider %q: %w", issuer, err) - } - - return cloneProvider(provider), nil -} - -func (s *Store) findIdentityLink( - ctx context.Context, - provider string, - subject string, -) (authkit.ExternalIdentity, error) { - var link authkit.ExternalIdentity - err := s.pool.QueryRow( - ctx, - `select provider, subject, principal_id - from authkit_external_identities - where provider = $1 and subject = $2`, - provider, - subject, - ).Scan(&link.Provider, &link.Subject, &link.PrincipalID) - if err != nil { - return authkit.ExternalIdentity{}, err - } - - return link, nil -} - -func (s *Store) principalExists(ctx context.Context, principalID string) (bool, error) { - var exists bool - if err := s.pool.QueryRow( - ctx, - `select exists(select 1 from authkit_principals where id = $1)`, - principalID, - ).Scan(&exists); err != nil { - return false, fmt.Errorf("postgres: find principal: %w", err) - } - - return exists, nil -} - -func (s *Store) roleExists(ctx context.Context, roleID string) (bool, error) { - var exists bool - if err := s.pool.QueryRow( - ctx, - `select exists(select 1 from authkit_roles where id = $1)`, - roleID, - ).Scan(&exists); err != nil { - return false, fmt.Errorf("postgres: find role: %w", err) - } - - return exists, nil -} - -func scanPrincipal(row scanner) (authkit.Principal, error) { - var principal authkit.Principal - var kind string - var attributes string - if err := row.Scan(&principal.ID, &kind, &principal.DisplayName, &attributes); err != nil { - return authkit.Principal{}, err - } - - principal.Kind = authkit.PrincipalKind(kind) - attrs, err := decodeAttributes(attributes) - if err != nil { - return authkit.Principal{}, err - } - principal.Attributes = attrs - - return principal, nil -} - -func findProvisionedIdentity( - ctx context.Context, - query rowQuerier, - identity authkit.Identity, -) (authkit.ProvisionIdentityResult, error) { - var principal authkit.Principal - var kind string - var attributes string - var link authkit.ExternalIdentity - err := query.QueryRow( - ctx, - `select p.id, p.kind, p.display_name, coalesce(p.attributes::text, ''), - i.provider, i.subject, i.principal_id - from authkit_external_identities as i - join authkit_principals as p on p.id = i.principal_id - where i.provider = $1 and i.subject = $2`, - identity.Provider, - identity.Subject, - ).Scan( - &principal.ID, - &kind, - &principal.DisplayName, - &attributes, - &link.Provider, - &link.Subject, - &link.PrincipalID, - ) - if err != nil { - return authkit.ProvisionIdentityResult{}, err - } - - principal.Kind = authkit.PrincipalKind(kind) - principal.Attributes, err = decodeAttributes(attributes) - if err != nil { - return authkit.ProvisionIdentityResult{}, err - } - - return authkit.ProvisionIdentityResult{ - Principal: principal, - Link: link, - Created: false, - }, nil -} - -func validateProvisioningRule(ctx context.Context, query queryExecutor, rule authkit.ProvisioningRule) error { - if rule.ID == "" { - return errors.New("postgres: provisioning rule ID is required") - } - if rule.Provider == "" { - return errors.New("postgres: provisioning rule provider is required") - } - if err := provisioning.ValidateCondition(rule.Condition); err != nil { - return fmt.Errorf("postgres: %w", err) - } - if err := validateRequiredStrings("provisioning rule role ID", rule.AssignRoleIDs); err != nil { - return fmt.Errorf("postgres: %w", err) - } - - _, err := findTrustedProvider(ctx, query, rule.Provider) - if errors.Is(err, pgx.ErrNoRows) { - return fmt.Errorf("postgres: provider %q is not trusted", rule.Provider) - } - if err != nil { - return err - } - - var roleCount int - if err := query.QueryRow( - ctx, - `select count(*) from authkit_roles where id = any($1)`, - rule.AssignRoleIDs, - ).Scan(&roleCount); err != nil { - return fmt.Errorf("postgres: validate provisioning rule roles: %w", err) - } - if roleCount != len(rule.AssignRoleIDs) { - return errors.New("postgres: provisioning rule references missing role") - } - - return nil -} - -func provisioningRuleExists(ctx context.Context, query rowQuerier, id string) (bool, error) { - var exists bool - if err := query.QueryRow( - ctx, - `select exists(select 1 from authkit_provisioning_rules where id = $1)`, - id, - ).Scan(&exists); err != nil { - return false, fmt.Errorf("postgres: find provisioning rule: %w", err) - } - - return exists, nil -} - -func insertProvisioningRule(ctx context.Context, exec sqlExecutor, rule authkit.ProvisioningRule) error { - if _, err := exec.Exec( - ctx, - `insert into authkit_provisioning_rules - (id, display_name, provider, condition, enabled) - values ($1, $2, $3, $4, $5)`, - rule.ID, - rule.DisplayName, - rule.Provider, - rule.Condition, - rule.Enabled, - ); err != nil { - if isPostgresCode(err, uniqueViolation) { - return fmt.Errorf("postgres: provisioning rule %q already exists", rule.ID) - } - - return fmt.Errorf("postgres: create provisioning rule: %w", err) - } - - return nil -} - -func insertProvisioningRuleRoles( - ctx context.Context, - exec sqlExecutor, - ruleID string, - roleIDs []string, -) error { - if len(roleIDs) == 0 { - return nil - } - if _, err := exec.Exec( - ctx, - `insert into authkit_provisioning_rule_roles (rule_id, role_id) - select $1, unnest($2::text[]) - on conflict (rule_id, role_id) do nothing`, - ruleID, - roleIDs, - ); err != nil { - return fmt.Errorf("postgres: assign provisioning rule roles: %w", err) - } - - return nil -} - -func findProvisioningRule( - ctx context.Context, - query rowQuerier, - id string, -) (authkit.ProvisioningRule, error) { - rule, err := scanProvisioningRule(query.QueryRow( - ctx, - `select r.id, r.display_name, r.provider, r.condition, r.enabled, - coalesce(array_agg(rr.role_id order by rr.role_id) - filter (where rr.role_id is not null), '{}'::text[]) as role_ids - from authkit_provisioning_rules as r - left join authkit_provisioning_rule_roles as rr on rr.rule_id = r.id - where r.id = $1 - group by r.id`, - id, - )) - if err != nil { - return authkit.ProvisioningRule{}, err - } - - return rule, nil -} - -func scanProvisioningRule(row scanner) (authkit.ProvisioningRule, error) { - var rule authkit.ProvisioningRule - if err := row.Scan( - &rule.ID, - &rule.DisplayName, - &rule.Provider, - &rule.Condition, - &rule.Enabled, - &rule.AssignRoleIDs, - ); err != nil { - return authkit.ProvisioningRule{}, err - } - - return cloneProvisioningRule(rule), nil -} - -func assignInitialRoles(ctx context.Context, exec sqlExecutor, principalID string, roleIDs []string) error { - roleIDs = uniqueStrings(roleIDs) - if err := validateNonEmptyStrings("initial role ID", roleIDs); err != nil { - return fmt.Errorf("postgres: %w", err) - } - if len(roleIDs) == 0 { - return nil - } - - if _, err := exec.Exec( - ctx, - `insert into authkit_principal_roles (principal_id, role_id) - select $1, unnest($2::text[]) - on conflict (principal_id, role_id) do nothing`, - principalID, - roleIDs, - ); err != nil { - if isPostgresCode(err, foreignKeyViolation) { - return errors.New("postgres: initial role does not exist") - } - - return fmt.Errorf("postgres: assign initial roles: %w", err) - } - - return nil -} - -func findTrustedProvider(ctx context.Context, query rowQuerier, issuer string) (oidc.Provider, error) { - var provider oidc.Provider - var forwardedClaims string - err := query.QueryRow( - ctx, - `select issuer, audiences, jwks_url, supported_signing_algorithms, - coalesce(forwarded_claims::text, '[]') - from authkit_oidc_providers - where issuer = $1`, - issuer, - ).Scan( - &provider.Issuer, - &provider.Audiences, - &provider.JWKSURL, - &provider.SupportedSigningAlgorithms, - &forwardedClaims, - ) - if err != nil { - return oidc.Provider{}, err - } - - provider.ForwardedClaims, err = decodeClaimPaths(forwardedClaims) - if err != nil { - return oidc.Provider{}, err - } - - return cloneProvider(provider), nil -} - -func (s *Store) resolveIdentityLinkConflict( - ctx context.Context, - req authkit.LinkIdentityRequest, -) (authkit.ExternalIdentity, error) { - link, err := s.findIdentityLink(ctx, req.Provider, req.Subject) - if err != nil { - return authkit.ExternalIdentity{}, fmt.Errorf("postgres: find identity link conflict: %w", err) - } - if link.PrincipalID == req.PrincipalID { - return link, nil - } - - return authkit.ExternalIdentity{}, fmt.Errorf( - "postgres: identity %q/%q is already linked to principal %q", - req.Provider, - req.Subject, - link.PrincipalID, - ) -} - -func (s *Store) findToken(ctx context.Context, tokenID string) (apikey.StoredToken, error) { - var token apikey.StoredToken - var secretHash []byte - var lastUsedAt pgtype.Timestamptz - var revokedAt pgtype.Timestamptz - err := s.pool.QueryRow( - ctx, - `select id, principal_id, name, secret_hash, expires_at, last_used_at, revoked_at - from authkit_api_tokens - where id = $1`, - tokenID, - ).Scan( - &token.ID, - &token.PrincipalID, - &token.Name, - &secretHash, - &token.ExpiresAt, - &lastUsedAt, - &revokedAt, - ) - if err != nil { - return apikey.StoredToken{}, err - } - if len(secretHash) != sha256.Size { - return apikey.StoredToken{}, fmt.Errorf( - "postgres: token %q has invalid secret hash length %d", - tokenID, - len(secretHash), - ) - } - - copy(token.SecretHash[:], secretHash) - token.ExpiresAt = token.ExpiresAt.UTC() - token.LastUsedAt = timeFromTimestamptz(lastUsedAt) - token.RevokedAt = timeFromTimestamptz(revokedAt) - - return token, nil -} - -func scanTokenMetadata(row scanner) (apikey.TokenMetadata, error) { - var token apikey.TokenMetadata - var lastUsedAt pgtype.Timestamptz - var revokedAt pgtype.Timestamptz - if err := row.Scan( - &token.ID, - &token.PrincipalID, - &token.Name, - &token.ExpiresAt, - &lastUsedAt, - &revokedAt, - ); err != nil { - return apikey.TokenMetadata{}, err - } - - token.ExpiresAt = token.ExpiresAt.UTC() - token.LastUsedAt = timeFromTimestamptz(lastUsedAt) - token.RevokedAt = timeFromTimestamptz(revokedAt) - - return token, nil -} - -func encodeAttributes(attrs map[string]any) (string, error) { - if len(attrs) == 0 { - return "", nil - } - - encoded, err := json.Marshal(attrs) - if err != nil { - return "", fmt.Errorf("postgres: encode principal attributes: %w", err) - } - - return string(encoded), nil -} - -func encodeClaimPaths(paths []authkit.ClaimPath) (string, error) { - if len(paths) == 0 { - return "[]", nil - } - - encoded, err := json.Marshal(paths) - if err != nil { - return "", fmt.Errorf("postgres: encode claim paths: %w", err) - } - - return string(encoded), nil -} - -func decodeClaimPaths(encoded string) ([]authkit.ClaimPath, error) { - if encoded == "" || encoded == "null" { - return nil, nil - } - - var paths []authkit.ClaimPath - if err := json.Unmarshal([]byte(encoded), &paths); err != nil { - return nil, fmt.Errorf("postgres: decode claim paths: %w", err) - } - if len(paths) == 0 { - return nil, nil - } - - return cloneClaimPaths(paths), nil -} - -func decodeAttributes(encoded string) (map[string]any, error) { - if encoded == "" || encoded == "null" { - //nolint:nilnil // Nil attributes are the normalized zero value for principals. - return nil, nil - } - - var attrs map[string]any - if err := json.Unmarshal([]byte(encoded), &attrs); err != nil { - return nil, fmt.Errorf("postgres: decode principal attributes: %w", err) - } - if len(attrs) == 0 { - //nolint:nilnil // Nil attributes are the normalized zero value for principals. - return nil, nil - } - - return attrs, nil -} - -func cloneAttributes(attrs map[string]any) map[string]any { - if len(attrs) == 0 { - return nil - } - - cloned := make(map[string]any, len(attrs)) - maps.Copy(cloned, attrs) - - return cloned -} - -func cloneProvider(provider oidc.Provider) oidc.Provider { - provider.Audiences = cloneStrings(provider.Audiences) - provider.SupportedSigningAlgorithms = cloneStrings(provider.SupportedSigningAlgorithms) - provider.ForwardedClaims = cloneClaimPaths(provider.ForwardedClaims) - - return provider -} - -func cloneStrings(values []string) []string { - if len(values) == 0 { - return nil - } - - cloned := make([]string, len(values)) - copy(cloned, values) - - return cloned -} - -func cloneClaimPaths(paths []authkit.ClaimPath) []authkit.ClaimPath { - if len(paths) == 0 { - return nil - } - - cloned := make([]authkit.ClaimPath, len(paths)) - for i, path := range paths { - cloned[i] = cloneClaimPath(path) - } - - return cloned -} - -func cloneClaimPath(path authkit.ClaimPath) authkit.ClaimPath { - if len(path) == 0 { - return nil - } - - cloned := make(authkit.ClaimPath, len(path)) - copy(cloned, path) - - return cloned -} - -func provisioningRuleFromCreate(req authkit.CreateProvisioningRuleRequest) authkit.ProvisioningRule { - return normalizeProvisioningRule(authkit.ProvisioningRule{ - ID: req.ID, - DisplayName: req.DisplayName, - Provider: req.Provider, - Condition: provisioning.NormalizeCondition(req.Condition), - AssignRoleIDs: cloneStrings(req.AssignRoleIDs), - Enabled: req.Enabled, - }) -} - -func provisioningRuleFromUpdate(req authkit.UpdateProvisioningRuleRequest) authkit.ProvisioningRule { - return normalizeProvisioningRule(authkit.ProvisioningRule{ - ID: req.ID, - DisplayName: req.DisplayName, - Provider: req.Provider, - Condition: provisioning.NormalizeCondition(req.Condition), - AssignRoleIDs: cloneStrings(req.AssignRoleIDs), - Enabled: req.Enabled, - }) -} - -func normalizeProvisioningRule(rule authkit.ProvisioningRule) authkit.ProvisioningRule { - rule.AssignRoleIDs = uniqueStrings(rule.AssignRoleIDs) - - return rule -} - -func cloneProvisioningRule(rule authkit.ProvisioningRule) authkit.ProvisioningRule { - rule.AssignRoleIDs = cloneStrings(rule.AssignRoleIDs) - - return rule -} - -func validateNonEmptyStrings(name string, values []string) error { - for i, value := range values { - if value == "" { - return fmt.Errorf("%s %d is required", name, i) - } - } - - return nil -} - -func validateRequiredStrings(name string, values []string) error { - if len(values) == 0 { - return fmt.Errorf("%s is required", name) - } - - return validateNonEmptyStrings(name, values) -} - -func uniqueStrings(values []string) []string { - if len(values) == 0 { - return nil - } - - unique := make([]string, 0, len(values)) - seen := make(map[string]struct{}, len(values)) - for _, value := range values { - if _, ok := seen[value]; ok { - continue - } - - seen[value] = struct{}{} - unique = append(unique, value) - } - - return unique -} - -func timeFromTimestamptz(value pgtype.Timestamptz) *time.Time { - if !value.Valid { - return nil - } - - t := value.Time.UTC() - - return &t -} - -func isPostgresCode(err error, code string) bool { - var pgErr *pgconn.PgError - - return errors.As(err, &pgErr) && pgErr.Code == code -} diff --git a/store/postgres/store_test.go b/store/postgres/store_test.go index 04043b5..29faf7b 100644 --- a/store/postgres/store_test.go +++ b/store/postgres/store_test.go @@ -6,38 +6,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/meigma/authkit" - "github.com/meigma/authkit/apikey" - "github.com/meigma/authkit/oidc" - "github.com/meigma/authkit/passkey" ) -func TestStoreSatisfiesAuthkitContracts(_ *testing.T) { - var _ authkit.PrincipalCreator = (*Store)(nil) - var _ authkit.PrincipalFinder = (*Store)(nil) - var _ authkit.PrincipalLister = (*Store)(nil) - var _ authkit.RoleCreator = (*Store)(nil) - var _ authkit.RoleActionGranter = (*Store)(nil) - var _ authkit.PrincipalRoleAssigner = (*Store)(nil) - var _ authkit.PrincipalRoleUnassigner = (*Store)(nil) - var _ authkit.PrincipalRoleAssignmentLister = (*Store)(nil) - var _ authkit.PrincipalActionResolver = (*Store)(nil) - var _ authkit.IdentityLinker = (*Store)(nil) - var _ authkit.IdentityProvisioner = (*Store)(nil) - var _ authkit.PrincipalResolver = (*Store)(nil) - var _ authkit.ProvisioningRuleCreator = (*Store)(nil) - var _ authkit.ProvisioningRuleUpdater = (*Store)(nil) - var _ authkit.ProvisioningRuleDeleter = (*Store)(nil) - var _ authkit.ProvisioningRuleFinder = (*Store)(nil) - var _ authkit.ProvisioningRuleLister = (*Store)(nil) - var _ apikey.TokenStore = (*Store)(nil) - var _ apikey.TokenMetadataLister = (*Store)(nil) - var _ oidc.ProviderSource = (*Store)(nil) - var _ oidc.ProviderTrustStore = (*Store)(nil) - var _ passkey.Store = (*Store)(nil) -} - func TestNewStoreValidatesPool(t *testing.T) { store, err := NewStore(nil) diff --git a/store/postgres/token.go b/store/postgres/token.go new file mode 100644 index 0000000..a2a48c2 --- /dev/null +++ b/store/postgres/token.go @@ -0,0 +1,249 @@ +package postgres + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + + "github.com/meigma/authkit" + "github.com/meigma/authkit/apikey" +) + +// CreateToken stores token. The token's PrincipalID must reference an +// existing principal; otherwise the returned error wraps +// `authkit.ErrPrincipalNotFound`. +func (s *Store) CreateToken(ctx context.Context, token apikey.StoredToken) error { + if err := ctx.Err(); err != nil { + return err + } + if token.ID == "" { + return errors.New("postgres: token ID is required") + } + + if _, err := s.pool.Exec( + ctx, + `insert into authkit_api_tokens + (id, principal_id, name, secret_hash, expires_at, last_used_at, revoked_at) + values ($1, $2, $3, $4, $5, $6, $7)`, + token.ID, + token.PrincipalID, + token.Name, + token.SecretHash[:], + token.ExpiresAt, + token.LastUsedAt, + token.RevokedAt, + ); err != nil { + if isPostgresCode(err, foreignKeyViolation) { + return fmt.Errorf( + "%w: postgres: principal %q does not exist", + authkit.ErrPrincipalNotFound, + token.PrincipalID, + ) + } + + return fmt.Errorf("postgres: create token: %w", err) + } + + return nil +} + +// FindToken returns the token identified by tokenID, or +// `apikey.ErrTokenNotFound` if no such token exists. +func (s *Store) FindToken(ctx context.Context, tokenID string) (apikey.StoredToken, error) { + if err := ctx.Err(); err != nil { + return apikey.StoredToken{}, err + } + + token, err := s.findToken(ctx, tokenID) + if errors.Is(err, pgx.ErrNoRows) { + return apikey.StoredToken{}, apikey.ErrTokenNotFound + } + if err != nil { + return apikey.StoredToken{}, err + } + + return token, nil +} + +// UpdateTokenLastUsed records the most recent successful use of tokenID. +// Returns `apikey.ErrTokenNotFound` if no such token exists. +func (s *Store) UpdateTokenLastUsed(ctx context.Context, tokenID string, usedAt time.Time) error { + if err := ctx.Err(); err != nil { + return err + } + + tag, err := s.pool.Exec( + ctx, + `update authkit_api_tokens set last_used_at = $2 where id = $1`, + tokenID, + usedAt, + ) + if err != nil { + return fmt.Errorf("postgres: update token last used: %w", err) + } + if tag.RowsAffected() == 0 { + return apikey.ErrTokenNotFound + } + + return nil +} + +// RevokeToken records tokenID as revoked at revokedAt. Returns +// `apikey.ErrTokenNotFound` if no such token exists. +func (s *Store) RevokeToken(ctx context.Context, tokenID string, revokedAt time.Time) error { + if err := ctx.Err(); err != nil { + return err + } + + tag, err := s.pool.Exec( + ctx, + `update authkit_api_tokens set revoked_at = $2 where id = $1`, + tokenID, + revokedAt, + ) + if err != nil { + return fmt.Errorf("postgres: revoke token: %w", err) + } + if tag.RowsAffected() == 0 { + return apikey.ErrTokenNotFound + } + + return nil +} + +// ListPrincipalTokenMetadata returns API-token metadata for principalID, +// sorted by token ID. Returns `authkit.ErrPrincipalNotFound` if no such +// principal exists. The returned slice contains no secret material. +func (s *Store) ListPrincipalTokenMetadata( + ctx context.Context, + principalID string, +) ([]apikey.TokenMetadata, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if principalID == "" { + return nil, errors.New("postgres: principal ID is required") + } + + exists, err := s.principalExists(ctx, principalID) + if err != nil { + return nil, err + } + if !exists { + return nil, authkit.ErrPrincipalNotFound + } + + rows, err := s.pool.Query( + ctx, + `select id, principal_id, name, expires_at, last_used_at, revoked_at + from authkit_api_tokens + where principal_id = $1 + order by id`, + principalID, + ) + if err != nil { + return nil, fmt.Errorf("postgres: list principal API token metadata: %w", err) + } + defer rows.Close() + + var tokens []apikey.TokenMetadata + for rows.Next() { + token, scanErr := scanTokenMetadata(rows) + if scanErr != nil { + return nil, scanErr + } + tokens = append(tokens, token) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("postgres: read principal API token metadata: %w", err) + } + + return tokens, nil +} + +// findToken reads the full StoredToken (including secret hash) for tokenID. +// Returns `pgx.ErrNoRows` when no such token exists; callers translate +// that to `apikey.ErrTokenNotFound`. +func (s *Store) findToken(ctx context.Context, tokenID string) (apikey.StoredToken, error) { + var token apikey.StoredToken + var secretHash []byte + var lastUsedAt pgtype.Timestamptz + var revokedAt pgtype.Timestamptz + err := s.pool.QueryRow( + ctx, + `select id, principal_id, name, secret_hash, expires_at, last_used_at, revoked_at + from authkit_api_tokens + where id = $1`, + tokenID, + ).Scan( + &token.ID, + &token.PrincipalID, + &token.Name, + &secretHash, + &token.ExpiresAt, + &lastUsedAt, + &revokedAt, + ) + if err != nil { + return apikey.StoredToken{}, err + } + // Reject malformed hashes early; copying into a fixed-size SHA-256 + // array silently truncates otherwise. + if len(secretHash) != sha256.Size { + return apikey.StoredToken{}, fmt.Errorf( + "postgres: token %q has invalid secret hash length %d", + tokenID, + len(secretHash), + ) + } + + copy(token.SecretHash[:], secretHash) + token.ExpiresAt = token.ExpiresAt.UTC() + token.LastUsedAt = timeFromTimestamptz(lastUsedAt) + token.RevokedAt = timeFromTimestamptz(revokedAt) + + return token, nil +} + +// scanTokenMetadata reads a token-metadata row from row. The row must +// select id, principal_id, name, expires_at, last_used_at, and revoked_at +// in that order. Secret material is intentionally not part of the metadata +// projection. +func scanTokenMetadata(row scanner) (apikey.TokenMetadata, error) { + var token apikey.TokenMetadata + var lastUsedAt pgtype.Timestamptz + var revokedAt pgtype.Timestamptz + if err := row.Scan( + &token.ID, + &token.PrincipalID, + &token.Name, + &token.ExpiresAt, + &lastUsedAt, + &revokedAt, + ); err != nil { + return apikey.TokenMetadata{}, err + } + + token.ExpiresAt = token.ExpiresAt.UTC() + token.LastUsedAt = timeFromTimestamptz(lastUsedAt) + token.RevokedAt = timeFromTimestamptz(revokedAt) + + return token, nil +} + +// timeFromTimestamptz converts a pgx Timestamptz into a pointer to time +// (nil when the column is NULL). The returned time is normalized to UTC. +func timeFromTimestamptz(value pgtype.Timestamptz) *time.Time { + if !value.Valid { + return nil + } + + t := value.Time.UTC() + + return &t +} diff --git a/store/postgres/token_test.go b/store/postgres/token_test.go new file mode 100644 index 0000000..5d3844f --- /dev/null +++ b/store/postgres/token_test.go @@ -0,0 +1,12 @@ +package postgres + +import ( + "testing" + + "github.com/meigma/authkit/apikey" +) + +func TestStoreSatisfiesTokenContracts(_ *testing.T) { + var _ apikey.TokenStore = (*Store)(nil) + var _ apikey.TokenMetadataLister = (*Store)(nil) +} diff --git a/store/postgres/tx.go b/store/postgres/tx.go new file mode 100644 index 0000000..c9a905b --- /dev/null +++ b/store/postgres/tx.go @@ -0,0 +1,47 @@ +package postgres + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// withTx runs fn inside a transaction on pool. If fn returns nil, the +// transaction is committed and fn's result returned. If fn returns a non-nil +// error, the deferred Rollback unwinds the transaction and the error +// propagates unchanged. Begin and Commit failures are wrapped with label so +// the call site is recognizable in error chains. +// +// Callers that need to read recovery state after a conflict (for example, +// to resolve a unique-violation race by re-reading the winning row) should +// have fn return a sentinel error, detect it after withTx returns, and read +// against the pool directly — the transaction is already rolled back by +// then. +func withTx[T any]( + ctx context.Context, + pool *pgxpool.Pool, + label string, + fn func(pgx.Tx) (T, error), +) (T, error) { + var zero T + + tx, err := pool.Begin(ctx) + if err != nil { + return zero, fmt.Errorf("postgres: begin %s: %w", label, err) + } + defer func() { + _ = tx.Rollback(ctx) + }() + + result, err := fn(tx) + if err != nil { + return zero, err + } + if err := tx.Commit(ctx); err != nil { + return zero, fmt.Errorf("postgres: commit %s: %w", label, err) + } + + return result, nil +} diff --git a/store/postgres/validation.go b/store/postgres/validation.go new file mode 100644 index 0000000..0d9e933 --- /dev/null +++ b/store/postgres/validation.go @@ -0,0 +1,47 @@ +package postgres + +import "fmt" + +// validateNonEmptyStrings returns an error when any value in values is empty. +// The error names name and the offending index for easier diagnosis. +func validateNonEmptyStrings(name string, values []string) error { + for i, value := range values { + if value == "" { + return fmt.Errorf("%s %d is required", name, i) + } + } + + return nil +} + +// validateRequiredStrings returns an error when values is empty or any +// element is empty. Use this when at least one non-empty entry must be +// supplied. +func validateRequiredStrings(name string, values []string) error { + if len(values) == 0 { + return fmt.Errorf("%s is required", name) + } + + return validateNonEmptyStrings(name, values) +} + +// uniqueStrings returns values with duplicates removed, preserving the order +// of first occurrence. A nil or empty input returns nil. +func uniqueStrings(values []string) []string { + if len(values) == 0 { + return nil + } + + unique := make([]string, 0, len(values)) + seen := make(map[string]struct{}, len(values)) + for _, value := range values { + if _, ok := seen[value]; ok { + continue + } + + seen[value] = struct{}{} + unique = append(unique, value) + } + + return unique +}