diff --git a/internal/admin/auth_audit_test.go b/internal/admin/auth_audit_test.go new file mode 100644 index 00000000..d36308d5 --- /dev/null +++ b/internal/admin/auth_audit_test.go @@ -0,0 +1,120 @@ +package admin + +import ( + "bytes" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func newAuthServiceWithAudit(t *testing.T) (*AuthService, *bytes.Buffer) { + t.Helper() + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + signer := newSignerForTest(t, 1, clk) + verifier := newVerifierForTest(t, []byte{1}, clk) + + creds := MapCredentialStore{ + "AKIA_ADMIN": "ADMIN_SECRET", + } + roles := map[string]Role{ + "AKIA_ADMIN": RoleFull, + } + buf := &bytes.Buffer{} + logger := slog.New(slog.NewJSONHandler(buf, &slog.HandlerOptions{Level: slog.LevelInfo})) + svc := NewAuthService(signer, creds, roles, AuthServiceOpts{ + Clock: clk, + Verifier: verifier, + Logger: logger, + }) + return svc, buf +} + +func TestAudit_LoginSuccessRecordsActor(t *testing.T) { + svc, buf := newAuthServiceWithAudit(t) + req := postJSON(t, loginRequest{AccessKey: "AKIA_ADMIN", SecretKey: "ADMIN_SECRET"}) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + out := buf.String() + require.Contains(t, out, `"msg":"admin_audit"`) + require.Contains(t, out, `"action":"login"`) + require.Contains(t, out, `"actor":"AKIA_ADMIN"`) + require.Contains(t, out, `"claimed_actor":"AKIA_ADMIN"`) + require.Contains(t, out, `"status":200`) +} + +func TestAudit_LoginFailureRecordsClaimedActor(t *testing.T) { + svc, buf := newAuthServiceWithAudit(t) + req := postJSON(t, loginRequest{AccessKey: "AKIA_ADMIN", SecretKey: "WRONG"}) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) + out := buf.String() + require.Contains(t, out, `"action":"login"`) + // We did NOT authenticate, so actor is empty. + require.Contains(t, out, `"actor":""`) + // But the claimed actor is still logged so operators can track + // which access key was targeted by brute-force attempts. + require.Contains(t, out, `"claimed_actor":"AKIA_ADMIN"`) + require.Contains(t, out, `"status":401`) +} + +func TestAudit_LogoutDecodesCookieForActor(t *testing.T) { + svc, buf := newAuthServiceWithAudit(t) + + // Log in first. + loginReq := postJSON(t, loginRequest{AccessKey: "AKIA_ADMIN", SecretKey: "ADMIN_SECRET"}) + loginRec := httptest.NewRecorder() + svc.HandleLogin(loginRec, loginReq) + require.Equal(t, http.StatusOK, loginRec.Code) + cookies := loginRec.Result().Cookies() + buf.Reset() + + // Now log out with the session cookie — audit must record actor. + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/auth/logout", nil) + req.RemoteAddr = "127.0.0.1:1" + for _, c := range cookies { + req.AddCookie(c) + } + rec := httptest.NewRecorder() + svc.HandleLogout(rec, req) + + require.Equal(t, http.StatusNoContent, rec.Code) + out := buf.String() + require.Contains(t, out, `"action":"logout"`) + require.Contains(t, out, `"actor":"AKIA_ADMIN"`) +} + +func TestAudit_LogoutWithoutCookieEmptyActor(t *testing.T) { + svc, buf := newAuthServiceWithAudit(t) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/auth/logout", nil) + req.RemoteAddr = "127.0.0.1:1" + rec := httptest.NewRecorder() + svc.HandleLogout(rec, req) + + require.Equal(t, http.StatusNoContent, rec.Code) + out := buf.String() + require.Contains(t, out, `"action":"logout"`) + require.Contains(t, out, `"actor":""`) +} + +func TestAudit_LoginLengthTimingHashed(t *testing.T) { + // Same-length secret mismatch and different-length secret mismatch + // must both reach the failure path without short-circuiting on + // length. We cannot time them precisely in a unit test, but we can + // at least verify both paths emit the same failure response. + svc, _ := newAuthServiceWithAudit(t) + for _, secret := range []string{"x", "much-longer-wrong-secret-value-here"} { + req := postJSON(t, loginRequest{AccessKey: "AKIA_ADMIN", SecretKey: secret}) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + require.Contains(t, rec.Body.String(), "invalid_credentials") + } +} diff --git a/internal/admin/auth_handler.go b/internal/admin/auth_handler.go new file mode 100644 index 00000000..6804a653 --- /dev/null +++ b/internal/admin/auth_handler.go @@ -0,0 +1,442 @@ +package admin + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "io" + "log/slog" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/cockroachdb/errors" + "github.com/goccy/go-json" +) + +// CredentialStore is the read-side view of the static SigV4 credential +// table the server was configured with. It returns the secret for a +// given access key, or ("", false) if the key is unknown. Supplying the +// same map the S3/DynamoDB adapters use keeps authentication consistent +// across the protocol surface. +type CredentialStore interface { + LookupSecret(accessKey string) (string, bool) +} + +// MapCredentialStore adapts a plain map into the CredentialStore +// interface. Callers typically load this from config at startup and +// hand the same map to the S3 adapter and the admin service. +type MapCredentialStore map[string]string + +// LookupSecret implements CredentialStore. +func (m MapCredentialStore) LookupSecret(accessKey string) (string, bool) { + secret, ok := m[strings.TrimSpace(accessKey)] + return secret, ok +} + +// AuthService wires the login/logout handlers, token minting, role +// lookup, and per-IP rate limiter together. Construct it once at +// startup and reuse across the admin listener's lifetime. +type AuthService struct { + signer *Signer + verifier *Verifier + creds CredentialStore + roles map[string]Role + limiter *rateLimiter + loginWindow time.Duration + sessionTTL time.Duration + secureCookie bool + cookieDomain string + clock Clock + logger *slog.Logger +} + +// AuthServiceOpts covers the knobs a caller may want to vary in tests. +// Zero values fall back to production defaults. +type AuthServiceOpts struct { + // InsecureCookie disables the Secure attribute on the issued + // cookies. It exists only for local plaintext-loopback development + // and is expected to stay false in any real deployment. + InsecureCookie bool + // CookieDomain is optional and rarely used. Empty means "host-only + // cookie", which is the default and the safest choice. + CookieDomain string + // LoginLimit is the per-IP rate limit (default 5). + LoginLimit int + // LoginWindow is the rate-limit window (default 1 minute). + LoginWindow time.Duration + // Clock drives rate-limiter aging. Defaults to SystemClock. + Clock Clock + // Verifier lets the logout handler best-effort decode the + // incoming session cookie and include the actor in the audit + // log. When nil, logout events are still audited but with an + // empty actor field. + Verifier *Verifier + // Logger is the slog destination for admin_audit entries emitted + // by the login/logout handlers. nil falls back to slog.Default(). + Logger *slog.Logger +} + +// NewAuthService constructs an AuthService. The signer must be primary +// (use NewSigner with the current key); token verification uses the +// Verifier passed separately to SessionAuth. +func NewAuthService(signer *Signer, creds CredentialStore, roles map[string]Role, opts AuthServiceOpts) *AuthService { + limit := opts.LoginLimit + if limit <= 0 { + limit = 5 + } + window := opts.LoginWindow + if window <= 0 { + window = time.Minute + } + if opts.Clock == nil { + opts.Clock = SystemClock + } + logger := opts.Logger + if logger == nil { + logger = slog.Default() + } + return &AuthService{ + signer: signer, + verifier: opts.Verifier, + creds: creds, + roles: roles, + limiter: newRateLimiter(limit, window, opts.Clock), + loginWindow: window, + sessionTTL: sessionTTL, + secureCookie: !opts.InsecureCookie, + cookieDomain: opts.CookieDomain, + clock: opts.Clock, + logger: logger, + } +} + +// loginRequest is the JSON body the login endpoint accepts. +type loginRequest struct { + AccessKey string `json:"access_key"` + SecretKey string `json:"secret_key"` +} + +// loginResponse is the JSON body the login endpoint returns on success. +// The CSRF token is delivered exclusively via the admin_csrf cookie (see +// the Set-Cookie headers the handler sets on the same response); we do +// not echo it in the JSON body to avoid encouraging clients to cache or +// log the token out of band. +type loginResponse struct { + Role Role `json:"role"` + ExpiresAt time.Time `json:"expires_at"` +} + +// HandleLogin validates credentials and issues the session + CSRF cookies. +// It is safe to expose without the SessionAuth middleware because this is +// where a session first comes from; rate limiting, Content-Type validation, +// and constant-time credential comparison guard it. +// +// Login events (success and failure) emit admin_audit slog entries +// directly. The generic Audit middleware cannot do this because it runs +// before the handler knows who the caller is claiming to be. +func (s *AuthService) HandleLogin(w http.ResponseWriter, r *http.Request) { + rec := newStatusRecorder(w) + defer s.auditLogin(r, rec) + + if !s.preflightLogin(rec, r) { + return + } + req, ok := readLoginRequest(rec, r) + rec.claimedActor = req.AccessKey + if !ok { + return + } + principal, ok := s.authenticate(rec, req) + if !ok { + return + } + s.issueSession(rec, principal) + rec.actor = principal.AccessKey +} + +func (s *AuthService) preflightLogin(w http.ResponseWriter, r *http.Request) bool { + if r.Method != http.MethodPost { + writeJSONError(w, http.StatusMethodNotAllowed, "method_not_allowed", "login requires POST") + return false + } + if !s.limiter.allow(clientIP(r)) { + // Retry-After must be derived from the actual rate-limit + // window so tests and callers that tune LoginWindow get an + // accurate hint; clamp to at least 1 second so we never + // send a zero value. + retryAfter := int(s.loginWindow.Seconds()) + if retryAfter < 1 { + retryAfter = 1 + } + w.Header().Set("Retry-After", strconv.Itoa(retryAfter)) + writeJSONError(w, http.StatusTooManyRequests, "rate_limited", + "too many login attempts from this source; try again later") + return false + } + ct := strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))) + if !strings.HasPrefix(ct, "application/json") { + writeJSONError(w, http.StatusUnsupportedMediaType, "unsupported_media_type", + "login requires Content-Type: application/json") + return false + } + return true +} + +func readLoginRequest(w http.ResponseWriter, r *http.Request) (loginRequest, bool) { + raw, err := io.ReadAll(r.Body) + if err != nil { + if IsMaxBytesError(err) { + WriteMaxBytesError(w) + return loginRequest{}, false + } + writeJSONError(w, http.StatusBadRequest, "invalid_body", "failed to read body") + return loginRequest{}, false + } + var req loginRequest + if err := json.Unmarshal(raw, &req); err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid_body", "body is not valid JSON") + return loginRequest{}, false + } + // Access keys are AWS-style identifiers that users sometimes copy + // with surrounding whitespace; trimming there is harmless and + // matches how the S3 adapter normalises its credential table at + // load time. Secrets, by contrast, are opaque bytes — trimming + // would accept inputs the SigV4 adapter would reject, creating a + // cross-protocol inconsistency. Leave SecretKey untouched. + req.AccessKey = strings.TrimSpace(req.AccessKey) + if req.AccessKey == "" || req.SecretKey == "" { + writeJSONError(w, http.StatusBadRequest, "missing_fields", + "access_key and secret_key are required") + return loginRequest{}, false + } + return req, true +} + +// secretCompareKey is a per-process random key used to derive +// fixed-length digests of incoming and expected login secrets before a +// constant-time comparison. The key itself does not need to be secret — +// its only job is to: +// +// 1. normalise inputs to a fixed 32-byte width so subtle.ConstantTimeCompare +// cannot leak the length of the expected secret via an early-return, +// 2. make the construction a keyed MAC rather than a naked password hash, +// which keeps static analysis (CodeQL) aligned with the intent: this +// is a timing-safe comparator, not a persisted password hash. +// +// We deliberately do not use bcrypt / argon2 here: nothing is persisted, +// the secret is received in plaintext over TLS at login time, and the +// rate limiter already bounds online guessing. A computationally +// expensive KDF would add latency to every login attempt without +// changing the threat model. +var ( + secretCompareKey []byte + secretCompareKeyOnce sync.Once +) + +// unknownKeySecretPlaceholder is a deterministic dummy we hash when +// the incoming access key is unknown. We hash it fresh on every call +// (rather than precomputing the digest once) so the unknown-key +// branch performs the same HMAC work as the known-key branch — +// otherwise an attacker could enumerate valid access keys by +// measuring login latency. It is NOT a credential: nothing grants +// access to any resource, and grepping for it finds only this file. +const unknownKeySecretPlaceholder = "admin-auth-unknown-key-placeholder" //nolint:gosec // intentional non-credential sentinel; see comment above. + +func initSecretCompareKey() { + secretCompareKey = make([]byte, sha256.Size) + if _, err := rand.Read(secretCompareKey); err != nil { + // rand.Read never fails on supported platforms; if it did, + // panicking is the right reaction — the admin listener + // cannot authenticate anyone anyway. + panic("admin: crypto/rand failure while initialising secret compare key: " + err.Error()) + } +} + +// digestForCompare returns HMAC-SHA256(secretCompareKey, s). Used only +// for timing-safe comparison of login secrets; never stored. +func digestForCompare(s string) []byte { + secretCompareKeyOnce.Do(initSecretCompareKey) + mac := hmac.New(sha256.New, secretCompareKey) + mac.Write([]byte(s)) + return mac.Sum(nil) +} + +func (s *AuthService) authenticate(w http.ResponseWriter, req loginRequest) (AuthPrincipal, bool) { + providedHash := digestForCompare(req.SecretKey) + expected, known := s.creds.LookupSecret(req.AccessKey) + // Compute the expected digest fresh in BOTH branches so the + // amount of HMAC work is identical regardless of whether the + // access key is known. A precomputed placeholder digest would + // make the unknown-key path measurably faster, letting an + // attacker enumerate valid access keys via login latency. + expectedSecret := expected + if !known { + expectedSecret = unknownKeySecretPlaceholder + } + expectedHash := digestForCompare(expectedSecret) + match := subtle.ConstantTimeCompare(providedHash, expectedHash) == 1 + if !known || !match { + writeJSONError(w, http.StatusUnauthorized, "invalid_credentials", + "access_key or secret_key is invalid") + return AuthPrincipal{}, false + } + role, ok := s.roles[req.AccessKey] + if !ok { + writeJSONError(w, http.StatusForbidden, "forbidden", + "access_key is not authorised for admin access") + return AuthPrincipal{}, false + } + return AuthPrincipal{AccessKey: req.AccessKey, Role: role}, true +} + +func (s *AuthService) issueSession(w http.ResponseWriter, principal AuthPrincipal) { + token, err := s.signer.Sign(principal) + if err != nil { + writeJSONError(w, http.StatusInternalServerError, "internal", "failed to mint session token") + return + } + csrf, err := newCSRFToken() + if err != nil { + writeJSONError(w, http.StatusInternalServerError, "internal", "failed to mint csrf token") + return + } + // Compute the response expiry from the AuthService clock plus + // the configured session TTL. Injected test clocks therefore + // produce deterministic values, and in practice callers (main + // and NewServer) pass the same clock to both the Signer and + // the AuthService, so the response's expires_at matches the + // JWT exp claim. We do not cross-check against the signer's + // clock here because Signer does not expose it; if a future + // caller wires them independently, expires_at may drift by up + // to the delta between the two clocks. + expires := s.clock().UTC().Add(s.sessionTTL) + http.SetCookie(w, s.buildCookie(sessionCookieName, token, true)) + http.SetCookie(w, s.buildCookie(csrfCookieName, csrf, false)) + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(loginResponse{Role: principal.Role, ExpiresAt: expires}) +} + +// HandleLogout clears both cookies. The route is wired behind the +// protected middleware chain (SessionAuth + CSRF), so unauthenticated +// or cross-site callers are rejected before they reach this handler — +// that is what prevents logout-CSRF. We still best-effort decode the +// incoming session cookie so the audit log can record who logged out; +// a missing or invalid cookie leaves actor empty. +func (s *AuthService) HandleLogout(w http.ResponseWriter, r *http.Request) { + rec := newStatusRecorder(w) + defer s.auditLogout(r, rec) + if r.Method != http.MethodPost { + writeJSONError(rec, http.StatusMethodNotAllowed, "method_not_allowed", "logout requires POST") + return + } + if s.verifier != nil { + if c, err := r.Cookie(sessionCookieName); err == nil && strings.TrimSpace(c.Value) != "" { + if p, verr := s.verifier.Verify(c.Value); verr == nil { + rec.actor = p.AccessKey + } + } + } + http.SetCookie(rec, s.buildExpiredCookie(sessionCookieName, true)) + http.SetCookie(rec, s.buildExpiredCookie(csrfCookieName, false)) + rec.Header().Set("Cache-Control", "no-store") + rec.WriteHeader(http.StatusNoContent) +} + +// statusRecorder captures the response status + writes we emit so the +// audit log can include both the final code and the claimed actor. +type statusRecorder struct { + http.ResponseWriter + status int + claimedActor string // what the caller said they were + actor string // what we authenticated them as (empty on failure) +} + +func newStatusRecorder(w http.ResponseWriter) *statusRecorder { + return &statusRecorder{ResponseWriter: w} +} + +func (r *statusRecorder) WriteHeader(code int) { + if r.status == 0 { + r.status = code + } + r.ResponseWriter.WriteHeader(code) +} + +func (r *statusRecorder) Write(b []byte) (int, error) { + if r.status == 0 { + r.status = http.StatusOK + } + n, err := r.ResponseWriter.Write(b) + if err != nil { + return n, errors.Wrap(err, "status recorder write") + } + return n, nil +} + +func (s *AuthService) auditLogin(r *http.Request, rec *statusRecorder) { + s.logger.LogAttrs(r.Context(), slog.LevelInfo, "admin_audit", + slog.String("action", "login"), + slog.String("actor", rec.actor), + slog.String("claimed_actor", rec.claimedActor), + slog.String("remote", r.RemoteAddr), + slog.Int("status", nonZero(rec.status, http.StatusOK)), + ) +} + +func (s *AuthService) auditLogout(r *http.Request, rec *statusRecorder) { + s.logger.LogAttrs(r.Context(), slog.LevelInfo, "admin_audit", + slog.String("action", "logout"), + slog.String("actor", rec.actor), + slog.String("remote", r.RemoteAddr), + slog.Int("status", nonZero(rec.status, http.StatusOK)), + ) +} + +func nonZero(v, fallback int) int { + if v == 0 { + return fallback + } + return v +} + +func (s *AuthService) buildCookie(name, value string, httpOnly bool) *http.Cookie { + return &http.Cookie{ + Name: name, + Value: value, + Path: pathPrefixAdmin, + Domain: s.cookieDomain, + MaxAge: int(s.sessionTTL.Seconds()), + Secure: s.secureCookie, + HttpOnly: httpOnly, + SameSite: http.SameSiteStrictMode, + } +} + +func (s *AuthService) buildExpiredCookie(name string, httpOnly bool) *http.Cookie { + return &http.Cookie{ + Name: name, + Value: "", + Path: pathPrefixAdmin, + Domain: s.cookieDomain, + MaxAge: -1, + Expires: time.Unix(0, 0), + Secure: s.secureCookie, + HttpOnly: httpOnly, + SameSite: http.SameSiteStrictMode, + } +} + +func newCSRFToken() (string, error) { + var raw [32]byte + if _, err := rand.Read(raw[:]); err != nil { + return "", errors.Wrap(err, "read random bytes for csrf token") + } + return base64.RawURLEncoding.EncodeToString(raw[:]), nil +} diff --git a/internal/admin/auth_handler_test.go b/internal/admin/auth_handler_test.go new file mode 100644 index 00000000..95696c52 --- /dev/null +++ b/internal/admin/auth_handler_test.go @@ -0,0 +1,273 @@ +package admin + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/goccy/go-json" + "github.com/stretchr/testify/require" +) + +func newAuthServiceForTest(t *testing.T) (*AuthService, *Verifier) { + t.Helper() + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + signer := newSignerForTest(t, 1, clk) + verifier := newVerifierForTest(t, []byte{1}, clk) + + creds := MapCredentialStore{ + "AKIA_ADMIN": "ADMIN_SECRET", + "AKIA_RO": "RO_SECRET", + "AKIA_OTHER": "OTHER_SECRET", // present in creds but not in roles + } + roles := map[string]Role{ + "AKIA_ADMIN": RoleFull, + "AKIA_RO": RoleReadOnly, + } + svc := NewAuthService(signer, creds, roles, AuthServiceOpts{ + Clock: clk, + }) + return svc, verifier +} + +func postJSON(t *testing.T, body any) *http.Request { + t.Helper() + buf, err := json.Marshal(body) + require.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/auth/login", strings.NewReader(string(buf))) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "127.0.0.1:50001" + return req +} + +func TestLogin_HappyPathFull(t *testing.T) { + svc, verifier := newAuthServiceForTest(t) + req := postJSON(t, loginRequest{AccessKey: "AKIA_ADMIN", SecretKey: "ADMIN_SECRET"}) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp loginResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, RoleFull, resp.Role) + + // Find both cookies and assert attributes. + var session, csrf *http.Cookie + for _, c := range rec.Result().Cookies() { + switch c.Name { + case sessionCookieName: + session = c + case csrfCookieName: + csrf = c + } + } + require.NotNil(t, session, "expected admin_session cookie") + require.NotNil(t, csrf, "expected admin_csrf cookie") + + // Cookie hardening. + require.True(t, session.HttpOnly) + require.True(t, session.Secure) + require.Equal(t, http.SameSiteStrictMode, session.SameSite) + require.Equal(t, "/admin", session.Path) + require.Equal(t, int(sessionTTL.Seconds()), session.MaxAge) + + require.False(t, csrf.HttpOnly, "CSRF cookie must be readable by SPA") + require.True(t, csrf.Secure) + require.Equal(t, http.SameSiteStrictMode, csrf.SameSite) + require.Equal(t, "/admin", csrf.Path) + + // Token must verify. + principal, err := verifier.Verify(session.Value) + require.NoError(t, err) + require.Equal(t, "AKIA_ADMIN", principal.AccessKey) + require.Equal(t, RoleFull, principal.Role) +} + +func TestLogin_ReadOnlyMappedToRoleReadOnly(t *testing.T) { + svc, verifier := newAuthServiceForTest(t) + req := postJSON(t, loginRequest{AccessKey: "AKIA_RO", SecretKey: "RO_SECRET"}) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var session *http.Cookie + for _, c := range rec.Result().Cookies() { + if c.Name == sessionCookieName { + session = c + } + } + require.NotNil(t, session) + principal, err := verifier.Verify(session.Value) + require.NoError(t, err) + require.Equal(t, RoleReadOnly, principal.Role) +} + +func TestLogin_WrongSecretRejected(t *testing.T) { + svc, _ := newAuthServiceForTest(t) + req := postJSON(t, loginRequest{AccessKey: "AKIA_ADMIN", SecretKey: "WRONG"}) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + require.Contains(t, rec.Body.String(), "invalid_credentials") + // No cookies on failure. + require.Empty(t, rec.Result().Cookies()) +} + +func TestLogin_UnknownAccessKeyRejected(t *testing.T) { + svc, _ := newAuthServiceForTest(t) + req := postJSON(t, loginRequest{AccessKey: "AKIA_NOBODY", SecretKey: "anything"}) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + require.Contains(t, rec.Body.String(), "invalid_credentials") +} + +func TestLogin_CredentialValidButNotAdminRejected(t *testing.T) { + svc, _ := newAuthServiceForTest(t) + // AKIA_OTHER is in creds but not in the role index. + req := postJSON(t, loginRequest{AccessKey: "AKIA_OTHER", SecretKey: "OTHER_SECRET"}) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "forbidden") +} + +func TestLogin_MissingFields(t *testing.T) { + svc, _ := newAuthServiceForTest(t) + for _, body := range []loginRequest{ + {AccessKey: "", SecretKey: "x"}, + {AccessKey: "x", SecretKey: ""}, + {}, + } { + req := postJSON(t, body) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equal(t, http.StatusBadRequest, rec.Code) + } +} + +func TestLogin_RequiresJSON(t *testing.T) { + svc, _ := newAuthServiceForTest(t) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/auth/login", + strings.NewReader("access_key=x&secret_key=y")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.RemoteAddr = "127.0.0.1:1" + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equal(t, http.StatusUnsupportedMediaType, rec.Code) +} + +func TestLogin_OnlyPOST(t *testing.T) { + svc, _ := newAuthServiceForTest(t) + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/auth/login", nil) + req.RemoteAddr = "127.0.0.1:1" + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equal(t, http.StatusMethodNotAllowed, rec.Code) +} + +func TestLogin_RateLimitPerIP(t *testing.T) { + svc, _ := newAuthServiceForTest(t) + // Configure is already default 5/min. + for i := 0; i < 5; i++ { + req := postJSON(t, loginRequest{AccessKey: "AKIA_ADMIN", SecretKey: "WRONG"}) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equalf(t, http.StatusUnauthorized, rec.Code, "attempt %d", i+1) + } + // 6th attempt from the same IP must be rate limited. + req := postJSON(t, loginRequest{AccessKey: "AKIA_ADMIN", SecretKey: "WRONG"}) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Equal(t, "60", rec.Header().Get("Retry-After")) +} + +func TestLogin_RateLimitIsPerIPNotGlobal(t *testing.T) { + svc, _ := newAuthServiceForTest(t) + for i := 0; i < 5; i++ { + req := postJSON(t, loginRequest{AccessKey: "AKIA_ADMIN", SecretKey: "WRONG"}) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + } + // Different IP — should not be throttled. + req := postJSON(t, loginRequest{AccessKey: "AKIA_ADMIN", SecretKey: "ADMIN_SECRET"}) + req.RemoteAddr = "10.0.0.1:12345" + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestLogin_InsecureCookieOptIn(t *testing.T) { + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + signer := newSignerForTest(t, 1, clk) + creds := MapCredentialStore{"AKIA_ADMIN": "ADMIN_SECRET"} + roles := map[string]Role{"AKIA_ADMIN": RoleFull} + svc := NewAuthService(signer, creds, roles, AuthServiceOpts{Clock: clk, InsecureCookie: true}) + + req := postJSON(t, loginRequest{AccessKey: "AKIA_ADMIN", SecretKey: "ADMIN_SECRET"}) + rec := httptest.NewRecorder() + svc.HandleLogin(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + for _, c := range rec.Result().Cookies() { + require.Falsef(t, c.Secure, "cookie %s must not be Secure in dev mode", c.Name) + } +} + +func TestLogout_ExpiresBothCookies(t *testing.T) { + svc, _ := newAuthServiceForTest(t) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/auth/logout", nil) + req.RemoteAddr = "127.0.0.1:1" + rec := httptest.NewRecorder() + svc.HandleLogout(rec, req) + + require.Equal(t, http.StatusNoContent, rec.Code) + var names []string + for _, c := range rec.Result().Cookies() { + require.Equal(t, -1, c.MaxAge) + require.Equal(t, "", c.Value) + names = append(names, c.Name) + } + require.ElementsMatch(t, []string{sessionCookieName, csrfCookieName}, names) +} + +func TestLogout_OnlyPOST(t *testing.T) { + svc, _ := newAuthServiceForTest(t) + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/auth/logout", nil) + req.RemoteAddr = "127.0.0.1:1" + rec := httptest.NewRecorder() + svc.HandleLogout(rec, req) + require.Equal(t, http.StatusMethodNotAllowed, rec.Code) +} + +func TestRateLimiter_ResetsAfterWindow(t *testing.T) { + now := time.Unix(1_700_000_000, 0).UTC() + clk := func() time.Time { return now } + rl := newRateLimiter(2, time.Minute, clk) + + require.True(t, rl.allow("1.1.1.1")) + require.True(t, rl.allow("1.1.1.1")) + require.False(t, rl.allow("1.1.1.1")) + + now = now.Add(61 * time.Second) + require.True(t, rl.allow("1.1.1.1")) +} + +func TestClientIP(t *testing.T) { + for _, tc := range []struct { + remote string + want string + }{ + {"127.0.0.1:12345", "127.0.0.1"}, + {"[::1]:12345", "::1"}, + {"no-port", "no-port"}, + } { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = tc.remote + require.Equal(t, tc.want, clientIP(req)) + } +} diff --git a/internal/admin/cluster_handler.go b/internal/admin/cluster_handler.go new file mode 100644 index 00000000..d0be3d97 --- /dev/null +++ b/internal/admin/cluster_handler.go @@ -0,0 +1,104 @@ +package admin + +import ( + "context" + "log/slog" + "net/http" + "time" + + "github.com/goccy/go-json" +) + +// ClusterInfo is the lightweight snapshot the admin dashboard displays on +// its landing page. Everything here is cheap to assemble; we deliberately +// do not include per-shard key counts or byte statistics to keep the +// endpoint safe to poll. +type ClusterInfo struct { + NodeID string `json:"node_id"` + Version string `json:"version"` + Timestamp time.Time `json:"timestamp"` + Groups []GroupInfo `json:"groups"` +} + +// GroupInfo describes a single Raft group from the local node's point of +// view. LeaderID is the empty string during an election or when the node +// has not yet discovered the leader. +type GroupInfo struct { + GroupID uint64 `json:"group_id"` + LeaderID string `json:"leader_id"` + Members []string `json:"members"` + IsLeader bool `json:"is_leader"` +} + +// ClusterInfoSource is the small contract the cluster handler calls out +// to. Production wires this to a real Raft/engine view; tests use a stub. +type ClusterInfoSource interface { + Describe(ctx context.Context) (ClusterInfo, error) +} + +// ClusterInfoFunc is a convenience adapter for wiring a plain function +// without defining an interface implementation. +type ClusterInfoFunc func(ctx context.Context) (ClusterInfo, error) + +// Describe implements ClusterInfoSource. +func (f ClusterInfoFunc) Describe(ctx context.Context) (ClusterInfo, error) { + return f(ctx) +} + +// ClusterHandler serves GET /admin/api/v1/cluster. +type ClusterHandler struct { + source ClusterInfoSource + logger *slog.Logger +} + +// NewClusterHandler wires a source into the HTTP handler and seeds +// logging with slog.Default(). Callers that want a tagged logger can +// chain WithLogger(...) on the returned handler. +func NewClusterHandler(source ClusterInfoSource) *ClusterHandler { + return &ClusterHandler{source: source, logger: slog.Default()} +} + +// WithLogger overrides the default slog destination. Kept as an option +// so main.go can attach a component tag without changing the +// constructor signature. +func (h *ClusterHandler) WithLogger(l *slog.Logger) *ClusterHandler { + if l == nil { + return h + } + h.logger = l + return h +} + +// ServeHTTP renders the cluster snapshot as JSON. Errors from the source +// are logged on the server with full detail and surfaced to the client as +// a generic "cluster_describe_failed" code. Leaking err.Error() to +// unauthenticated-ish clients would reveal raft/store internals. +func (h *ClusterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeJSONError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET") + return + } + info, err := h.source.Describe(r.Context()) + if err != nil { + h.logger.LogAttrs(r.Context(), slog.LevelError, "admin cluster describe failed", + slog.String("error", err.Error()), + ) + writeJSONError(w, http.StatusInternalServerError, "cluster_describe_failed", + "failed to describe cluster state; see server logs") + return + } + if info.Timestamp.IsZero() { + info.Timestamp = time.Now().UTC() + } + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(info); err != nil { + // The 200 header is already on the wire, so we cannot + // change the status — but a truncated JSON body is hard + // to diagnose without a breadcrumb. + h.logger.LogAttrs(r.Context(), slog.LevelWarn, "admin cluster response encode failed", + slog.String("error", err.Error()), + ) + } +} diff --git a/internal/admin/cluster_handler_test.go b/internal/admin/cluster_handler_test.go new file mode 100644 index 00000000..f77dbd1b --- /dev/null +++ b/internal/admin/cluster_handler_test.go @@ -0,0 +1,80 @@ +package admin + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/goccy/go-json" + "github.com/stretchr/testify/require" +) + +func TestClusterHandler_HappyPath(t *testing.T) { + source := ClusterInfoFunc(func(_ context.Context) (ClusterInfo, error) { + return ClusterInfo{ + NodeID: "node-1", + Version: "0.42.0", + Groups: []GroupInfo{ + {GroupID: 1, LeaderID: "node-1", IsLeader: true, Members: []string{"node-1", "node-2", "node-3"}}, + {GroupID: 2, LeaderID: "node-2", Members: []string{"node-1", "node-2", "node-3"}}, + }, + }, nil + }) + h := NewClusterHandler(source) + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/cluster", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var got ClusterInfo + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, "node-1", got.NodeID) + require.Equal(t, "0.42.0", got.Version) + require.Len(t, got.Groups, 2) + require.True(t, got.Groups[0].IsLeader) + require.False(t, got.Timestamp.IsZero()) +} + +func TestClusterHandler_PreservesExplicitTimestamp(t *testing.T) { + ts := time.Date(2026, 4, 24, 10, 0, 0, 0, time.UTC) + source := ClusterInfoFunc(func(_ context.Context) (ClusterInfo, error) { + return ClusterInfo{NodeID: "n", Version: "v", Timestamp: ts}, nil + }) + h := NewClusterHandler(source) + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/cluster", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + + var got ClusterInfo + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, ts, got.Timestamp.UTC()) +} + +func TestClusterHandler_SourceErrorReturns500(t *testing.T) { + source := ClusterInfoFunc(func(_ context.Context) (ClusterInfo, error) { + return ClusterInfo{}, errors.New("raft storage sentinel XYZ123") + }) + h := NewClusterHandler(source) + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/cluster", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + + require.Equal(t, http.StatusInternalServerError, rec.Code) + require.Contains(t, rec.Body.String(), "cluster_describe_failed") + // The raw error must not leak to the client. + require.NotContains(t, rec.Body.String(), "XYZ123") + require.NotContains(t, rec.Body.String(), "raft storage sentinel") +} + +func TestClusterHandler_OnlyGET(t *testing.T) { + h := NewClusterHandler(ClusterInfoFunc(func(_ context.Context) (ClusterInfo, error) { + return ClusterInfo{}, nil + })) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/cluster", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusMethodNotAllowed, rec.Code) +} diff --git a/internal/admin/config.go b/internal/admin/config.go new file mode 100644 index 00000000..241a68d0 --- /dev/null +++ b/internal/admin/config.go @@ -0,0 +1,266 @@ +package admin + +import ( + "encoding/base64" + "net" + "strings" + + "github.com/cockroachdb/errors" +) + +const ( + // sessionSigningKeyLen is the required raw byte length for the admin + // JWT HS256 signing key. + sessionSigningKeyLen = 64 +) + +// Config captures everything the admin listener needs at startup. It mirrors +// the Section 7.1 table in docs/design/2026_04_24_proposed_admin_dashboard.md +// and intentionally uses plain Go fields rather than a config library so the +// existing flag-based wiring in main.go can hand values over without a new +// dependency. +type Config struct { + // Enabled toggles the admin listener. Default false. + Enabled bool + + // Listen is the host:port for the admin HTTP server. Default + // 127.0.0.1:8080 (loopback only). + Listen string + + // TLSCertFile / TLSKeyFile enable TLS when both are set. + TLSCertFile string + TLSKeyFile string + + // AllowPlaintextNonLoopback opts out of the TLS-on-non-loopback + // requirement. Refusing to honour it is the default. + AllowPlaintextNonLoopback bool + + // SessionSigningKey is the base64-encoded cluster-wide HS256 key. It + // must decode to exactly 64 bytes. + SessionSigningKey string + + // SessionSigningKeyPrevious is an optional base64-encoded previous + // signing key, used to verify tokens issued before a key rotation. + SessionSigningKeyPrevious string + + // ReadOnlyAccessKeys grants the GET subset of admin endpoints. + ReadOnlyAccessKeys []string + + // FullAccessKeys grants the full CRUD surface of admin endpoints. + FullAccessKeys []string + + // AllowInsecureDevCookie turns off the always-on Secure cookie + // attribute. Intended only for local plaintext development; it is off + // by default and the startup banner calls it out loudly. + AllowInsecureDevCookie bool +} + +// Validate returns the first configuration error found, if any. It does not +// try to collect every error because any of these conditions is a hard +// startup failure. +func (c *Config) Validate() error { + if c == nil { + return errors.New("admin config is nil") + } + if !c.Enabled { + return nil + } + if err := c.validateListen(); err != nil { + return err + } + if err := c.validateTLS(); err != nil { + return err + } + if err := c.validateSigningKeys(); err != nil { + return err + } + return validateAccessKeyRoles(c.ReadOnlyAccessKeys, c.FullAccessKeys) +} + +func (c *Config) validateListen() error { + listen := strings.TrimSpace(c.Listen) + if listen == "" { + return errors.New("-adminListen must not be empty when -adminEnabled=true") + } + if _, _, err := net.SplitHostPort(listen); err != nil { + return errors.Wrapf(err, "-adminListen %q is not host:port", c.Listen) + } + return nil +} + +func (c *Config) validateTLS() error { + certSet := strings.TrimSpace(c.TLSCertFile) != "" + keySet := strings.TrimSpace(c.TLSKeyFile) != "" + if certSet != keySet { + // A lone cert or key almost always means a typo. Silently + // treating it as "TLS off" would downgrade transport + // security while the operator thinks TLS is enabled; fail + // fast so the misconfiguration is visible at startup. + return errors.New("-adminTLSCertFile and -adminTLSKeyFile must be set together;" + + " partial TLS configuration is not allowed") + } + tlsConfigured := certSet && keySet + if tlsConfigured || !addressRequiresTLS(strings.TrimSpace(c.Listen)) || c.AllowPlaintextNonLoopback { + return nil + } + // errors.Newf already carries a stack; errors.WithStack here is + // for the project's wrapcheck linter, which requires every + // cockroachdb/errors return at a package boundary to be wrapped. + return errors.WithStack(errors.Newf( + "-adminListen %q is not loopback but TLS is not configured;"+ + " set -adminTLSCertFile + -adminTLSKeyFile, or explicitly pass"+ + " -adminAllowPlaintextNonLoopback (strongly discouraged)", + c.Listen, + )) +} + +func (c *Config) validateSigningKeys() error { + primary, err := decodeSigningKey("-adminSessionSigningKey", c.SessionSigningKey) + if err != nil { + return err + } + if len(primary) == 0 { + return errors.New("-adminSessionSigningKey is required when -adminEnabled=true") + } + if strings.TrimSpace(c.SessionSigningKeyPrevious) == "" { + return nil + } + if _, err := decodeSigningKey("-adminSessionSigningKeyPrevious", c.SessionSigningKeyPrevious); err != nil { + return err + } + return nil +} + +// DecodedSigningKeys returns the raw HS256 keys in verification order: the +// primary signing key first, followed by an optional previous key. Validate +// must be called first; this method also asserts that contract defensively +// so a missing key cannot quietly produce a `[][]byte{nil}` result and feed +// a nil HMAC key into the verifier. +func (c *Config) DecodedSigningKeys() ([][]byte, error) { + if c == nil || !c.Enabled { + return nil, errors.New("DecodedSigningKeys called on a disabled admin config") + } + primary, err := decodeSigningKey("-adminSessionSigningKey", c.SessionSigningKey) + if err != nil { + return nil, err + } + if len(primary) == 0 { + return nil, errors.New("-adminSessionSigningKey is empty; call Validate first") + } + keys := [][]byte{primary} + if strings.TrimSpace(c.SessionSigningKeyPrevious) == "" { + return keys, nil + } + prev, err := decodeSigningKey("-adminSessionSigningKeyPrevious", c.SessionSigningKeyPrevious) + if err != nil { + return nil, err + } + if len(prev) == 0 { + return nil, errors.New("-adminSessionSigningKeyPrevious is set but decoded to zero bytes") + } + return append(keys, prev), nil +} + +// RoleIndex returns a map from access key to Role after Validate has +// succeeded. The caller must not mutate the returned map. +func (c *Config) RoleIndex() map[string]Role { + index := make(map[string]Role, len(c.FullAccessKeys)+len(c.ReadOnlyAccessKeys)) + for _, k := range c.FullAccessKeys { + trim := strings.TrimSpace(k) + if trim == "" { + continue + } + index[trim] = RoleFull + } + for _, k := range c.ReadOnlyAccessKeys { + trim := strings.TrimSpace(k) + if trim == "" { + continue + } + // Overlap is rejected by Validate; this branch only runs for + // keys exclusive to read_only. + index[trim] = RoleReadOnly + } + return index +} + +func decodeSigningKey(field, encoded string) ([]byte, error) { + trim := strings.TrimSpace(encoded) + if trim == "" { + return nil, nil + } + // Operators routinely copy keys out of Kubernetes Secrets or + // similar tooling that may emit either standard or URL-safe + // base64, with or without padding. Try each alphabet in turn so + // a correct key is never rejected over a formatting mismatch. + decoders := []*base64.Encoding{ + base64.StdEncoding, + base64.RawStdEncoding, + base64.URLEncoding, + base64.RawURLEncoding, + } + var ( + raw []byte + decErr error + ) + for _, dec := range decoders { + raw, decErr = dec.DecodeString(trim) + if decErr == nil { + break + } + } + if decErr != nil { + return nil, errors.Wrapf(decErr, "%s is not valid base64", field) + } + if len(raw) != sessionSigningKeyLen { + return nil, errors.WithStack(errors.Newf( + "%s must decode to %d bytes but got %d bytes", + field, sessionSigningKeyLen, len(raw), + )) + } + return raw, nil +} + +func validateAccessKeyRoles(readOnly, full []string) error { + fullSet := make(map[string]struct{}, len(full)) + for _, k := range full { + trim := strings.TrimSpace(k) + if trim == "" { + continue + } + fullSet[trim] = struct{}{} + } + for _, k := range readOnly { + trim := strings.TrimSpace(k) + if trim == "" { + continue + } + if _, dup := fullSet[trim]; dup { + return errors.WithStack(errors.Newf( + "access key %q is listed in both -adminReadOnlyAccessKeys and -adminFullAccessKeys;"+ + " this would silently grant write access depending on lookup order, so it is rejected at startup", + trim, + )) + } + } + return nil +} + +// addressRequiresTLS reports whether a listen address is exposed beyond +// loopback and therefore must use TLS. Mirrors monitoring.AddressRequiresToken +// so the admin package does not import monitoring. +func addressRequiresTLS(addr string) bool { + host, _, err := net.SplitHostPort(strings.TrimSpace(addr)) + if err != nil { + return true + } + host = strings.TrimSpace(host) + if host == "" || host == "0.0.0.0" || host == "::" { + return true + } + if strings.EqualFold(host, "localhost") { + return false + } + ip := net.ParseIP(host) + return ip == nil || !ip.IsLoopback() +} diff --git a/internal/admin/config_test.go b/internal/admin/config_test.go new file mode 100644 index 00000000..fd54742d --- /dev/null +++ b/internal/admin/config_test.go @@ -0,0 +1,225 @@ +package admin + +import ( + "encoding/base64" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func makeKey(seed byte) string { + raw := make([]byte, sessionSigningKeyLen) + for i := range raw { + raw[i] = seed + } + return base64.StdEncoding.EncodeToString(raw) +} + +func TestConfigValidate_DisabledNoOp(t *testing.T) { + c := &Config{} + require.NoError(t, c.Validate()) +} + +func TestConfigValidate_RequiresListen(t *testing.T) { + c := &Config{ + Enabled: true, + SessionSigningKey: makeKey(1), + } + err := c.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "-adminListen must not be empty") +} + +func TestConfigValidate_RequiresSigningKey(t *testing.T) { + c := &Config{ + Enabled: true, + Listen: "127.0.0.1:8080", + } + err := c.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "adminSessionSigningKey is required") +} + +func TestConfigValidate_SigningKeyWrongLength(t *testing.T) { + short := base64.StdEncoding.EncodeToString([]byte("too short")) + c := &Config{ + Enabled: true, + Listen: "127.0.0.1:8080", + SessionSigningKey: short, + } + err := c.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "must decode to 64 bytes") +} + +func TestConfigValidate_SigningKeyNotBase64(t *testing.T) { + c := &Config{ + Enabled: true, + Listen: "127.0.0.1:8080", + SessionSigningKey: "!!!not base64!!!", + } + err := c.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "not valid base64") +} + +func TestConfigValidate_LoopbackNoTLSOK(t *testing.T) { + for _, addr := range []string{"127.0.0.1:8080", "[::1]:8080", "localhost:8080"} { + c := &Config{ + Enabled: true, + Listen: addr, + SessionSigningKey: makeKey(7), + } + require.NoErrorf(t, c.Validate(), "loopback %s", addr) + } +} + +func TestConfigValidate_NonLoopbackRequiresTLS(t *testing.T) { + c := &Config{ + Enabled: true, + Listen: "0.0.0.0:8080", + SessionSigningKey: makeKey(2), + } + err := c.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "TLS is not configured") + require.Contains(t, err.Error(), "adminAllowPlaintextNonLoopback") +} + +func TestConfigValidate_PartialTLSRejected(t *testing.T) { + for _, tc := range []struct { + name, cert, key string + }{ + {"only cert", "cert.pem", ""}, + {"only key", "", "key.pem"}, + } { + t.Run(tc.name, func(t *testing.T) { + c := &Config{ + Enabled: true, + Listen: "127.0.0.1:8080", + TLSCertFile: tc.cert, + TLSKeyFile: tc.key, + SessionSigningKey: makeKey(10), + } + err := c.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "partial TLS configuration") + }) + } +} + +func TestConfigValidate_NonLoopbackWithTLSOK(t *testing.T) { + c := &Config{ + Enabled: true, + Listen: "10.0.0.1:8443", + TLSCertFile: "cert.pem", + TLSKeyFile: "key.pem", + SessionSigningKey: makeKey(3), + } + require.NoError(t, c.Validate()) +} + +func TestConfigValidate_NonLoopbackPlaintextOptInOK(t *testing.T) { + c := &Config{ + Enabled: true, + Listen: "0.0.0.0:8080", + AllowPlaintextNonLoopback: true, + SessionSigningKey: makeKey(4), + } + require.NoError(t, c.Validate()) +} + +func TestConfigValidate_OverlappingRolesRejected(t *testing.T) { + c := &Config{ + Enabled: true, + Listen: "127.0.0.1:8080", + SessionSigningKey: makeKey(5), + ReadOnlyAccessKeys: []string{"AKIA1", "AKIA2"}, + FullAccessKeys: []string{"AKIA2", "AKIA3"}, + } + err := c.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "AKIA2") + require.Contains(t, err.Error(), "both") +} + +func TestConfigValidate_RoleIndexExclusive(t *testing.T) { + c := &Config{ + Enabled: true, + Listen: "127.0.0.1:8080", + SessionSigningKey: makeKey(6), + ReadOnlyAccessKeys: []string{"AKIA_RO"}, + FullAccessKeys: []string{"AKIA_ADMIN"}, + } + require.NoError(t, c.Validate()) + idx := c.RoleIndex() + require.Equal(t, RoleReadOnly, idx["AKIA_RO"]) + require.Equal(t, RoleFull, idx["AKIA_ADMIN"]) + _, unknown := idx["AKIA_NOBODY"] + require.False(t, unknown) +} + +func TestConfigValidate_PreviousSigningKeyValidated(t *testing.T) { + c := &Config{ + Enabled: true, + Listen: "127.0.0.1:8080", + SessionSigningKey: makeKey(1), + SessionSigningKeyPrevious: base64.StdEncoding.EncodeToString([]byte("short")), + } + err := c.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "adminSessionSigningKeyPrevious") +} + +func TestConfigDecodedSigningKeys_Order(t *testing.T) { + c := &Config{ + Enabled: true, + Listen: "127.0.0.1:8080", + SessionSigningKey: makeKey(1), + SessionSigningKeyPrevious: makeKey(2), + } + require.NoError(t, c.Validate()) + keys, err := c.DecodedSigningKeys() + require.NoError(t, err) + require.Len(t, keys, 2) + require.Equal(t, byte(1), keys[0][0]) + require.Equal(t, byte(2), keys[1][0]) +} + +func TestRole_AllowsWrite(t *testing.T) { + require.True(t, RoleFull.AllowsWrite()) + require.False(t, RoleReadOnly.AllowsWrite()) + require.False(t, Role("").AllowsWrite()) +} + +func TestAddressRequiresTLS(t *testing.T) { + require.False(t, addressRequiresTLS("127.0.0.1:8080")) + require.False(t, addressRequiresTLS("[::1]:8080")) + require.False(t, addressRequiresTLS("localhost:8080")) + require.True(t, addressRequiresTLS(":8080")) + require.True(t, addressRequiresTLS("0.0.0.0:8080")) + require.True(t, addressRequiresTLS("10.0.0.1:8080")) + require.True(t, addressRequiresTLS("garbage")) +} + +func TestConfigValidate_DisabledIgnoresBadFields(t *testing.T) { + c := &Config{ + Enabled: false, + Listen: "not a host port", + } + require.NoError(t, c.Validate()) +} + +// TestConfigValidate_PreservesContext ensures the validation message +// includes the offending field so operators can resolve errors quickly. +func TestConfigValidate_PreservesContext(t *testing.T) { + c := &Config{ + Enabled: true, + Listen: "::::::::", + SessionSigningKey: makeKey(8), + } + err := c.Validate() + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "-adminListen")) +} diff --git a/internal/admin/jwt.go b/internal/admin/jwt.go new file mode 100644 index 00000000..2b34c701 --- /dev/null +++ b/internal/admin/jwt.go @@ -0,0 +1,215 @@ +package admin + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "strings" + "time" + + "github.com/cockroachdb/errors" + "github.com/goccy/go-json" +) + +// Session TTL for admin JWTs. Aligns with the 1h Max-Age specified for the +// session cookie in the design doc (Section 6.1). +const sessionTTL = 1 * time.Hour + +// jwtSegments is the fixed number of dot-separated segments in a valid +// HS256 JWT (header.payload.signature). +const jwtSegments = 3 + +// jwtHeader is the fixed HS256 JWT header. Admin never issues anything else. +var jwtHeaderEncoded = base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) + +type jwtClaims struct { + Sub string `json:"sub"` + Role string `json:"role"` + IAT int64 `json:"iat"` + EXP int64 `json:"exp"` + JTI string `json:"jti"` +} + +// Clock is the small time abstraction used by the signer/verifier so tests +// can control token freshness without sleeping. +type Clock func() time.Time + +// SystemClock returns wall-clock time and is the default for production. +func SystemClock() time.Time { return time.Now().UTC() } + +// Signer issues HS256-signed JWTs using the primary admin signing key. Only +// the primary key can sign new tokens; the previous key is verify-only and +// lives on Verifier. +type Signer struct { + key []byte + clock Clock +} + +// NewSigner constructs a Signer; key must be exactly sessionSigningKeyLen +// bytes (validated up-front so we do not catch this inside the hot path). +func NewSigner(key []byte, clock Clock) (*Signer, error) { + if len(key) != sessionSigningKeyLen { + return nil, errors.WithStack(errors.Newf("signer key must be %d bytes, got %d", sessionSigningKeyLen, len(key))) + } + if clock == nil { + clock = SystemClock + } + copied := append([]byte{}, key...) + return &Signer{key: copied, clock: clock}, nil +} + +// Sign mints a fresh JWT for principal with the admin session TTL. +func (s *Signer) Sign(principal AuthPrincipal) (string, error) { + jti, err := randomJTI() + if err != nil { + return "", err + } + now := s.clock().UTC() + claims := jwtClaims{ + Sub: principal.AccessKey, + Role: string(principal.Role), + IAT: now.Unix(), + EXP: now.Add(sessionTTL).Unix(), + JTI: jti, + } + payload, err := json.Marshal(claims) + if err != nil { + return "", errors.Wrap(err, "marshal jwt claims") + } + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + signingInput := jwtHeaderEncoded + "." + encodedPayload + mac := hmac.New(sha256.New, s.key) + mac.Write([]byte(signingInput)) + sig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + return signingInput + "." + sig, nil +} + +// Verifier validates HS256 admin tokens. It tries the primary key first and +// falls back to the optional previous key so operators can rotate keys +// without logging everybody out at once. +type Verifier struct { + keys [][]byte + clock Clock +} + +// NewVerifier builds a verifier from keys in priority order (primary first, +// optional previous second). Zero-length keys are rejected. +func NewVerifier(keys [][]byte, clock Clock) (*Verifier, error) { + if len(keys) == 0 { + return nil, errors.New("verifier requires at least one key") + } + for i, k := range keys { + if len(k) != sessionSigningKeyLen { + return nil, errors.WithStack(errors.Newf("verifier key[%d] must be %d bytes, got %d", i, sessionSigningKeyLen, len(k))) + } + } + copied := make([][]byte, len(keys)) + for i, k := range keys { + copied[i] = append([]byte{}, k...) + } + if clock == nil { + clock = SystemClock + } + return &Verifier{keys: copied, clock: clock}, nil +} + +// ErrInvalidToken is returned for any verification failure without leaking +// which specific check failed. Callers should log the wrapped error but +// return a single 401 to clients regardless of the cause. +var ErrInvalidToken = errors.New("invalid admin session token") + +// Verify parses token, checks the signature against each configured key, +// and confirms it is within its validity window. On success it returns the +// embedded AuthPrincipal. +func (v *Verifier) Verify(token string) (AuthPrincipal, error) { + signingInput, payloadSeg, sig, err := splitSignedToken(token) + if err != nil { + return AuthPrincipal{}, err + } + if err := v.checkSignature(signingInput, sig); err != nil { + return AuthPrincipal{}, err + } + claims, err := decodeClaims(payloadSeg) + if err != nil { + return AuthPrincipal{}, err + } + return v.validateClaims(claims) +} + +// clockSkewToleranceSeconds is the slack we allow on the "issued in the +// future" check so that minor NTP drift between admin nodes does not +// produce spurious 401s. +const clockSkewToleranceSeconds = 30 + +func splitSignedToken(token string) (signingInput, payloadSeg string, sig []byte, err error) { + parts := strings.Split(token, ".") + if len(parts) != jwtSegments { + return "", "", nil, errors.Wrap(ErrInvalidToken, "token does not have three segments") + } + if parts[0] != jwtHeaderEncoded { + return "", "", nil, errors.Wrap(ErrInvalidToken, "unsupported jwt header") + } + sig, err = base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return "", "", nil, errors.Wrap(ErrInvalidToken, "malformed signature") + } + return parts[0] + "." + parts[1], parts[1], sig, nil +} + +func (v *Verifier) checkSignature(signingInput string, providedSig []byte) error { + for _, k := range v.keys { + mac := hmac.New(sha256.New, k) + mac.Write([]byte(signingInput)) + if subtle.ConstantTimeCompare(mac.Sum(nil), providedSig) == 1 { + return nil + } + } + return errors.Wrap(ErrInvalidToken, "signature mismatch") +} + +func decodeClaims(payloadSeg string) (jwtClaims, error) { + payload, err := base64.RawURLEncoding.DecodeString(payloadSeg) + if err != nil { + return jwtClaims{}, errors.Wrap(ErrInvalidToken, "malformed payload") + } + var claims jwtClaims + if err := json.Unmarshal(payload, &claims); err != nil { + return jwtClaims{}, errors.Wrap(ErrInvalidToken, "payload is not json") + } + return claims, nil +} + +func (v *Verifier) validateClaims(claims jwtClaims) (AuthPrincipal, error) { + now := v.clock().UTC().Unix() + if claims.EXP == 0 || now >= claims.EXP { + return AuthPrincipal{}, errors.Wrap(ErrInvalidToken, "token expired") + } + // A missing iat is treated the same as a missing exp: the admin + // Signer always sets iat, so a token without one is either + // malformed, produced by a foreign signer that happens to share + // the HS256 key, or a future-version token we do not recognise. + if claims.IAT == 0 { + return AuthPrincipal{}, errors.Wrap(ErrInvalidToken, "missing iat") + } + if now+clockSkewToleranceSeconds < claims.IAT { + return AuthPrincipal{}, errors.Wrap(ErrInvalidToken, "token issued in the future") + } + if claims.Sub == "" { + return AuthPrincipal{}, errors.Wrap(ErrInvalidToken, "missing sub") + } + role := Role(claims.Role) + if role != RoleReadOnly && role != RoleFull { + return AuthPrincipal{}, errors.Wrap(ErrInvalidToken, "unknown role") + } + return AuthPrincipal{AccessKey: claims.Sub, Role: role}, nil +} + +func randomJTI() (string, error) { + var raw [16]byte + if _, err := rand.Read(raw[:]); err != nil { + return "", errors.Wrap(err, "read random bytes for jti") + } + return base64.RawURLEncoding.EncodeToString(raw[:]), nil +} diff --git a/internal/admin/jwt_test.go b/internal/admin/jwt_test.go new file mode 100644 index 00000000..b7301f4e --- /dev/null +++ b/internal/admin/jwt_test.go @@ -0,0 +1,184 @@ +package admin + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "strings" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +func fixedClock(t time.Time) Clock { return func() time.Time { return t } } + +func newSignerForTest(t *testing.T, seed byte, clk Clock) *Signer { + t.Helper() + key := bytes.Repeat([]byte{seed}, sessionSigningKeyLen) + s, err := NewSigner(key, clk) + require.NoError(t, err) + return s +} + +func newVerifierForTest(t *testing.T, seeds []byte, clk Clock) *Verifier { + t.Helper() + keys := make([][]byte, 0, len(seeds)) + for _, seed := range seeds { + keys = append(keys, bytes.Repeat([]byte{seed}, sessionSigningKeyLen)) + } + v, err := NewVerifier(keys, clk) + require.NoError(t, err) + return v +} + +func TestJWT_SignVerifyRoundTrip(t *testing.T) { + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + signer := newSignerForTest(t, 1, clk) + verifier := newVerifierForTest(t, []byte{1}, clk) + + principal := AuthPrincipal{AccessKey: "AKIA_OK", Role: RoleFull} + token, err := signer.Sign(principal) + require.NoError(t, err) + require.Equal(t, 2, strings.Count(token, ".")) + + got, err := verifier.Verify(token) + require.NoError(t, err) + require.Equal(t, principal, got) +} + +func TestJWT_RejectsExpired(t *testing.T) { + signClk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + verifyClk := fixedClock(time.Unix(1_700_000_000+int64(sessionTTL.Seconds()+1), 0).UTC()) + signer := newSignerForTest(t, 1, signClk) + verifier := newVerifierForTest(t, []byte{1}, verifyClk) + + token, err := signer.Sign(AuthPrincipal{AccessKey: "AKIA", Role: RoleReadOnly}) + require.NoError(t, err) + + _, err = verifier.Verify(token) + require.Error(t, err) + require.True(t, errors.Is(err, ErrInvalidToken)) +} + +func TestJWT_RejectsFutureIssued(t *testing.T) { + // Sign in the future; verifier clock is now. + signClk := fixedClock(time.Unix(1_700_000_000+600, 0).UTC()) + verifyClk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + signer := newSignerForTest(t, 1, signClk) + verifier := newVerifierForTest(t, []byte{1}, verifyClk) + + token, err := signer.Sign(AuthPrincipal{AccessKey: "AKIA", Role: RoleFull}) + require.NoError(t, err) + + _, err = verifier.Verify(token) + require.Error(t, err) + require.True(t, errors.Is(err, ErrInvalidToken)) +} + +func TestJWT_RejectsWrongSignature(t *testing.T) { + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + signer := newSignerForTest(t, 1, clk) + verifier := newVerifierForTest(t, []byte{9}, clk) // different key + + token, err := signer.Sign(AuthPrincipal{AccessKey: "AKIA", Role: RoleFull}) + require.NoError(t, err) + + _, err = verifier.Verify(token) + require.Error(t, err) + require.True(t, errors.Is(err, ErrInvalidToken)) +} + +func TestJWT_PreviousKeyAccepted(t *testing.T) { + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + // Previous key signs the token. + oldSigner := newSignerForTest(t, 2, clk) + token, err := oldSigner.Sign(AuthPrincipal{AccessKey: "AKIA_OLD", Role: RoleReadOnly}) + require.NoError(t, err) + + // Verifier has primary=new, previous=old. + verifier := newVerifierForTest(t, []byte{1, 2}, clk) + got, err := verifier.Verify(token) + require.NoError(t, err) + require.Equal(t, "AKIA_OLD", got.AccessKey) + require.Equal(t, RoleReadOnly, got.Role) +} + +func TestJWT_AfterRotationOldPreviousRejected(t *testing.T) { + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + // Token minted with key seed=7. + signer := newSignerForTest(t, 7, clk) + token, err := signer.Sign(AuthPrincipal{AccessKey: "AKIA", Role: RoleFull}) + require.NoError(t, err) + + // After rotation completes, only seeds {1,2} are configured. + verifier := newVerifierForTest(t, []byte{1, 2}, clk) + _, err = verifier.Verify(token) + require.Error(t, err) + require.True(t, errors.Is(err, ErrInvalidToken)) +} + +func TestJWT_MalformedTokens(t *testing.T) { + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + verifier := newVerifierForTest(t, []byte{1}, clk) + + cases := []string{ + "", + "abc", + "a.b", + "a.b.c.d", + "not-header.payload.sig", + } + for _, tok := range cases { + _, err := verifier.Verify(tok) + require.Errorf(t, err, "token %q should fail", tok) + require.True(t, errors.Is(err, ErrInvalidToken)) + } +} + +func TestJWT_RejectsUnknownRole(t *testing.T) { + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + key := bytes.Repeat([]byte{3}, sessionSigningKeyLen) + // Manually craft a token with role=admin (unsupported). + payload := []byte(`{"sub":"AKIA","role":"admin","iat":1700000000,"exp":1700003600,"jti":"j"}`) + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + signingInput := jwtHeaderEncoded + "." + encodedPayload + mac := hmac.New(sha256.New, key) + mac.Write([]byte(signingInput)) + sig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + token := signingInput + "." + sig + + verifier := newVerifierForTest(t, []byte{3}, clk) + _, err := verifier.Verify(token) + require.Error(t, err) + require.True(t, errors.Is(err, ErrInvalidToken)) +} + +func TestJWT_RejectsMissingSub(t *testing.T) { + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + key := bytes.Repeat([]byte{4}, sessionSigningKeyLen) + payload := []byte(`{"sub":"","role":"full","iat":1700000000,"exp":1700003600,"jti":"j"}`) + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + signingInput := jwtHeaderEncoded + "." + encodedPayload + mac := hmac.New(sha256.New, key) + mac.Write([]byte(signingInput)) + sig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) + token := signingInput + "." + sig + + verifier := newVerifierForTest(t, []byte{4}, clk) + _, err := verifier.Verify(token) + require.Error(t, err) + require.True(t, errors.Is(err, ErrInvalidToken)) +} + +func TestNewSigner_RejectsWrongKeyLength(t *testing.T) { + _, err := NewSigner([]byte("short"), nil) + require.Error(t, err) +} + +func TestNewVerifier_RejectsEmptyKeys(t *testing.T) { + _, err := NewVerifier(nil, nil) + require.Error(t, err) +} diff --git a/internal/admin/logout_csrf_test.go b/internal/admin/logout_csrf_test.go new file mode 100644 index 00000000..de608333 --- /dev/null +++ b/internal/admin/logout_csrf_test.go @@ -0,0 +1,100 @@ +package admin + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestLogout_RejectsUnauthenticated ensures a cross-site caller cannot +// POST /admin/api/v1/auth/logout without a valid session, which was the +// logout-CSRF vector Codex flagged. +func TestLogout_RejectsUnauthenticated(t *testing.T) { + srv := newServerForTest(t) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/auth/logout", nil) + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + // The server must not have set any cookies on a rejected logout. + require.Empty(t, rec.Result().Cookies()) +} + +// TestLogout_RequiresCSRF ensures that even with a valid session cookie, +// logout refuses to execute without a matching X-Admin-CSRF header. +// SameSite=Strict already blocks the cross-site leg, but the server-side +// CSRF check is an explicit belt-and-braces guard. +func TestLogout_RequiresCSRF(t *testing.T) { + srv := newServerForTest(t) + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + // Log in to collect cookies. + cookies := loginForTest(t, ts) + + // POST /auth/logout with session cookie but NO X-Admin-CSRF header. + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/auth/logout", nil) + for _, c := range cookies { + req.AddCookie(c) + } + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "csrf_missing") +} + +// TestLogout_HappyPath verifies that a well-formed logout (session +// cookie + matching CSRF header + cookie) succeeds and returns 204. +func TestLogout_HappyPath(t *testing.T) { + srv := newServerForTest(t) + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + cookies := loginForTest(t, ts) + var csrfValue string + for _, c := range cookies { + if c.Name == csrfCookieName { + csrfValue = c.Value + } + } + require.NotEmpty(t, csrfValue) + + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/auth/logout", nil) + for _, c := range cookies { + req.AddCookie(c) + } + req.Header.Set(csrfHeaderName, csrfValue) + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + require.Equal(t, http.StatusNoContent, rec.Code) + + // Expired-cookie Set-Cookie headers must still be emitted so the + // client actually forgets the session. + var cleared int + for _, c := range rec.Result().Cookies() { + if c.MaxAge == -1 { + cleared++ + } + } + require.Equal(t, 2, cleared, "expected both admin_session and admin_csrf to be cleared") +} + +// loginForTest POSTs a valid login against ts and returns the resulting +// cookies. It is a small helper shared by the logout tests above. +func loginForTest(t *testing.T, ts *httptest.Server) []*http.Cookie { + t.Helper() + body := []byte(`{"access_key":"AKIA_ADMIN","secret_key":"ADMIN_SECRET"}`) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, + ts.URL+"/admin/api/v1/auth/login", bytes.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + cookies := resp.Cookies() + require.Len(t, cookies, 2) + return cookies +} diff --git a/internal/admin/middleware.go b/internal/admin/middleware.go new file mode 100644 index 00000000..a1622bc7 --- /dev/null +++ b/internal/admin/middleware.go @@ -0,0 +1,232 @@ +package admin + +import ( + "context" + "crypto/subtle" + "errors" + "log/slog" + "net/http" + "strings" + "time" + + pkgerrors "github.com/cockroachdb/errors" +) + +// Cookie names used throughout the admin surface. We define them in one +// place so the login handler, CSRF middleware, and logout handler cannot +// drift. +const ( + sessionCookieName = "admin_session" + csrfCookieName = "admin_csrf" + csrfHeaderName = "X-Admin-CSRF" + + // defaultBodyLimit matches docs/design 4.4: 64 KiB upper bound for + // every POST/PUT endpoint. DynamoDB table schemas and S3 bucket + // metadata are each well under this bound. + defaultBodyLimit int64 = 64 << 10 +) + +// contextKey is the unexported type for storing values in request +// contexts. Using a string type directly would risk collisions with other +// packages. +type contextKey int + +const ( + ctxKeyPrincipal contextKey = iota + 1 +) + +// PrincipalFromContext returns the authenticated principal associated +// with the request context, or false if the middleware did not set one. +func PrincipalFromContext(ctx context.Context) (AuthPrincipal, bool) { + v, ok := ctx.Value(ctxKeyPrincipal).(AuthPrincipal) + return v, ok +} + +// BodyLimit caps each request body at `limit` bytes via +// http.MaxBytesReader. Handlers that read the body are responsible for +// detecting overflow (via IsMaxBytesError / errors.As on +// *http.MaxBytesError) and calling WriteMaxBytesError to respond 413. +// We intentionally do not centralise that translation in the middleware +// chain: different handlers parse bodies with different decoders (json, +// form, multipart) and each already has a natural error path, so a +// wrapper ResponseWriter would either double-write or mask subsequent +// errors. +func BodyLimit(limit int64) func(http.Handler) http.Handler { + if limit <= 0 { + limit = defaultBodyLimit + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Body != nil { + r.Body = http.MaxBytesReader(w, r.Body, limit) + } + next.ServeHTTP(w, r) + }) + } +} + +// WriteMaxBytesError is the canonical 413 response body for admin +// handlers that detected an http.MaxBytesError while reading a request +// body. It keeps the JSON error shape consistent with the rest of the +// admin surface. +func WriteMaxBytesError(w http.ResponseWriter) { + writeJSONError(w, http.StatusRequestEntityTooLarge, "payload_too_large", + "request body exceeds the 64 KiB admin limit") +} + +// IsMaxBytesError reports whether err was produced because the client +// uploaded more bytes than BodyLimit permits. +func IsMaxBytesError(err error) bool { + if err == nil { + return false + } + var mbe *http.MaxBytesError + return errors.As(err, &mbe) +} + +// SessionAuth parses the admin_session cookie, validates it against the +// verifier, and attaches the resulting AuthPrincipal to the request +// context. Requests without a session, or with an invalid/expired one, +// are rejected with 401. +func SessionAuth(verifier *Verifier) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(sessionCookieName) + if err != nil || strings.TrimSpace(cookie.Value) == "" { + writeJSONError(w, http.StatusUnauthorized, "unauthenticated", "missing session cookie") + return + } + principal, err := verifier.Verify(cookie.Value) + if err != nil { + writeJSONError(w, http.StatusUnauthorized, "unauthenticated", "invalid or expired session") + return + } + ctx := context.WithValue(r.Context(), ctxKeyPrincipal, principal) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// RequireWriteRole blocks the handler unless the current principal may +// execute write operations. Must be composed after SessionAuth. +func RequireWriteRole() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + principal, ok := PrincipalFromContext(r.Context()) + if !ok { + writeJSONError(w, http.StatusUnauthorized, "unauthenticated", "no principal on context") + return + } + if !principal.Role.AllowsWrite() { + writeJSONError(w, http.StatusForbidden, "forbidden", + "this endpoint requires admin.full_access_keys membership") + return + } + next.ServeHTTP(w, r) + }) + } +} + +// CSRFDoubleSubmit enforces the double-submit cookie rule for state +// changing methods (POST, PUT, DELETE, PATCH). The admin_csrf cookie is +// minted at login; the SPA copies its value into the X-Admin-CSRF header +// on every write. We reject the request if either the cookie or the +// header is missing or if they do not match. GET/HEAD pass through +// untouched. +func CSRFDoubleSubmit() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + next.ServeHTTP(w, r) + return + } + cookie, err := r.Cookie(csrfCookieName) + if err != nil || strings.TrimSpace(cookie.Value) == "" { + writeJSONError(w, http.StatusForbidden, "csrf_missing", "admin_csrf cookie is required") + return + } + header := strings.TrimSpace(r.Header.Get(csrfHeaderName)) + if header == "" { + writeJSONError(w, http.StatusForbidden, "csrf_missing", + "X-Admin-CSRF header is required for write operations") + return + } + // Constant-time comparison on the byte contents once we + // know the lengths match. A length mismatch is itself not + // secret (the server mints both tokens with a fixed 32-byte + // width, so any divergence means an attacker forged or + // corrupted the value — the response is 403 in every case + // anyway), so short-circuiting there does not leak anything + // useful to a timing attacker. + if len(cookie.Value) != len(header) || + subtle.ConstantTimeCompare([]byte(cookie.Value), []byte(header)) != 1 { + writeJSONError(w, http.StatusForbidden, "csrf_mismatch", "CSRF token mismatch") + return + } + next.ServeHTTP(w, r) + }) + } +} + +// auditRecorder is the ResponseWriter wrapper the Audit middleware uses +// to learn the final status code without forcing the handler to pass it +// back explicitly. +type auditRecorder struct { + http.ResponseWriter + status int + written bool +} + +func (a *auditRecorder) WriteHeader(code int) { + if !a.written { + a.status = code + a.written = true + } + a.ResponseWriter.WriteHeader(code) +} + +func (a *auditRecorder) Write(b []byte) (int, error) { + if !a.written { + a.status = http.StatusOK + a.written = true + } + n, err := a.ResponseWriter.Write(b) + if err != nil { + return n, pkgerrors.Wrap(err, "audit recorder write") + } + return n, nil +} + +// Audit writes a structured slog line for every state-changing admin +// request, as required by docs/design Section 10. GET/HEAD requests are +// not audited (read traffic can be too loud and does not modify state). +// The logger uses the "admin_audit" key so operators can filter. Callers +// wire this middleware after SessionAuth so the principal is available. +func Audit(logger *slog.Logger) func(http.Handler) http.Handler { + if logger == nil { + logger = slog.Default() + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + next.ServeHTTP(w, r) + return + } + rec := &auditRecorder{ResponseWriter: w} + start := time.Now() + next.ServeHTTP(rec, r) + principal, _ := PrincipalFromContext(r.Context()) + logger.LogAttrs(r.Context(), slog.LevelInfo, "admin_audit", + slog.String("actor", principal.AccessKey), + slog.String("role", string(principal.Role)), + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + slog.Int("status", rec.status), + slog.String("remote", r.RemoteAddr), + slog.Duration("duration", time.Since(start)), + ) + }) + } +} diff --git a/internal/admin/middleware_test.go b/internal/admin/middleware_test.go new file mode 100644 index 00000000..fc7c00df --- /dev/null +++ b/internal/admin/middleware_test.go @@ -0,0 +1,231 @@ +package admin + +import ( + "bytes" + "context" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestBodyLimit_Allows(t *testing.T) { + var gotBody []byte + h := BodyLimit(128)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, err := io.ReadAll(r.Body) + require.NoError(t, err) + gotBody = b + w.WriteHeader(http.StatusNoContent) + })) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/x", strings.NewReader("hello")) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusNoContent, rec.Code) + require.Equal(t, "hello", string(gotBody)) +} + +func TestBodyLimit_Exceeded(t *testing.T) { + oversize := strings.Repeat("x", 200) + h := BodyLimit(64)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := io.ReadAll(r.Body) + if err != nil { + if IsMaxBytesError(err) { + WriteMaxBytesError(w) + return + } + t.Fatalf("unexpected error: %v", err) + } + w.WriteHeader(http.StatusNoContent) + })) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/x", strings.NewReader(oversize)) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) + require.Contains(t, rec.Body.String(), "payload_too_large") +} + +func TestSessionAuth_MissingCookie(t *testing.T) { + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + v := newVerifierForTest(t, []byte{1}, clk) + h := SessionAuth(v)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/cluster", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + require.Contains(t, rec.Body.String(), "unauthenticated") +} + +func TestSessionAuth_HappyPathPutsPrincipalOnContext(t *testing.T) { + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + signer := newSignerForTest(t, 1, clk) + v := newVerifierForTest(t, []byte{1}, clk) + + principal := AuthPrincipal{AccessKey: "AKIA", Role: RoleFull} + token, err := signer.Sign(principal) + require.NoError(t, err) + + var gotPrincipal AuthPrincipal + h := SessionAuth(v)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p, ok := PrincipalFromContext(r.Context()) + require.True(t, ok) + gotPrincipal = p + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/cluster", nil) + req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: token}) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, principal, gotPrincipal) +} + +func TestSessionAuth_InvalidToken(t *testing.T) { + clk := fixedClock(time.Unix(1_700_000_000, 0).UTC()) + v := newVerifierForTest(t, []byte{1}, clk) + h := SessionAuth(v)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/cluster", nil) + req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "garbage"}) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestRequireWriteRole_ReadOnlyRejected(t *testing.T) { + h := RequireWriteRole()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + ctx := context.WithValue(context.Background(), ctxKeyPrincipal, + AuthPrincipal{AccessKey: "AKIA_RO", Role: RoleReadOnly}) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/dynamo/tables", nil).WithContext(ctx) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "forbidden") +} + +func TestRequireWriteRole_FullAllowed(t *testing.T) { + h := RequireWriteRole()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusCreated) + })) + ctx := context.WithValue(context.Background(), ctxKeyPrincipal, + AuthPrincipal{AccessKey: "AKIA_ADMIN", Role: RoleFull}) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/dynamo/tables", nil).WithContext(ctx) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusCreated, rec.Code) +} + +func TestRequireWriteRole_NoPrincipal(t *testing.T) { + h := RequireWriteRole()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/dynamo/tables", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestCSRF_GETPasses(t *testing.T) { + called := false + h := CSRFDoubleSubmit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/dynamo/tables", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.True(t, called) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestCSRF_WriteMissingCookie(t *testing.T) { + h := CSRFDoubleSubmit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/dynamo/tables", nil) + req.Header.Set(csrfHeaderName, "tok") + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "csrf_missing") +} + +func TestCSRF_WriteMissingHeader(t *testing.T) { + h := CSRFDoubleSubmit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/dynamo/tables", nil) + req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: "tok"}) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "csrf_missing") +} + +func TestCSRF_Mismatch(t *testing.T) { + h := CSRFDoubleSubmit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/dynamo/tables", nil) + req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: "a"}) + req.Header.Set(csrfHeaderName, "b") + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "csrf_mismatch") +} + +func TestCSRF_MatchAllows(t *testing.T) { + h := CSRFDoubleSubmit()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/dynamo/tables", nil) + req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: "same"}) + req.Header.Set(csrfHeaderName, "same") + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusAccepted, rec.Code) +} + +func TestAudit_LogsWriteRequest(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})) + ctx := context.WithValue(context.Background(), ctxKeyPrincipal, + AuthPrincipal{AccessKey: "AKIA_ADMIN", Role: RoleFull}) + h := Audit(logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + req := httptest.NewRequest(http.MethodPost, "/admin/api/v1/dynamo/tables", nil).WithContext(ctx) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusNoContent, rec.Code) + out := buf.String() + require.Contains(t, out, `"msg":"admin_audit"`) + require.Contains(t, out, `"actor":"AKIA_ADMIN"`) + require.Contains(t, out, `"method":"POST"`) + require.Contains(t, out, `"status":204`) +} + +func TestAudit_SkipsReads(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})) + h := Audit(logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/dynamo/tables", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + require.Empty(t, buf.String()) +} diff --git a/internal/admin/principal.go b/internal/admin/principal.go new file mode 100644 index 00000000..883bac79 --- /dev/null +++ b/internal/admin/principal.go @@ -0,0 +1,28 @@ +package admin + +// Role represents the authorization tier of an authenticated admin session. +type Role string + +const ( + // RoleReadOnly permits only GET endpoints. + RoleReadOnly Role = "read_only" + // RoleFull permits the entire admin CRUD surface. + RoleFull Role = "full" +) + +// AllowsWrite reports whether the role may execute state-mutating operations. +func (r Role) AllowsWrite() bool { + return r == RoleFull +} + +// AuthPrincipal is the authenticated caller derived from a session cookie or, +// in the future, a follower→leader forwarded request. The admin handler and +// adapter internal entrypoints pass it around instead of a raw HTTP request +// so the authorization model stays consistent whether the request arrived +// via SigV4 or JWT. +type AuthPrincipal struct { + // AccessKey is the caller's SigV4 access key identifier. + AccessKey string + // Role is the role resolved from the server-side access key table. + Role Role +} diff --git a/internal/admin/ratelimit.go b/internal/admin/ratelimit.go new file mode 100644 index 00000000..ce9dc259 --- /dev/null +++ b/internal/admin/ratelimit.go @@ -0,0 +1,111 @@ +package admin + +import ( + "net" + "net/http" + "sync" + "time" +) + +// rateLimiterMaxEntries is a hard cap on the number of distinct source +// IPs the limiter will track at once. Hitting the cap means an attacker +// is spraying us with unique source addresses; we respond by refusing +// to add new entries (and therefore refusing those logins) until the +// window ages them out. We first sweep expired windows before +// concluding the map is full, so well-behaved traffic never trips on +// the cap. +const rateLimiterMaxEntries = 1024 + +// rateLimiter is a fixed-window, in-memory per-IP rate limiter. It is +// intentionally simple: the admin login endpoint is low-volume, we only +// need to slow brute-force guessing, and distributed accounting would +// require Raft round-trips per login (unreasonable for the threat model). +// Entries older than window are pruned lazily on the next hit. +type rateLimiter struct { + limit int + window time.Duration + clock Clock + + mu sync.Mutex + entries map[string]*rateLimiterEntry +} + +type rateLimiterEntry struct { + windowStart time.Time + count int +} + +func newRateLimiter(limit int, window time.Duration, clock Clock) *rateLimiter { + if clock == nil { + clock = SystemClock + } + return &rateLimiter{ + limit: limit, + window: window, + clock: clock, + entries: make(map[string]*rateLimiterEntry), + } +} + +// allow returns true if the client at ip may perform one more action +// within the current window, otherwise false. It is safe for concurrent +// use. +func (rl *rateLimiter) allow(ip string) bool { + now := rl.clock().UTC() + rl.mu.Lock() + defer rl.mu.Unlock() + + e, ok := rl.entries[ip] + if ok { + if now.Sub(e.windowStart) >= rl.window { + e.windowStart = now + e.count = 1 + return true + } + if e.count >= rl.limit { + return false + } + e.count++ + return true + } + + // Unknown IP — we need to create a new entry. Enforce the hard + // cap on distinct tracked IPs before doing so. + if len(rl.entries) >= rateLimiterMaxEntries { + // Try to reclaim space from expired windows first. If that + // still leaves us at cap, refuse the new entry. Refusing + // (rather than evicting an arbitrary old entry) is safer: + // it prevents a spray of fresh IPs from silently erasing a + // legitimate user's in-progress rate-limit state. + rl.sweepExpiredLocked(now) + if len(rl.entries) >= rateLimiterMaxEntries { + return false + } + } + rl.entries[ip] = &rateLimiterEntry{windowStart: now, count: 1} + return true +} + +// sweepExpiredLocked drops entries whose window has elapsed. The caller +// must hold rl.mu. +func (rl *rateLimiter) sweepExpiredLocked(now time.Time) { + for k, v := range rl.entries { + if now.Sub(v.windowStart) >= rl.window { + delete(rl.entries, k) + } + } +} + +// clientIP extracts the IP part of the request's remote address. It falls +// back to the full RemoteAddr if SplitHostPort fails. We do not consult +// X-Forwarded-For here because the admin listener is expected to run +// directly on the node (loopback or behind a trusted load balancer in +// the TLS case); honouring client-controlled headers would let an +// attacker evade the limiter. +func clientIP(r *http.Request) string { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} diff --git a/internal/admin/ratelimit_test.go b/internal/admin/ratelimit_test.go new file mode 100644 index 00000000..49df26c7 --- /dev/null +++ b/internal/admin/ratelimit_test.go @@ -0,0 +1,45 @@ +package admin + +import ( + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRateLimiter_HardCapRefusesNewIPs(t *testing.T) { + now := time.Unix(1_700_000_000, 0).UTC() + clk := func() time.Time { return now } + rl := newRateLimiter(5, time.Minute, clk) + + // Fill the map to the cap by bringing in distinct IPs. + for i := 0; i < rateLimiterMaxEntries; i++ { + ip := "10.0." + strconv.Itoa(i/256) + "." + strconv.Itoa(i%256) + require.Truef(t, rl.allow(ip), "IP %s should be allowed when map is below cap", ip) + } + require.Equal(t, rateLimiterMaxEntries, len(rl.entries)) + + // A brand-new IP must be refused because the cap is reached and + // no entries have expired yet. + require.False(t, rl.allow("192.168.1.1")) + + // Existing IPs are still accounted for (not evicted). + require.True(t, rl.allow("10.0.0.0")) +} + +func TestRateLimiter_HardCapReclaimsAfterWindow(t *testing.T) { + now := time.Unix(1_700_000_000, 0).UTC() + clk := func() time.Time { return now } + rl := newRateLimiter(5, time.Minute, clk) + + for i := 0; i < rateLimiterMaxEntries; i++ { + ip := "10.0." + strconv.Itoa(i/256) + "." + strconv.Itoa(i%256) + require.True(t, rl.allow(ip)) + } + + // Advance past the window — the next allow() call must be able to + // sweep the expired entries and then admit the new IP. + now = now.Add(2 * time.Minute) + require.True(t, rl.allow("192.168.2.2")) +} diff --git a/internal/admin/router.go b/internal/admin/router.go new file mode 100644 index 00000000..afb36ac3 --- /dev/null +++ b/internal/admin/router.go @@ -0,0 +1,258 @@ +package admin + +import ( + "errors" + "io" + "io/fs" + "net/http" + "path" + "strings" + + "github.com/goccy/go-json" +) + +// Constants for the admin URL namespace. Centralised here so the router, +// handlers, and tests all agree on the paths. The admin listener only +// serves URLs under /admin/*; anything else yields a 404. +// +// The "root" variants (without a trailing slash) are treated as the +// directory itself so that requests like `/admin/api/v1` or +// `/admin/assets` resolve to a JSON 404 rather than falling through to +// the SPA fallback and being answered with index.html. +// +// pathAPIRoot / pathPrefixAPI guard the whole `/admin/api*` namespace — +// not just v1 — so that requests to currently-unimplemented API +// versions (`/admin/api`, `/admin/api/v2`, ...) return a JSON 404 +// instead of being silently answered with the SPA HTML. +const ( + pathPrefixAdmin = "/admin" + pathAPIRoot = "/admin/api" + pathPrefixAPI = pathAPIRoot + "/" + pathAPIv1Root = "/admin/api/v1" + pathPrefixAPIv1 = pathAPIv1Root + "/" + pathHealthz = "/admin/healthz" + pathAssetsRoot = "/admin/assets" + pathPrefixAssets = pathAssetsRoot + "/" + pathRootAssetsDir = "assets" + pathIndexHTML = "index.html" +) + +// APIHandler is the bridge between the router and all JSON API endpoints. +// Everything under /admin/api/v1/ resolves through it; individual endpoint +// routing is the handler's responsibility (see apiMux below). +type APIHandler http.Handler + +// Router dispatches admin HTTP requests in the strict order mandated by +// the design doc (Section 5.3): API routes first, then healthz, then +// static assets, then SPA fallback. We do NOT use http.ServeMux because +// its LongestPrefix matching rules would let /admin/api/v1/... slip into +// the SPA catch-all when the JSON handler returns a 404. +type Router struct { + api http.Handler + static fs.FS + notFind http.Handler +} + +// NewRouter builds the admin router. +// +// - api handles /admin/api/v1/*. It must return a JSON body itself; the +// router never rewrites its response. +// - static, if non-nil, backs both /admin/assets/* and the /admin/* +// SPA catch-all (which always serves index.html). A nil static FS +// causes 404s for asset and SPA routes, which is the expected state +// while the SPA has not been built yet. +func NewRouter(api http.Handler, static fs.FS) *Router { + return &Router{ + api: api, + static: static, + notFind: http.HandlerFunc(writeJSONNotFound), + } +} + +// ServeHTTP is the single entrypoint. Routing cascade in priority order: +// 1. /admin/api/v1/* → API handler +// 2. /admin/healthz → plain text +// 3. /admin/assets/* → static file +// 4. /admin/* → index.html (SPA fallback) +// 5. anything else → 404 JSON +func (rt *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { + rt.dispatch(rt.classify(r.URL.Path)).ServeHTTP(w, r) +} + +// routeKind enumerates the admin URL classes the router distinguishes. +// Splitting classify/dispatch keeps ServeHTTP under the cyclomatic +// complexity ceiling while preserving the strict evaluation order that +// API-before-SPA routing depends on. +type routeKind int + +const ( + routeAPIv1 routeKind = iota + routeAPIOther + routeHealthz + routeAssetsRoot + routeAsset + routeSPA + routeUnknown +) + +func (rt *Router) classify(p string) routeKind { + if k, ok := classifyAPI(p); ok { + return k + } + if k, ok := classifyAssets(p); ok { + return k + } + if p == pathHealthz { + return routeHealthz + } + if p == pathPrefixAdmin || strings.HasPrefix(p, pathPrefixAdmin+"/") { + return routeSPA + } + return routeUnknown +} + +func classifyAPI(p string) (routeKind, bool) { + switch { + case strings.HasPrefix(p, pathPrefixAPIv1): + return routeAPIv1, true + case p == pathAPIRoot, p == pathAPIv1Root, strings.HasPrefix(p, pathPrefixAPI): + return routeAPIOther, true + } + return 0, false +} + +func classifyAssets(p string) (routeKind, bool) { + switch { + case p == pathAssetsRoot: + return routeAssetsRoot, true + case strings.HasPrefix(p, pathPrefixAssets): + return routeAsset, true + } + return 0, false +} + +func (rt *Router) dispatch(k routeKind) http.Handler { + switch k { + case routeAPIv1: + if rt.api == nil { + return rt.notFind + } + return rt.api + case routeHealthz: + return http.HandlerFunc(rt.serveHealth) + case routeAsset: + return http.HandlerFunc(rt.serveAsset) + case routeSPA: + return http.HandlerFunc(rt.serveSPA) + case routeAPIOther, routeAssetsRoot, routeUnknown: + return rt.notFind + } + return rt.notFind +} + +func (rt *Router) serveHealth(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet && r.Method != http.MethodHead { + writeJSONError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET or HEAD supported") + return + } + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(http.StatusOK) + if r.Method == http.MethodGet { + _, _ = w.Write([]byte("ok\n")) + } +} + +func (rt *Router) serveAsset(w http.ResponseWriter, r *http.Request) { + if rt.static == nil { + rt.notFind.ServeHTTP(w, r) + return + } + // Drop /admin/assets/ prefix → relative path under pathRootAssetsDir. + rel := strings.TrimPrefix(r.URL.Path, pathPrefixAssets) + // Defence against traversal and malformed paths: require the + // already-normalised form that fs.ValidPath enforces (no + // ".." segments, no "//" segments, no leading "/"). Anything + // that does not pass validation resolves to a 404 JSON rather + // than risking a 500 from the underlying fs.FS — legitimate + // filenames containing ".." as a substring (e.g. "app..js") + // still pass because ValidPath checks segments, not substrings. + if rel == "" || !fs.ValidPath(rel) { + rt.notFind.ServeHTTP(w, r) + return + } + name := path.Join(pathRootAssetsDir, rel) + f, err := rt.static.Open(name) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + rt.notFind.ServeHTTP(w, r) + return + } + writeJSONError(w, http.StatusInternalServerError, "internal", "failed to open asset") + return + } + defer f.Close() + info, err := f.Stat() + if err != nil || info.IsDir() { + rt.notFind.ServeHTTP(w, r) + return + } + readSeeker, ok := f.(io.ReadSeeker) + if !ok { + // embed.FS files implement ReadSeeker, but be defensive. + writeJSONError(w, http.StatusInternalServerError, "internal", "asset is not seekable") + return + } + http.ServeContent(w, r, name, info.ModTime(), readSeeker) +} + +func (rt *Router) serveSPA(w http.ResponseWriter, r *http.Request) { + // Reject non-GET/HEAD methods before inspecting rt.static so the + // response is uniform across admin binaries whether or not the + // SPA bundle happens to be configured. Without this, a POST to + // /admin/something returned a JSON 404 with a nil static and a + // JSON 405 with a populated static — same URL, different answer. + if r.Method != http.MethodGet && r.Method != http.MethodHead { + writeJSONError(w, http.StatusMethodNotAllowed, "method_not_allowed", "only GET or HEAD supported") + return + } + if rt.static == nil { + rt.notFind.ServeHTTP(w, r) + return + } + f, err := rt.static.Open(pathIndexHTML) + if err != nil { + rt.notFind.ServeHTTP(w, r) + return + } + defer f.Close() + info, err := f.Stat() + if err != nil || info.IsDir() { + rt.notFind.ServeHTTP(w, r) + return + } + readSeeker, ok := f.(io.ReadSeeker) + if !ok { + writeJSONError(w, http.StatusInternalServerError, "internal", "index.html is not seekable") + return + } + w.Header().Set("Cache-Control", "no-store") + http.ServeContent(w, r, pathIndexHTML, info.ModTime(), readSeeker) +} + +// errorResponse is the JSON shape for every admin error. +type errorResponse struct { + Error string `json:"error"` + Message string `json:"message,omitempty"` +} + +func writeJSONNotFound(w http.ResponseWriter, _ *http.Request) { + writeJSONError(w, http.StatusNotFound, "not_found", "") +} + +func writeJSONError(w http.ResponseWriter, status int, code, msg string) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(errorResponse{Error: code, Message: msg}) +} diff --git a/internal/admin/router_test.go b/internal/admin/router_test.go new file mode 100644 index 00000000..2868602e --- /dev/null +++ b/internal/admin/router_test.go @@ -0,0 +1,195 @@ +package admin + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "testing/fstest" + + "github.com/stretchr/testify/require" +) + +func newTestStatic() fstest.MapFS { + return fstest.MapFS{ + "index.html": {Data: []byte("spa")}, + "assets/app.js": {Data: []byte("console.log('ok');")}, + "assets/app.css": {Data: []byte("body { color: red; }")}, + } +} + +func TestRouter_APIPathIsNeverSwallowedBySPA(t *testing.T) { + api := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("X-Came-From", "api") + writeJSONError(w, http.StatusNotFound, "unknown_endpoint", "no handler") + }) + r := NewRouter(api, newTestStatic()) + + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/unknown", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusNotFound, rec.Code) + require.Equal(t, "api", rec.Header().Get("X-Came-From")) + require.Contains(t, rec.Header().Get("Content-Type"), "application/json") + require.NotContains(t, rec.Body.String(), "")}, + "assets/app..js": {Data: []byte("console.log('double-dot');")}, + } + r := NewRouter(nil, fs) + req := httptest.NewRequest(http.MethodGet, "/admin/assets/app..js", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "double-dot") +} + +// A client-supplied double slash must resolve to a JSON 404, not a 500. +func TestRouter_StaticDoubleSlashReturnsJSON404(t *testing.T) { + r := NewRouter(nil, newTestStatic()) + req := httptest.NewRequest(http.MethodGet, "/admin/assets//app.js", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusNotFound, rec.Code) + require.Contains(t, rec.Header().Get("Content-Type"), "application/json") +} + +func TestRouter_SPAFallbackServesIndex(t *testing.T) { + r := NewRouter(nil, newTestStatic()) + for _, p := range []string{"/admin", "/admin/", "/admin/dynamo", "/admin/s3/bucket-42"} { + req := httptest.NewRequest(http.MethodGet, p, nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equalf(t, http.StatusOK, rec.Code, "path %s", p) + require.Containsf(t, rec.Body.String(), "spa", "path %s", p) + } +} + +func TestRouter_SPAWithoutStaticReturns404(t *testing.T) { + r := NewRouter(nil, nil) + req := httptest.NewRequest(http.MethodGet, "/admin/dynamo", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestRouter_BareAPIRootReturnsJSON404NotHTML(t *testing.T) { + r := NewRouter(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + }), newTestStatic()) + // /admin/api, /admin/api/v1 (bare), /admin/api/v2 (unimplemented), + // /admin/api/v2/foo (deeper under unknown version), and the bare + // /admin/assets directory must all resolve to a JSON 404 — never + // the SPA HTML fallback — so API clients and probes get a + // machine-readable answer. + paths := []string{ + "/admin/api", + "/admin/api/v1", + "/admin/api/v2", + "/admin/api/v2/tables", + "/admin/assets", + } + for _, p := range paths { + req := httptest.NewRequest(http.MethodGet, p, nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + require.Equalf(t, http.StatusNotFound, rec.Code, "path %s", p) + require.Containsf(t, rec.Header().Get("Content-Type"), "application/json", "path %s", p) + require.NotContainsf(t, rec.Body.String(), "/cmdline") + adminSessionSigningKeyFile = flag.String("adminSessionSigningKeyFile", "", "Path to a file containing the base64-encoded primary admin HS256 key; avoids leaking the secret via argv") + adminSessionSigningKeyPrevious = flag.String("adminSessionSigningKeyPrevious", "", "Optional previous admin HS256 key accepted only for verification during rotation; prefer -adminSessionSigningKeyPreviousFile") + adminSessionSigningKeyPreviousFile = flag.String("adminSessionSigningKeyPreviousFile", "", "Path to a file containing the base64-encoded previous admin HS256 key used for rotation") + adminReadOnlyAccessKeys = flag.String("adminReadOnlyAccessKeys", "", "Comma-separated SigV4 access keys granted read-only admin access") + adminFullAccessKeys = flag.String("adminFullAccessKeys", "", "Comma-separated SigV4 access keys granted full-access admin role") ) // memoryPressureExit is set to true by the memwatch OnExceed callback to @@ -317,6 +330,10 @@ func run() error { return err } + if err := startAdminFromFlags(runCtx, &lc, eg, runtimes); err != nil { + return waitErrgroupAfterStartupFailure(cancel, eg, err) + } + if err := eg.Wait(); err != nil { return errors.Wrapf(err, "failed to serve") } diff --git a/main_admin.go b/main_admin.go new file mode 100644 index 00000000..a640d7ff --- /dev/null +++ b/main_admin.go @@ -0,0 +1,343 @@ +package main + +import ( + "context" + "crypto/tls" + "log/slog" + "net" + "net/http" + "os" + "strings" + "time" + + "github.com/bootjp/elastickv/internal/admin" + "github.com/cockroachdb/errors" + "golang.org/x/sync/errgroup" +) + +// Environment variables that the admin listener consults before +// falling back to the command-line flag values. Exposing secrets via +// env vars / file paths keeps them out of /proc//cmdline. +const ( + envAdminSessionSigningKey = "ELASTICKV_ADMIN_SESSION_SIGNING_KEY" + envAdminSessionSigningKeyPrevious = "ELASTICKV_ADMIN_SESSION_SIGNING_KEY_PREVIOUS" +) + +const ( + adminReadHeaderTimeout = 5 * time.Second + adminWriteTimeout = 10 * time.Second + adminIdleTimeout = 30 * time.Second + adminShutdownTimeout = 5 * time.Second + + // adminBuildVersion is surfaced in GET /admin/api/v1/cluster. Until + // we wire real ldflags-injected build info, a placeholder is fine. + adminBuildVersion = "dev" +) + +// buildVersion returns the elastickv binary version for admin purposes. +// It is intentionally a function, not a constant, so build tooling can +// link-replace it via -ldflags in the future. +func buildVersion() string { return adminBuildVersion } + +// adminListenerConfig is the subset of startup inputs that goes into the +// admin listener. Collecting them in a struct keeps the main.go call site +// compact and makes unit testing the builder easier. +type adminListenerConfig struct { + enabled bool + listen string + tlsCertFile string + tlsKeyFile string + allowPlaintextNonLoopback bool + allowInsecureDevCookie bool + + sessionSigningKey string + sessionSigningKeyPrevious string + + readOnlyAccessKeys []string + fullAccessKeys []string +} + +// startAdminFromFlags is the single entrypoint main.run() uses to stand +// up the admin listener. It owns the flag → config translation and the +// credentials loading so run() does not inherit that complexity. +// +// When admin is disabled (the default) the function returns immediately +// without touching --s3CredentialsFile: pulling the admin feature into +// a hard dependency on that file would break deployments that never +// intended to use it. +func startAdminFromFlags(ctx context.Context, lc *net.ListenConfig, eg *errgroup.Group, runtimes []*raftGroupRuntime) error { + if !*adminEnabled { + return nil + } + staticCreds, err := loadS3StaticCredentials(*s3CredsFile) + if err != nil { + return errors.Wrapf(err, "load static credentials for admin listener") + } + // An admin listener with zero credentials would accept logins + // only to reject every one of them with invalid_credentials, so a + // missing or empty credentials file is a wiring bug rather than a + // valid "locked down" state. Failing fast here also guards against + // the typed-nil MapCredentialStore case inside NewServer (an + // untyped `== nil` check cannot detect a nil-map-valued interface + // on its own). + if len(staticCreds) == 0 { + return errors.New("admin listener is enabled but no static credentials are configured; " + + "set -s3CredentialsFile to a file with at least one entry") + } + primaryKey, err := resolveSigningKey(*adminSessionSigningKey, *adminSessionSigningKeyFile, envAdminSessionSigningKey) + if err != nil { + return errors.Wrap(err, "resolve -adminSessionSigningKey") + } + previousKey, err := resolveSigningKey(*adminSessionSigningKeyPrevious, *adminSessionSigningKeyPreviousFile, envAdminSessionSigningKeyPrevious) + if err != nil { + return errors.Wrap(err, "resolve -adminSessionSigningKeyPrevious") + } + cfg := adminListenerConfig{ + enabled: *adminEnabled, + listen: *adminListen, + tlsCertFile: *adminTLSCertFile, + tlsKeyFile: *adminTLSKeyFile, + allowPlaintextNonLoopback: *adminAllowPlaintextNonLoopback, + allowInsecureDevCookie: *adminAllowInsecureDevCookie, + sessionSigningKey: primaryKey, + sessionSigningKeyPrevious: previousKey, + readOnlyAccessKeys: parseCSV(*adminReadOnlyAccessKeys), + fullAccessKeys: parseCSV(*adminFullAccessKeys), + } + clusterSrc := newClusterInfoSource(*raftId, buildVersion(), runtimes) + _, err = startAdminServer(ctx, lc, eg, cfg, staticCreds, clusterSrc, buildVersion()) + return err +} + +// buildAdminConfig translates flag values into an admin.Config. +func buildAdminConfig(in adminListenerConfig) admin.Config { + return admin.Config{ + Enabled: in.enabled, + Listen: in.listen, + TLSCertFile: in.tlsCertFile, + TLSKeyFile: in.tlsKeyFile, + AllowPlaintextNonLoopback: in.allowPlaintextNonLoopback, + SessionSigningKey: in.sessionSigningKey, + SessionSigningKeyPrevious: in.sessionSigningKeyPrevious, + ReadOnlyAccessKeys: in.readOnlyAccessKeys, + FullAccessKeys: in.fullAccessKeys, + AllowInsecureDevCookie: in.allowInsecureDevCookie, + } +} + +// startAdminServer validates the admin configuration, constructs the admin +// server, and attaches its lifecycle to eg. It is a no-op when the admin +// listener is disabled. Errors at this point are hard startup failures: +// the design doc mandates ハードエラーで起動失敗 for every invalid +// configuration, and we honour that uniformly. +// +// The returned address is the actual host:port the listener bound to; it +// differs from adminCfg.Listen only when the caller passed a port of 0, +// but tests rely on this to avoid the bind-close-rebind race that a +// pre-allocated free-port helper would otherwise introduce. When admin +// is disabled the returned address is empty. +func startAdminServer( + ctx context.Context, + lc *net.ListenConfig, + eg *errgroup.Group, + cfg adminListenerConfig, + creds map[string]string, + cluster admin.ClusterInfoSource, + version string, +) (string, error) { + adminCfg := buildAdminConfig(cfg) + enabled, err := checkAdminConfig(&adminCfg, cluster) + if err != nil || !enabled { + return "", err + } + server, err := buildAdminHTTPServer(&adminCfg, creds, cluster) + if err != nil { + return "", err + } + httpSrv := newAdminHTTPServer(server) + listener, err := lc.Listen(ctx, "tcp", adminCfg.Listen) + if err != nil { + return "", errors.Wrapf(err, "failed to listen on admin address %s", adminCfg.Listen) + } + tlsEnabled := strings.TrimSpace(adminCfg.TLSCertFile) != "" && strings.TrimSpace(adminCfg.TLSKeyFile) != "" + if tlsEnabled { + httpSrv.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + actualAddr := listener.Addr().String() + // Use the real bound address in log lines and in the lifecycle + // task so the shutdown banner matches startup. + boundCfg := adminCfg + boundCfg.Listen = actualAddr + registerAdminLifecycle(ctx, eg, httpSrv, listener, &boundCfg, tlsEnabled, version) + return actualAddr, nil +} + +// checkAdminConfig validates adminCfg; returns (enabled=false, nil) when +// admin is disabled and requires no further work. +func checkAdminConfig(adminCfg *admin.Config, cluster admin.ClusterInfoSource) (bool, error) { + if err := adminCfg.Validate(); err != nil { + if !adminCfg.Enabled { + return false, nil + } + return false, errors.Wrap(err, "admin config is invalid") + } + if !adminCfg.Enabled { + return false, nil + } + if cluster == nil { + return false, errors.New("admin: cluster info source is required") + } + return true, nil +} + +func buildAdminHTTPServer(adminCfg *admin.Config, creds map[string]string, cluster admin.ClusterInfoSource) (*admin.Server, error) { + primaryKeys, err := adminCfg.DecodedSigningKeys() + if err != nil { + return nil, errors.Wrap(err, "decode admin signing keys") + } + signer, err := admin.NewSigner(primaryKeys[0], nil) + if err != nil { + return nil, errors.Wrap(err, "build admin signer") + } + verifier, err := admin.NewVerifier(primaryKeys, nil) + if err != nil { + return nil, errors.Wrap(err, "build admin verifier") + } + server, err := admin.NewServer(admin.ServerDeps{ + Signer: signer, + Verifier: verifier, + Credentials: admin.MapCredentialStore(creds), + Roles: adminCfg.RoleIndex(), + ClusterInfo: cluster, + StaticFS: nil, + AuthOpts: admin.AuthServiceOpts{ + InsecureCookie: adminCfg.AllowInsecureDevCookie, + }, + Logger: slog.Default().With(slog.String("component", "admin")), + }) + if err != nil { + return nil, errors.Wrap(err, "build admin server") + } + return server, nil +} + +func newAdminHTTPServer(server *admin.Server) *http.Server { + return &http.Server{ + Handler: server.Handler(), + ReadHeaderTimeout: adminReadHeaderTimeout, + WriteTimeout: adminWriteTimeout, + IdleTimeout: adminIdleTimeout, + } +} + +func registerAdminLifecycle( + ctx context.Context, + eg *errgroup.Group, + httpSrv *http.Server, + listener net.Listener, + adminCfg *admin.Config, + tlsEnabled bool, + version string, +) { + addr := adminCfg.Listen + eg.Go(func() error { + <-ctx.Done() + slog.Info("shutting down admin listener", "address", addr, "reason", ctx.Err()) + shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), adminShutdownTimeout) + defer cancel() + err := httpSrv.Shutdown(shutdownCtx) + if err == nil || errors.Is(err, http.ErrServerClosed) || errors.Is(err, net.ErrClosed) { + return nil + } + return errors.WithStack(err) + }) + eg.Go(func() error { + slog.Info("starting admin listener", "address", addr, "tls", tlsEnabled, "version", version) + var serveErr error + if tlsEnabled { + serveErr = httpSrv.ServeTLS(listener, adminCfg.TLSCertFile, adminCfg.TLSKeyFile) + } else { + serveErr = httpSrv.Serve(listener) + } + if serveErr == nil || errors.Is(serveErr, http.ErrServerClosed) || errors.Is(serveErr, net.ErrClosed) { + return nil + } + return errors.Wrapf(serveErr, "admin listener on %s stopped with error", addr) + }) +} + +// newClusterInfoSource builds a ClusterInfoSource that reads from the +// runtime raftGroupRuntime slice. It lives here (rather than +// internal/admin) so the admin package stays free of main-package types. +// +// Membership is fetched via engine.Configuration(ctx); the call is +// best-effort — if it fails (for instance because the engine is in the +// middle of a leadership transition) we leave Members empty rather +// than fail the whole cluster snapshot. +func newClusterInfoSource(nodeID, version string, runtimes []*raftGroupRuntime) admin.ClusterInfoSource { + return admin.ClusterInfoFunc(func(ctx context.Context) (admin.ClusterInfo, error) { + groups := make([]admin.GroupInfo, 0, len(runtimes)) + for _, rt := range runtimes { + if rt == nil || rt.engine == nil { + continue + } + status := rt.engine.Status() + // Seed as an empty-but-non-nil slice so a + // Configuration() failure still JSON-encodes as `[]` + // rather than `null`; API consumers that treat + // members as an always-array field rely on this. + members := []string{} + if cfg, err := rt.engine.Configuration(ctx); err == nil { + members = make([]string, 0, len(cfg.Servers)) + for _, srv := range cfg.Servers { + members = append(members, srv.ID) + } + } + groups = append(groups, admin.GroupInfo{ + GroupID: rt.spec.id, + LeaderID: status.Leader.ID, + IsLeader: strings.EqualFold(string(status.State), "leader"), + Members: members, + }) + } + return admin.ClusterInfo{ + NodeID: nodeID, + Version: version, + Groups: groups, + }, nil + }) +} + +// resolveSigningKey picks the effective admin signing key from, in +// priority order: the --*File flag (file contents), the env var, and +// finally the --*Flag argv value. Preferring the file/env paths keeps +// the raw base64 out of /proc//cmdline on Linux. Returns the empty +// string when every source is unset — callers that require a value +// (validated elsewhere) must handle that case themselves. +func resolveSigningKey(flagValue, filePath, envVar string) (string, error) { + if strings.TrimSpace(filePath) != "" { + b, err := os.ReadFile(filePath) + if err != nil { + return "", errors.Wrapf(err, "read admin signing key file %q", filePath) + } + return strings.TrimSpace(string(b)), nil + } + if v := strings.TrimSpace(os.Getenv(envVar)); v != "" { + return v, nil + } + return strings.TrimSpace(flagValue), nil +} + +// parseCSV splits a flag value like "a,b,c" into a slice with empty and +// whitespace-only entries dropped. It is not in shard_config.go because +// admin's comma-separated list format is simpler than raft groups. +func parseCSV(raw string) []string { + parts := strings.Split(raw, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + if trim := strings.TrimSpace(p); trim != "" { + out = append(out, trim) + } + } + return out +} diff --git a/main_admin_test.go b/main_admin_test.go new file mode 100644 index 00000000..13ecd07a --- /dev/null +++ b/main_admin_test.go @@ -0,0 +1,238 @@ +package main + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/pem" + "io" + "math/big" + "net" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/bootjp/elastickv/internal/admin" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func freshKey() string { + raw := make([]byte, 64) + // Deterministic seed is fine; tests only care it is the right length. + for i := range raw { + raw[i] = byte(i + 1) + } + return base64.StdEncoding.EncodeToString(raw) +} + +func TestBuildAdminConfig_Passthrough(t *testing.T) { + in := adminListenerConfig{ + enabled: true, + listen: "127.0.0.1:18080", + tlsCertFile: "c", + tlsKeyFile: "k", + allowPlaintextNonLoopback: true, + allowInsecureDevCookie: true, + sessionSigningKey: "a", + sessionSigningKeyPrevious: "b", + readOnlyAccessKeys: []string{"X"}, + fullAccessKeys: []string{"Y"}, + } + out := buildAdminConfig(in) + require.Equal(t, true, out.Enabled) + require.Equal(t, in.listen, out.Listen) + require.Equal(t, in.tlsCertFile, out.TLSCertFile) + require.Equal(t, in.tlsKeyFile, out.TLSKeyFile) + require.Equal(t, in.allowPlaintextNonLoopback, out.AllowPlaintextNonLoopback) + require.Equal(t, in.allowInsecureDevCookie, out.AllowInsecureDevCookie) + require.Equal(t, in.sessionSigningKey, out.SessionSigningKey) + require.Equal(t, in.sessionSigningKeyPrevious, out.SessionSigningKeyPrevious) + require.Equal(t, in.readOnlyAccessKeys, out.ReadOnlyAccessKeys) + require.Equal(t, in.fullAccessKeys, out.FullAccessKeys) +} + +func TestParseCSV(t *testing.T) { + require.Equal(t, []string{"a", "b", "c"}, parseCSV("a,b,c")) + require.Equal(t, []string{"a", "b"}, parseCSV(" a ,, b ,")) + require.Equal(t, []string{}, parseCSV("")) +} + +func TestStartAdminServer_DisabledNoOp(t *testing.T) { + eg, ctx := errgroup.WithContext(context.Background()) + defer func() { _ = eg.Wait() }() + var lc net.ListenConfig + _, err := startAdminServer(ctx, &lc, eg, adminListenerConfig{enabled: false}, nil, nil, "") + require.NoError(t, err) +} + +func TestStartAdminServer_InvalidConfigRejected(t *testing.T) { + eg, ctx := errgroup.WithContext(context.Background()) + defer func() { _ = eg.Wait() }() + var lc net.ListenConfig + cfg := adminListenerConfig{ + enabled: true, + listen: "127.0.0.1:0", + // missing signing key + } + _, err := startAdminServer(ctx, &lc, eg, cfg, map[string]string{}, nil, "") + require.Error(t, err) +} + +func TestStartAdminServer_NonLoopbackWithoutTLSRejected(t *testing.T) { + eg, ctx := errgroup.WithContext(context.Background()) + defer func() { _ = eg.Wait() }() + var lc net.ListenConfig + cfg := adminListenerConfig{ + enabled: true, + listen: "0.0.0.0:0", + sessionSigningKey: freshKey(), + } + _, err := startAdminServer(ctx, &lc, eg, cfg, map[string]string{}, nil, "") + require.Error(t, err) + require.Contains(t, err.Error(), "TLS") +} + +func TestStartAdminServer_RejectsMissingClusterSource(t *testing.T) { + eg, ctx := errgroup.WithContext(context.Background()) + defer func() { _ = eg.Wait() }() + var lc net.ListenConfig + cfg := adminListenerConfig{ + enabled: true, + listen: "127.0.0.1:0", + sessionSigningKey: freshKey(), + } + _, err := startAdminServer(ctx, &lc, eg, cfg, map[string]string{}, nil, "") + require.Error(t, err) + require.Contains(t, err.Error(), "cluster info source") +} + +func TestStartAdminServer_ServesHealthz(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + eg, eCtx := errgroup.WithContext(ctx) + defer func() { + cancel() + _ = eg.Wait() + }() + + var lc net.ListenConfig + cfg := adminListenerConfig{ + enabled: true, + listen: "127.0.0.1:0", + sessionSigningKey: freshKey(), + allowInsecureDevCookie: true, + } + cluster := admin.ClusterInfoFunc(func(_ context.Context) (admin.ClusterInfo, error) { + return admin.ClusterInfo{NodeID: "n1", Version: "test"}, nil + }) + addr, err := startAdminServer(eCtx, &lc, eg, cfg, map[string]string{}, cluster, "test") + require.NoError(t, err) + + // Poll /admin/healthz until success or the test deadline. + client := &http.Client{Timeout: 2 * time.Second} + deadline := time.Now().Add(3 * time.Second) + var resp *http.Response + for time.Now().Before(deadline) { + req, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+addr+"/admin/healthz", nil) + require.NoError(t, reqErr) + resp, err = client.Do(req) + if err == nil { + break + } + time.Sleep(20 * time.Millisecond) + } + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + body, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + require.Equal(t, "ok\n", string(body)) +} + +func TestStartAdminServer_ServesTLS(t *testing.T) { + certFile, keyFile := writeSelfSignedCert(t) + ctx, cancel := context.WithCancel(context.Background()) + eg, eCtx := errgroup.WithContext(ctx) + defer func() { + cancel() + _ = eg.Wait() + }() + + var lc net.ListenConfig + cfg := adminListenerConfig{ + enabled: true, + listen: "127.0.0.1:0", + tlsCertFile: certFile, + tlsKeyFile: keyFile, + sessionSigningKey: freshKey(), + } + cluster := admin.ClusterInfoFunc(func(_ context.Context) (admin.ClusterInfo, error) { + return admin.ClusterInfo{NodeID: "n-tls", Version: "test"}, nil + }) + addr, err := startAdminServer(eCtx, &lc, eg, cfg, map[string]string{}, cluster, "test") + require.NoError(t, err) + + transport := &http.Transport{TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec // self-signed certificate in test; server identity is not what we assert here. + MinVersion: tls.VersionTLS12, + }} + client := &http.Client{Transport: transport, Timeout: 2 * time.Second} + deadline := time.Now().Add(3 * time.Second) + var resp *http.Response + for time.Now().Before(deadline) { + req, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, "https://"+addr+"/admin/healthz", nil) + require.NoError(t, reqErr) + resp, err = client.Do(req) + if err == nil { + break + } + time.Sleep(20 * time.Millisecond) + } + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() +} + +// writeSelfSignedCert writes a short-lived self-signed certificate to a +// temp dir and returns the cert / key paths. The certificate is valid +// for 127.0.0.1 only; tests that need TLS should run with it. +func writeSelfSignedCert(t *testing.T) (string, string) { + t.Helper() + dir := t.TempDir() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + require.NoError(t, err) + tmpl := x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: "elastickv-admin-test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}, + DNSNames: []string{"localhost"}, + } + der, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) + require.NoError(t, err) + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + certOut, err := os.Create(certPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: der})) + require.NoError(t, certOut.Close()) + + keyBytes, err := x509.MarshalECPrivateKey(priv) + require.NoError(t, err) + keyOut, err := os.Create(keyPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyBytes})) + require.NoError(t, keyOut.Close()) + return certPath, keyPath +}