diff --git a/moon.yml b/moon.yml index 86ccc77..1f36dcf 100644 --- a/moon.yml +++ b/moon.yml @@ -70,7 +70,7 @@ tasks: cache: false integration: - command: 'go test -tags=integration ./store/postgres' + command: 'go test -count 1 -tags integration ./store/postgres ./testkit/internal/store/postgres ./testkit/internal/authflow' toolchains: ['go'] inputs: - '@group(go)' diff --git a/testkit/README.md b/testkit/README.md new file mode 100644 index 0000000..b627131 --- /dev/null +++ b/testkit/README.md @@ -0,0 +1,55 @@ +# authkit testkit + +`testkit` is a small pastebin-style web app used to exercise authkit in realistic application code. + +The current slice uses authkit's API-token exchange path for paste creation. Reading pastes remains public; creating pastes requires exchanging the startup API token for a short-lived authkit access JWT carried in a temporary app cookie. + +## Run + +```bash +go run ./testkit/cmd/testkit +``` + +The server listens on `:8080` by default. Override it with `TESTKIT_ADDR`: + +```bash +TESTKIT_ADDR=:8090 go run ./testkit/cmd/testkit +``` + +Startup prints a fresh development API token: + +```text +testkit seed API token: ak_... +``` + +Use that token on `/login`. The token is shown only at startup and expires after 24 hours. + +## Persistence + +By default, testkit stores pastes in process memory. Restarting the server clears them. + +Set `TESTKIT_DATABASE_URL` to use PostgreSQL paste persistence instead: + +```bash +TESTKIT_DATABASE_URL='postgres://testkit:testkit@localhost:5432/testkit?sslmode=disable' \ + go run ./testkit/cmd/testkit +``` + +When `TESTKIT_DATABASE_URL` is set, startup opens a Postgres pool, runs testkit's `testkit_*` paste migrations, runs authkit's Postgres migrations, stores paste data in `testkit_*` tables, and stores authkit principals/API tokens in `authkit_*` tables. + +Without `TESTKIT_DATABASE_URL`, both paste data and authkit state are in memory. + +## Routes + +- `GET /` lists recent pastes. +- `GET /login` renders the API-token login form. +- `POST /auth/token` exchanges an API token and sets the temporary access cookie. +- `POST /logout` clears the temporary access cookie. +- `GET /new` renders the create form for authenticated browsers. +- `POST /pastes` creates a paste for authenticated browsers and redirects to its page. +- `GET /p/{id}` renders a paste. +- `GET /raw/{id}` returns the paste body as `text/plain`. + +## Current Scope + +The browser cookie is a temporary testkit transport for authkit access JWTs. Ownership, edit/delete flows, refresh tokens, OIDC login, richer session management, and API endpoints are intentionally deferred until this API-token path is proven in the app. diff --git a/testkit/cmd/testkit/main.go b/testkit/cmd/testkit/main.go new file mode 100644 index 0000000..b973d66 --- /dev/null +++ b/testkit/cmd/testkit/main.go @@ -0,0 +1,152 @@ +package main + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + + authmemory "github.com/meigma/authkit/store/memory" + authpostgres "github.com/meigma/authkit/store/postgres" + "github.com/meigma/authkit/testkit/internal/authflow" + "github.com/meigma/authkit/testkit/internal/httpui" + "github.com/meigma/authkit/testkit/internal/paste" + testkitmemory "github.com/meigma/authkit/testkit/internal/store/memory" + testkitpostgres "github.com/meigma/authkit/testkit/internal/store/postgres" +) + +const ( + defaultAddr = ":8080" + addrEnv = "TESTKIT_ADDR" + databaseURLEnv = "TESTKIT_DATABASE_URL" + serverReadHeaderTimeout = 5 * time.Second + serverReadTimeout = 10 * time.Second + serverWriteTimeout = 10 * time.Second + serverIdleTimeout = 60 * time.Second + shutdownTimeout = 5 * time.Second +) + +func main() { + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + + err := run(ctx, os.Stdout) + stop() + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func run(ctx context.Context, out io.Writer) error { + addr := os.Getenv(addrEnv) + if addr == "" { + addr = defaultAddr + } + + stores, cleanup, err := newStores(ctx) + if err != nil { + return err + } + defer cleanup() + + pasteService, err := paste.NewService(stores.pastes) + if err != nil { + return err + } + authRuntime, err := authflow.NewRuntime(ctx, stores.auth) + if err != nil { + return err + } + handler, err := httpui.NewServer(pasteService, authRuntime) + if err != nil { + return err + } + + server := &http.Server{ + Addr: addr, + Handler: handler, + ReadHeaderTimeout: serverReadHeaderTimeout, + ReadTimeout: serverReadTimeout, + WriteTimeout: serverWriteTimeout, + IdleTimeout: serverIdleTimeout, + } + serverErr := make(chan error, 1) + go func() { + _, _ = fmt.Fprintf(out, "testkit seed API token: %s\n", authRuntime.SeedAPIToken) + _, _ = fmt.Fprintf(out, "testkit listening on http://localhost%s\n", addr) + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + serverErr <- err + + return + } + serverErr <- nil + }() + + select { + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil { + return fmt.Errorf("testkit: shutdown server: %w", err) + } + + return nil + case err := <-serverErr: + return err + } +} + +type stores struct { + pastes paste.Repository + auth authflow.Store +} + +func newStores(ctx context.Context) (stores, func(), error) { + databaseURL := os.Getenv(databaseURLEnv) + if databaseURL == "" { + return stores{ + pastes: testkitmemory.NewStore(), + auth: authmemory.NewStore(), + }, func() {}, nil + } + + pool, err := pgxpool.New(ctx, databaseURL) + if err != nil { + return stores{}, nil, fmt.Errorf("testkit: open postgres pool: %w", err) + } + if migrateErr := testkitpostgres.Migrate(ctx, pool); migrateErr != nil { + pool.Close() + + return stores{}, nil, fmt.Errorf("testkit: migrate postgres: %w", migrateErr) + } + if migrateErr := authpostgres.Migrate(ctx, pool); migrateErr != nil { + pool.Close() + + return stores{}, nil, fmt.Errorf("testkit: migrate authkit postgres: %w", migrateErr) + } + + pasteStore, err := testkitpostgres.NewStore(pool) + if err != nil { + pool.Close() + + return stores{}, nil, err + } + authStore, err := authpostgres.NewStore(pool) + if err != nil { + pool.Close() + + return stores{}, nil, err + } + + return stores{ + pastes: pasteStore, + auth: authStore, + }, pool.Close, nil +} diff --git a/testkit/internal/authflow/doc.go b/testkit/internal/authflow/doc.go new file mode 100644 index 0000000..49d398f --- /dev/null +++ b/testkit/internal/authflow/doc.go @@ -0,0 +1,2 @@ +// Package authflow wires the authkit API-token exchange flow for testkit. +package authflow diff --git a/testkit/internal/authflow/runtime.go b/testkit/internal/authflow/runtime.go new file mode 100644 index 0000000..d518c53 --- /dev/null +++ b/testkit/internal/authflow/runtime.go @@ -0,0 +1,302 @@ +package authflow + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "fmt" + "net/http" + "time" + + "github.com/lestrrat-go/jwx/v3/jwa" + "github.com/lestrrat-go/jwx/v3/jwk" + + "github.com/meigma/authkit" + "github.com/meigma/authkit/accessjwt" + "github.com/meigma/authkit/apikey" + "github.com/meigma/authkit/compose" + "github.com/meigma/authkit/exchange" + "github.com/meigma/authkit/httpauth" +) + +const ( + // CookieName is the temporary app-owned cookie carrying authkit access JWTs. + CookieName = "authkit_testkit_access" + + // LoginPath is the HTML login page used when browser authentication fails. + LoginPath = "/login" + + // SeedAPITokenTTL is the lifetime of the development bootstrap API token. + SeedAPITokenTTL = 24 * time.Hour + + // AccessTokenTTL is the lifetime of access JWTs issued by testkit. + AccessTokenTTL = 15 * time.Minute + + accessJWTIssuer = "https://testkit.local/authkit" + accessJWTAudience = "testkit" + accessJWTKeyID = "testkit-dev-access-key" + bootstrapPrincipalName = "Testkit author" + bootstrapAPITokenName = "testkit bootstrap token" + rsaKeyBits = 2048 + accessCookiePath = "/" + accessCookieMaxAge = int(AccessTokenTTL / time.Second) + clearedAccessCookieAge = -1 +) + +// Store is the authkit storage surface testkit needs for API-token exchange. +type Store interface { + authkit.PrincipalCreator + authkit.PrincipalFinder + authkit.PrincipalLister + apikey.TokenStore +} + +// Runtime contains the authkit components used by testkit HTTP handlers. +type Runtime struct { + // Middleware authenticates requests carrying authkit access JWTs. + Middleware *httpauth.Middleware + + // Exchanger exchanges opaque API tokens for authkit access JWTs. + Exchanger *exchange.APITokenExchanger + + // Principal is the bootstrap principal that owns the startup API token. + Principal authkit.Principal + + // SeedAPIToken is the plaintext startup API token shown once on stdout. + SeedAPIToken string + + // SeedAPITokenExpiresAt is when SeedAPIToken stops being accepted for exchange. + SeedAPITokenExpiresAt time.Time +} + +type options struct { + clock func() time.Time +} + +// Option configures Runtime construction. +type Option func(*options) + +// WithClock configures the clock used for token timestamps. +func WithClock(clock func() time.Time) Option { + return func(opts *options) { + if clock != nil { + opts.clock = clock + } + } +} + +// NewRuntime constructs the authkit API-token exchange runtime for testkit. +func NewRuntime(ctx context.Context, store Store, opts ...Option) (*Runtime, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if store == nil { + return nil, errors.New("authflow: store is required") + } + + cfg := options{ + clock: time.Now, + } + for _, opt := range opts { + if opt != nil { + opt(&cfg) + } + } + if cfg.clock == nil { + cfg.clock = time.Now + } + + principal, err := bootstrapPrincipal(ctx, store) + if err != nil { + return nil, err + } + apiTokens, err := apikey.NewService(store, apikey.WithClock(cfg.clock)) + if err != nil { + return nil, fmt.Errorf("authflow: create API-token service: %w", err) + } + seedToken, err := apiTokens.IssueToken(ctx, apikey.IssueRequest{ + PrincipalID: principal.ID, + Name: bootstrapAPITokenName, + ExpiresAt: cfg.clock().Add(SeedAPITokenTTL), + }) + if err != nil { + return nil, fmt.Errorf("authflow: issue seed API token: %w", err) + } + + accessIssuer, accessVerifier, err := newAccessJWTIssuerAndVerifier(cfg.clock) + if err != nil { + return nil, err + } + exchanger, err := exchange.NewAPITokenExchanger(exchange.APITokenOptions{ + APITokens: apiTokens, + Principals: store, + AccessTokens: accessIssuer, + }) + if err != nil { + return nil, fmt.Errorf("authflow: create API-token exchanger: %w", err) + } + composed, err := compose.NewHTTP(compose.HTTPOptions{ + PrincipalAuthenticators: []compose.PrincipalAuthenticatorSpec{ + compose.AccessJWT(accessVerifier, store), + }, + Authorizer: allowAuthorizer{}, + MiddlewareOptions: []httpauth.Option{ + httpauth.WithErrorRenderer(renderAuthError), + }, + }) + if err != nil { + return nil, fmt.Errorf("authflow: compose HTTP auth: %w", err) + } + + return &Runtime{ + Middleware: composed.Middleware, + Exchanger: exchanger, + Principal: principal, + SeedAPIToken: seedToken.Plaintext, + SeedAPITokenExpiresAt: seedToken.ExpiresAt, + }, nil +} + +// ExchangeAPIToken exchanges plaintext for an authkit access JWT. +func (r *Runtime) ExchangeAPIToken( + ctx context.Context, + plaintext string, +) (exchange.APITokenResult, error) { + if r == nil || r.Exchanger == nil { + return exchange.APITokenResult{}, errors.New("authflow: runtime exchanger is required") + } + + return r.Exchanger.Exchange(ctx, exchange.APITokenRequest{ + Plaintext: plaintext, + }) +} + +// Authenticate authenticates requests carrying authkit access JWTs. +func (r *Runtime) Authenticate(next http.Handler) http.Handler { + return r.Middleware.Authenticate(next) +} + +// SetAccessCookie writes the temporary testkit access JWT cookie. +func SetAccessCookie(w http.ResponseWriter, token accessjwt.IssuedToken) { + http.SetCookie(w, &http.Cookie{ + Name: CookieName, + Value: token.Plaintext, + Path: accessCookiePath, + Expires: token.ExpiresAt, + MaxAge: accessCookieMaxAge, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }) +} + +// ClearAccessCookie clears the temporary testkit access JWT cookie. +func ClearAccessCookie(w http.ResponseWriter) { + http.SetCookie(w, &http.Cookie{ + Name: CookieName, + Value: "", + Path: accessCookiePath, + MaxAge: clearedAccessCookieAge, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }) +} + +func bootstrapPrincipal(ctx context.Context, store Store) (authkit.Principal, error) { + principals, err := store.ListPrincipals(ctx) + if err != nil { + return authkit.Principal{}, fmt.Errorf("authflow: list principals: %w", err) + } + for _, principal := range principals { + if principal.Kind == authkit.PrincipalKindUser && principal.DisplayName == bootstrapPrincipalName { + return principal, nil + } + } + + principal, err := store.CreatePrincipal(ctx, authkit.CreatePrincipalRequest{ + Kind: authkit.PrincipalKindUser, + DisplayName: bootstrapPrincipalName, + Attributes: map[string]any{ + "testkit": true, + }, + }) + if err != nil { + return authkit.Principal{}, fmt.Errorf("authflow: create bootstrap principal: %w", err) + } + + return principal, nil +} + +func newAccessJWTIssuerAndVerifier( + clock func() time.Time, +) (*accessjwt.Issuer, *accessjwt.Verifier, error) { + rawKey, err := rsa.GenerateKey(rand.Reader, rsaKeyBits) + if err != nil { + return nil, nil, fmt.Errorf("authflow: generate access JWT key: %w", err) + } + signingKey, err := jwk.Import(rawKey) + if err != nil { + return nil, nil, fmt.Errorf("authflow: import access JWT key: %w", err) + } + if setErr := signingKey.Set(jwk.KeyIDKey, accessJWTKeyID); setErr != nil { + return nil, nil, fmt.Errorf("authflow: set access JWT key ID: %w", setErr) + } + if setErr := signingKey.Set(jwk.AlgorithmKey, jwa.RS256()); setErr != nil { + return nil, nil, fmt.Errorf("authflow: set access JWT key algorithm: %w", setErr) + } + publicKey, err := jwk.PublicKeyOf(signingKey) + if err != nil { + return nil, nil, fmt.Errorf("authflow: derive access JWT public key: %w", err) + } + keySet := jwk.NewSet() + if addErr := keySet.AddKey(publicKey); addErr != nil { + return nil, nil, fmt.Errorf("authflow: build access JWT key set: %w", addErr) + } + + issuer, err := accessjwt.NewIssuer(accessjwt.IssuerOptions{ + Issuer: accessJWTIssuer, + Audience: accessJWTAudience, + TTL: AccessTokenTTL, + SigningKey: signingKey, + Clock: clock, + }) + if err != nil { + return nil, nil, fmt.Errorf("authflow: create access JWT issuer: %w", err) + } + verifier, err := accessjwt.NewVerifier(accessjwt.VerifierOptions{ + Issuer: accessJWTIssuer, + Audience: accessJWTAudience, + KeySet: keySet, + Clock: clock, + }) + if err != nil { + return nil, nil, fmt.Errorf("authflow: create access JWT verifier: %w", err) + } + + return issuer, verifier, nil +} + +func renderAuthError(w http.ResponseWriter, req *http.Request, err error) { + if errors.Is(err, authkit.ErrUnauthenticated) || errors.Is(err, authkit.ErrUnresolvedIdentity) { + ClearAccessCookie(w) + http.Redirect(w, req, LoginPath, http.StatusSeeOther) + + return + } + + status := http.StatusInternalServerError + if errors.Is(err, authkit.ErrUnauthorized) { + status = http.StatusForbidden + } + http.Error(w, http.StatusText(status), status) +} + +type allowAuthorizer struct{} + +func (allowAuthorizer) Can(ctx context.Context, _ authkit.AuthorizationCheck) (authkit.Decision, error) { + if err := ctx.Err(); err != nil { + return authkit.Decision{}, err + } + + return authkit.Decision{Allowed: true}, nil +} diff --git a/testkit/internal/authflow/runtime_integration_test.go b/testkit/internal/authflow/runtime_integration_test.go new file mode 100644 index 0000000..1de4503 --- /dev/null +++ b/testkit/internal/authflow/runtime_integration_test.go @@ -0,0 +1,83 @@ +//go:build integration + +package authflow + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" + + "github.com/meigma/authkit/httpauth" + authpostgres "github.com/meigma/authkit/store/postgres" + testkitpostgres "github.com/meigma/authkit/testkit/internal/store/postgres" +) + +const postgresReadyOccurrences = 2 + +func TestRuntimeUsesPostgresAuthStore(t *testing.T) { + ctx := context.Background() + pool := newPostgresPool(t) + require.NoError(t, testkitpostgres.Migrate(ctx, pool)) + require.NoError(t, authpostgres.Migrate(ctx, pool)) + require.NoError(t, authpostgres.Migrate(ctx, pool)) + + store, err := authpostgres.NewStore(pool) + require.NoError(t, err) + runtime, err := NewRuntime(ctx, store, WithClock(fixedTime)) + require.NoError(t, err) + + result, err := runtime.ExchangeAPIToken(ctx, runtime.SeedAPIToken) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", bearer(result.AccessToken.Plaintext)) + runtime.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + principal, ok := httpauth.PrincipalFromContext(req.Context()) + assert.True(t, ok) + if ok { + assert.Equal(t, runtime.Principal.ID, principal.ID) + } + w.WriteHeader(http.StatusNoContent) + })).ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusNoContent, recorder.Code) +} + +func newPostgresPool(t *testing.T) *pgxpool.Pool { + t.Helper() + + ctx := context.Background() + container, err := tcpostgres.Run( + ctx, + "postgres:16-alpine", + tcpostgres.WithDatabase("testkit"), + tcpostgres.WithUsername("testkit"), + tcpostgres.WithPassword("testkit"), + testcontainers.WithAdditionalWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(postgresReadyOccurrences). + WithStartupTimeout(time.Minute), + ), + ) + require.NoError(t, err) + testcontainers.CleanupContainer(t, container) + + connectionString, err := container.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + pool, err := pgxpool.New(ctx, connectionString) + require.NoError(t, err) + t.Cleanup(pool.Close) + + return pool +} diff --git a/testkit/internal/authflow/runtime_test.go b/testkit/internal/authflow/runtime_test.go new file mode 100644 index 0000000..fda0534 --- /dev/null +++ b/testkit/internal/authflow/runtime_test.go @@ -0,0 +1,107 @@ +package authflow + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/meigma/authkit" + "github.com/meigma/authkit/httpauth" + "github.com/meigma/authkit/store/memory" +) + +func TestRuntimeExchangesSeedAPITokenForAccessJWT(t *testing.T) { + runtime := newTestRuntime(t) + + result, err := runtime.ExchangeAPIToken(context.Background(), runtime.SeedAPIToken) + require.NoError(t, err) + + assert.Equal(t, runtime.Principal.ID, result.Principal.ID) + assert.Equal(t, runtime.Principal.ID, result.AccessToken.PrincipalID) + assert.Equal(t, fixedTime().Add(AccessTokenTTL), result.AccessToken.ExpiresAt) + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", bearer(result.AccessToken.Plaintext)) + runtime.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + principal, ok := httpauth.PrincipalFromContext(req.Context()) + assert.True(t, ok) + if ok { + assert.Equal(t, runtime.Principal.ID, principal.ID) + } + w.WriteHeader(http.StatusNoContent) + })).ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusNoContent, recorder.Code) +} + +func TestRuntimeRejectsDirectAPITokenAsProtectedBearer(t *testing.T) { + runtime := newTestRuntime(t) + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", bearer(runtime.SeedAPIToken)) + runtime.Authenticate(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + })).ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusSeeOther, recorder.Code) + assert.Equal(t, LoginPath, recorder.Header().Get("Location")) + assert.Equal(t, -1, findSetCookie(t, recorder, CookieName).MaxAge) +} + +func TestRuntimeRejectsInvalidAPITokenExchange(t *testing.T) { + runtime := newTestRuntime(t) + + _, err := runtime.ExchangeAPIToken(context.Background(), "invalid") + + require.ErrorIs(t, err, authkit.ErrUnauthenticated) +} + +func TestRuntimeReusesBootstrapPrincipal(t *testing.T) { + store := memory.NewStore() + first, err := NewRuntime(context.Background(), store, WithClock(fixedTime)) + require.NoError(t, err) + second, err := NewRuntime(context.Background(), store, WithClock(fixedTime)) + require.NoError(t, err) + + assert.Equal(t, first.Principal.ID, second.Principal.ID) + principals, err := store.ListPrincipals(context.Background()) + require.NoError(t, err) + assert.Len(t, principals, 1) +} + +func newTestRuntime(t *testing.T) *Runtime { + t.Helper() + + runtime, err := NewRuntime(context.Background(), memory.NewStore(), WithClock(fixedTime)) + require.NoError(t, err) + + return runtime +} + +func findSetCookie(t *testing.T, recorder *httptest.ResponseRecorder, name string) *http.Cookie { + t.Helper() + + for _, cookie := range recorder.Result().Cookies() { + if cookie.Name == name { + return cookie + } + } + require.Failf(t, "missing cookie", "cookie %q was not set", name) + + return nil +} + +func bearer(token string) string { + return "Bearer " + token +} + +func fixedTime() time.Time { + return time.Date(2026, time.May, 14, 10, 0, 0, 0, time.UTC) +} diff --git a/testkit/internal/httpui/csrf.go b/testkit/internal/httpui/csrf.go new file mode 100644 index 0000000..a1aaa07 --- /dev/null +++ b/testkit/internal/httpui/csrf.go @@ -0,0 +1,62 @@ +package httpui + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "errors" + "net/http" + "strings" +) + +const ( + csrfCookieName = "testkit_csrf" + csrfFieldName = "csrf_token" + csrfTokenBytes = 32 +) + +var errInvalidCSRFToken = errors.New("httpui: invalid CSRF token") + +type csrfProtector struct{} + +func newCSRFProtector() csrfProtector { + return csrfProtector{} +} + +func (csrfProtector) token(w http.ResponseWriter, req *http.Request) (string, error) { + if cookie, err := req.Cookie(csrfCookieName); err == nil && strings.TrimSpace(cookie.Value) != "" { + return cookie.Value, nil + } + + raw := make([]byte, csrfTokenBytes) + if _, err := rand.Read(raw); err != nil { + return "", err + } + token := base64.RawURLEncoding.EncodeToString(raw) + http.SetCookie(w, &http.Cookie{ + Name: csrfCookieName, + Value: token, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }) + + return token, nil +} + +func (csrfProtector) validate(req *http.Request) error { + cookie, err := req.Cookie(csrfCookieName) + if err != nil || strings.TrimSpace(cookie.Value) == "" { + return errInvalidCSRFToken + } + + formToken := req.PostFormValue(csrfFieldName) + if formToken == "" { + return errInvalidCSRFToken + } + if subtle.ConstantTimeCompare([]byte(cookie.Value), []byte(formToken)) != 1 { + return errInvalidCSRFToken + } + + return nil +} diff --git a/testkit/internal/httpui/doc.go b/testkit/internal/httpui/doc.go new file mode 100644 index 0000000..937ae41 --- /dev/null +++ b/testkit/internal/httpui/doc.go @@ -0,0 +1,2 @@ +// Package httpui contains the server-rendered HTTP UI for the testkit pastebin. +package httpui diff --git a/testkit/internal/httpui/handlers.go b/testkit/internal/httpui/handlers.go new file mode 100644 index 0000000..ed6d246 --- /dev/null +++ b/testkit/internal/httpui/handlers.go @@ -0,0 +1,313 @@ +package httpui + +import ( + "bytes" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/meigma/authkit" + "github.com/meigma/authkit/testkit/internal/authflow" + "github.com/meigma/authkit/testkit/internal/paste" +) + +const ( + contentTypeHeader = "Content-Type" + htmlContentType = "text/html; charset=utf-8" + plainContentType = "text/plain; charset=utf-8" + formBodyOverhead = 8 * 1024 + loginFormMaxBytes = 16 * 1024 + pageError = "error" + pageIndex = "index" + pageLogin = "login" + pageNew = "new" + pagePaste = "paste" + pasteIDPathValue = "id" +) + +type pageData struct { + Title string + Pastes []paste.Paste + Paste paste.Paste + Form pasteForm + CSRFToken string + Error string +} + +type pasteForm struct { + Title string + Body string + Syntax string +} + +func (s *Server) handleIndex(w http.ResponseWriter, req *http.Request) { + pastes, err := s.pastes.ListRecent(req.Context(), paste.DefaultRecentLimit) + if err != nil { + s.renderError(w, http.StatusInternalServerError, "Could not load recent pastes.") + + return + } + + s.render(w, http.StatusOK, pageIndex, pageData{ + Title: "Recent pastes", + Pastes: pastes, + }) +} + +func (s *Server) handleNew(w http.ResponseWriter, req *http.Request) { + s.renderNew(w, req, http.StatusOK, pasteForm{}, "") +} + +func (s *Server) handleLogin(w http.ResponseWriter, req *http.Request) { + s.renderLogin(w, req, http.StatusOK, "") +} + +func (s *Server) handleExchangeAPIToken(w http.ResponseWriter, req *http.Request) { + req.Body = http.MaxBytesReader(w, req.Body, loginFormMaxBytes) + if err := req.ParseForm(); err != nil { + s.renderLogin(w, req, http.StatusBadRequest, "Could not read API token.") + + return + } + if err := s.csrf.validate(req); err != nil { + s.renderLogin(w, req, http.StatusForbidden, "Could not validate form.") + + return + } + + rawToken := strings.TrimSpace(req.PostFormValue("api_token")) + if rawToken == "" { + rawToken = bearerToken(req) + } + if rawToken == "" { + s.renderLogin(w, req, http.StatusUnauthorized, "API token is required.") + + return + } + + result, err := s.auth.ExchangeAPIToken(req.Context(), rawToken) + if err != nil { + s.renderExchangeError(w, req, err) + + return + } + + authflow.SetAccessCookie(w, result.AccessToken) + http.Redirect(w, req, "/new", http.StatusSeeOther) +} + +func (s *Server) handleLogout(w http.ResponseWriter, req *http.Request) { + req.Body = http.MaxBytesReader(w, req.Body, loginFormMaxBytes) + if err := req.ParseForm(); err != nil { + s.renderError(w, http.StatusBadRequest, "Could not read logout form.") + + return + } + if err := s.csrf.validate(req); err != nil { + s.renderError(w, http.StatusForbidden, "Could not validate form.") + + return + } + + authflow.ClearAccessCookie(w) + http.Redirect(w, req, "/", http.StatusSeeOther) +} + +func (s *Server) handleCreate(w http.ResponseWriter, req *http.Request) { + req.Body = http.MaxBytesReader(w, req.Body, int64(paste.DefaultMaxBodyBytes+formBodyOverhead)) + if err := req.ParseForm(); err != nil { + s.renderNew(w, req, http.StatusBadRequest, pasteForm{}, "Could not read paste form.") + + return + } + if err := s.csrf.validate(req); err != nil { + s.renderNew(w, req, http.StatusForbidden, pasteForm{}, "Could not validate form.") + + return + } + + form := pasteForm{ + Title: req.PostFormValue("title"), + Body: req.PostFormValue("body"), + Syntax: req.PostFormValue("syntax"), + } + created, err := s.pastes.Create(req.Context(), paste.CreatePasteRequest{ + Title: form.Title, + Body: form.Body, + Syntax: form.Syntax, + }) + if err != nil { + s.renderCreateError(w, req, form, err) + + return + } + + http.Redirect(w, req, pastePath(created.ID), http.StatusSeeOther) +} + +func (s *Server) handlePaste(w http.ResponseWriter, req *http.Request) { + found, err := s.pastes.Read(req.Context(), req.PathValue(pasteIDPathValue)) + if err != nil { + s.renderReadError(w, err) + + return + } + + s.render(w, http.StatusOK, pagePaste, pageData{ + Title: found.Title, + Paste: found, + }) +} + +func (s *Server) handleRaw(w http.ResponseWriter, req *http.Request) { + found, err := s.pastes.Read(req.Context(), req.PathValue(pasteIDPathValue)) + if err != nil { + if errors.Is(err, paste.ErrPasteNotFound) { + http.NotFound(w, req) + + return + } + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + + return + } + + w.Header().Set(contentTypeHeader, plainContentType) + if _, err := w.Write([]byte(found.Body)); err != nil { + return + } +} + +func (s *Server) withAccessCookie(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.Header.Get("Authorization") != "" { + next.ServeHTTP(w, req) + + return + } + + cookie, err := req.Cookie(authflow.CookieName) + if err != nil || strings.TrimSpace(cookie.Value) == "" { + next.ServeHTTP(w, req) + + return + } + + authedReq := req.Clone(req.Context()) + authedReq.Header = req.Header.Clone() + authedReq.Header.Set("Authorization", "Bearer "+cookie.Value) + next.ServeHTTP(w, authedReq) + }) +} + +func (s *Server) renderCreateError(w http.ResponseWriter, req *http.Request, form pasteForm, err error) { + switch { + case errors.Is(err, paste.ErrEmptyBody): + s.renderNew(w, req, http.StatusBadRequest, form, "Paste body is required.") + case isBodyTooLarge(err): + s.renderNew(w, req, http.StatusRequestEntityTooLarge, form, "Paste body is too large.") + default: + s.renderError(w, http.StatusInternalServerError, "Could not create paste.") + } +} + +func (s *Server) renderExchangeError(w http.ResponseWriter, req *http.Request, err error) { + if errors.Is(err, authkit.ErrUnauthenticated) { + s.renderLogin(w, req, http.StatusUnauthorized, "API token is invalid.") + + return + } + + s.renderError(w, http.StatusInternalServerError, "Could not exchange API token.") +} + +func (s *Server) renderReadError(w http.ResponseWriter, err error) { + if errors.Is(err, paste.ErrPasteNotFound) { + s.renderError(w, http.StatusNotFound, "Paste not found.") + + return + } + + s.renderError(w, http.StatusInternalServerError, "Could not load paste.") +} + +func (s *Server) renderNew( + w http.ResponseWriter, + req *http.Request, + status int, + form pasteForm, + message string, +) { + token, err := s.csrf.token(w, req) + if err != nil { + s.renderError(w, http.StatusInternalServerError, "Could not prepare form.") + + return + } + + s.render(w, status, pageNew, pageData{ + Title: "New paste", + Form: form, + CSRFToken: token, + Error: message, + }) +} + +func (s *Server) renderLogin(w http.ResponseWriter, req *http.Request, status int, message string) { + token, err := s.csrf.token(w, req) + if err != nil { + s.renderError(w, http.StatusInternalServerError, "Could not prepare form.") + + return + } + + s.render(w, status, pageLogin, pageData{ + Title: "API token login", + CSRFToken: token, + Error: message, + }) +} + +func (s *Server) renderError(w http.ResponseWriter, status int, message string) { + s.render(w, status, pageError, pageData{ + Title: http.StatusText(status), + Error: message, + }) +} + +func (s *Server) render(w http.ResponseWriter, status int, page string, data pageData) { + var buf bytes.Buffer + if err := s.templates.execute(&buf, page, data); err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + + return + } + + w.Header().Set(contentTypeHeader, htmlContentType) + w.WriteHeader(status) + if _, err := w.Write(buf.Bytes()); err != nil { + return + } +} + +func isBodyTooLarge(err error) bool { + var bodyErr paste.BodyTooLargeError + + return errors.As(err, &bodyErr) +} + +func pastePath(id string) string { + return fmt.Sprintf("/p/%s", url.PathEscape(id)) +} + +func bearerToken(req *http.Request) string { + header := req.Header.Get("Authorization") + parts := strings.Fields(header) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + return "" + } + + return parts[1] +} diff --git a/testkit/internal/httpui/server.go b/testkit/internal/httpui/server.go new file mode 100644 index 0000000..21d804a --- /dev/null +++ b/testkit/internal/httpui/server.go @@ -0,0 +1,70 @@ +package httpui + +import ( + "errors" + "fmt" + "io/fs" + "net/http" + + "github.com/meigma/authkit/testkit/internal/authflow" + "github.com/meigma/authkit/testkit/internal/paste" +) + +const ( + staticDir = "static" + staticURL = "/static/" +) + +// Server serves the testkit pastebin UI. +type Server struct { + handler http.Handler + auth *authflow.Runtime + csrf csrfProtector + pastes *paste.Service + templates *templateSet +} + +// NewServer constructs a testkit HTTP UI server. +func NewServer(pastes *paste.Service, auth *authflow.Runtime) (*Server, error) { + if pastes == nil { + return nil, errors.New("httpui: paste service is required") + } + if auth == nil { + return nil, errors.New("httpui: auth runtime is required") + } + + templates, err := newTemplateSet() + if err != nil { + return nil, err + } + + staticFiles, err := fs.Sub(content, staticDir) + if err != nil { + return nil, fmt.Errorf("httpui: prepare static assets: %w", err) + } + + server := &Server{ + auth: auth, + csrf: newCSRFProtector(), + pastes: pastes, + templates: templates, + } + mux := http.NewServeMux() + mux.Handle("GET "+staticURL, http.StripPrefix(staticURL, http.FileServer(http.FS(staticFiles)))) + mux.HandleFunc("GET /{$}", server.handleIndex) + mux.HandleFunc("GET /login", server.handleLogin) + mux.HandleFunc("POST /auth/token", server.handleExchangeAPIToken) + mux.HandleFunc("POST /logout", server.handleLogout) + mux.Handle("GET /new", server.withAccessCookie(auth.Authenticate(http.HandlerFunc(server.handleNew)))) + mux.Handle("POST /pastes", server.withAccessCookie(auth.Authenticate(http.HandlerFunc(server.handleCreate)))) + mux.HandleFunc("GET /p/{id}", server.handlePaste) + mux.HandleFunc("GET /raw/{id}", server.handleRaw) + server.handler = mux + + return server, nil +} + +// ServeHTTP serves an HTTP request. +func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s.handler.ServeHTTP(w, req) +} diff --git a/testkit/internal/httpui/server_test.go b/testkit/internal/httpui/server_test.go new file mode 100644 index 0000000..803f3a1 --- /dev/null +++ b/testkit/internal/httpui/server_test.go @@ -0,0 +1,406 @@ +package httpui + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + authmemory "github.com/meigma/authkit/store/memory" + "github.com/meigma/authkit/testkit/internal/authflow" + "github.com/meigma/authkit/testkit/internal/paste" + testkitmemory "github.com/meigma/authkit/testkit/internal/store/memory" +) + +const testPasteID = "paste-1" + +func TestServerRendersPublicPages(t *testing.T) { + server := newTestServer(t, testPasteID) + + tests := []struct { + name string + path string + wantStatus int + wantBody string + }{ + { + name: "index", + path: "/", + wantStatus: http.StatusOK, + wantBody: "No pastes yet.", + }, + { + name: "login form", + path: "/login", + wantStatus: http.StatusOK, + wantBody: "API token login", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + server.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, tt.path, nil)) + + assert.Equal(t, tt.wantStatus, recorder.Code) + assert.Contains(t, recorder.Body.String(), tt.wantBody) + assert.Equal(t, htmlContentType, recorder.Header().Get(contentTypeHeader)) + }) + } + + loginRecorder := httptest.NewRecorder() + server.ServeHTTP(loginRecorder, httptest.NewRequest(http.MethodGet, "/login", nil)) + assert.Contains(t, loginRecorder.Body.String(), `name="csrf_token"`) + assert.NotEmpty(t, findCookie(t, loginRecorder, csrfCookieName).Value) +} + +func TestServerRequiresAuthenticationForPasteCreation(t *testing.T) { + server := newTestServer(t, testPasteID) + + tests := []struct { + name string + req *http.Request + }{ + { + name: "new paste form", + req: httptest.NewRequest(http.MethodGet, "/new", nil), + }, + { + name: "create paste", + req: newPostFormRequest(t, "/pastes", url.Values{ + "body": {"hello"}, + }), + }, + { + name: "API token is not a runtime bearer token", + req: newAuthorizedRequest( + httptest.NewRequest(http.MethodGet, "/new", nil), + bearer(server.auth.SeedAPIToken), + ), + }, + { + name: "API token is not an access cookie", + req: newCookieRequest( + httptest.NewRequest(http.MethodGet, "/new", nil), + &http.Cookie{Name: authflow.CookieName, Value: server.auth.SeedAPIToken}, + ), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + server.ServeHTTP(recorder, tt.req) + + assert.Equal(t, http.StatusSeeOther, recorder.Code) + assert.Equal(t, authflow.LoginPath, recorder.Header().Get("Location")) + }) + } +} + +func TestServerExchangesAPITokenAndCreatesPaste(t *testing.T) { + server := newTestServer(t, testPasteID) + browser := exchangeAccessCookie(t, server) + + newRecorder := httptest.NewRecorder() + newReq := httptest.NewRequest(http.MethodGet, "/new", nil) + newReq.AddCookie(browser.access) + newReq.AddCookie(browser.csrf) + server.ServeHTTP(newRecorder, newReq) + + require.Equal(t, http.StatusOK, newRecorder.Code) + assert.Contains(t, newRecorder.Body.String(), "Create paste") + assert.Contains(t, newRecorder.Body.String(), `name="csrf_token"`) + + createRecorder := httptest.NewRecorder() + createReq := newPostFormRequest(t, "/pastes", url.Values{ + "title": {"Example title"}, + "body": {"hello from the paste"}, + "syntax": {"text"}, + csrfFieldName: {browser.csrf.Value}, + }) + createReq.AddCookie(browser.access) + createReq.AddCookie(browser.csrf) + server.ServeHTTP(createRecorder, createReq) + + require.Equal(t, http.StatusSeeOther, createRecorder.Code) + assert.Equal(t, "/p/"+testPasteID, createRecorder.Header().Get("Location")) + + pasteRecorder := httptest.NewRecorder() + server.ServeHTTP(pasteRecorder, httptest.NewRequest(http.MethodGet, "/p/"+testPasteID, nil)) + + assert.Equal(t, http.StatusOK, pasteRecorder.Code) + assert.Contains(t, pasteRecorder.Body.String(), "Example title") + assert.Contains(t, pasteRecorder.Body.String(), "hello from the paste") + assert.Contains(t, pasteRecorder.Body.String(), "text") + + rawRecorder := httptest.NewRecorder() + server.ServeHTTP(rawRecorder, httptest.NewRequest(http.MethodGet, "/raw/"+testPasteID, nil)) + + assert.Equal(t, http.StatusOK, rawRecorder.Code) + assert.Equal(t, plainContentType, rawRecorder.Header().Get(contentTypeHeader)) + assert.Equal(t, "hello from the paste", rawRecorder.Body.String()) + + indexRecorder := httptest.NewRecorder() + server.ServeHTTP(indexRecorder, httptest.NewRequest(http.MethodGet, "/", nil)) + + assert.Equal(t, http.StatusOK, indexRecorder.Code) + assert.Contains(t, indexRecorder.Body.String(), "Example title") + assert.Contains(t, indexRecorder.Body.String(), "/p/"+testPasteID) +} + +func TestServerRejectsInvalidAPITokenExchange(t *testing.T) { + server := newTestServer(t, testPasteID) + csrfCookie := csrfFromLogin(t, server) + + recorder := httptest.NewRecorder() + req := newPostFormRequest(t, "/auth/token", url.Values{ + "api_token": {"invalid"}, + csrfFieldName: {csrfCookie.Value}, + }) + req.AddCookie(csrfCookie) + server.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusUnauthorized, recorder.Code) + assert.Contains(t, recorder.Body.String(), "API token is invalid.") + assert.NotContains(t, recorder.Body.String(), `value="invalid"`) +} + +func TestServerRejectsEmptyPasteBody(t *testing.T) { + server := newTestServer(t, testPasteID) + browser := exchangeAccessCookie(t, server) + + req := newPostFormRequest(t, "/pastes", url.Values{ + "title": {"Empty paste"}, + "body": {" \n\t "}, + csrfFieldName: {browser.csrf.Value}, + }) + req.AddCookie(browser.access) + req.AddCookie(browser.csrf) + recorder := httptest.NewRecorder() + server.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusBadRequest, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Paste body is required.") + assert.Contains(t, recorder.Body.String(), "Empty paste") +} + +func TestServerLogoutClearsAccessCookie(t *testing.T) { + server := newTestServer(t, testPasteID) + browser := exchangeAccessCookie(t, server) + + req := newPostFormRequest(t, "/logout", url.Values{ + csrfFieldName: {browser.csrf.Value}, + }) + req.AddCookie(browser.access) + req.AddCookie(browser.csrf) + recorder := httptest.NewRecorder() + server.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusSeeOther, recorder.Code) + assert.Equal(t, "/", recorder.Header().Get("Location")) + cleared := findCookie(t, recorder, authflow.CookieName) + assert.Equal(t, -1, cleared.MaxAge) + assert.Empty(t, cleared.Value) +} + +func TestServerReturnsNotFoundForMissingPaste(t *testing.T) { + server := newTestServer(t, testPasteID) + + tests := []struct { + name string + path string + }{ + {name: "paste page", path: "/p/missing"}, + {name: "raw paste", path: "/raw/missing"}, + {name: "unknown route", path: "/missing"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + server.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, tt.path, nil)) + + assert.Equal(t, http.StatusNotFound, recorder.Code) + }) + } +} + +func TestServerRejectsMissingCSRFToken(t *testing.T) { + server := newTestServer(t, testPasteID) + browser := exchangeAccessCookie(t, server) + + tests := []struct { + name string + req *http.Request + }{ + { + name: "API-token exchange", + req: newPostFormRequest(t, "/auth/token", url.Values{ + "api_token": {server.auth.SeedAPIToken}, + }), + }, + { + name: "paste create", + req: func() *http.Request { + req := newPostFormRequest(t, "/pastes", url.Values{ + "body": {"hello"}, + }) + req.AddCookie(browser.access) + + return req + }(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + server.ServeHTTP(recorder, tt.req) + + assert.Equal(t, http.StatusForbidden, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Could not validate form.") + }) + } +} + +type testServer struct { + *Server + + auth *authflow.Runtime +} + +func newTestServer(t *testing.T, ids ...string) *testServer { + t.Helper() + + sequence := sequentialIDs(ids...) + service, err := paste.NewService( + testkitmemory.NewStore(), + paste.WithIDGenerator(sequence.next), + paste.WithClock(fixedTime), + ) + require.NoError(t, err) + + authRuntime, err := authflow.NewRuntime( + context.Background(), + authmemory.NewStore(), + authflow.WithClock(fixedTime), + ) + require.NoError(t, err) + server, err := NewServer(service, authRuntime) + require.NoError(t, err) + + return &testServer{ + Server: server, + auth: authRuntime, + } +} + +type browserCookies struct { + access *http.Cookie + csrf *http.Cookie +} + +func exchangeAccessCookie(t *testing.T, server *testServer) browserCookies { + t.Helper() + + csrfCookie := csrfFromLogin(t, server) + req := newPostFormRequest(t, "/auth/token", url.Values{ + "api_token": {server.auth.SeedAPIToken}, + csrfFieldName: {csrfCookie.Value}, + }) + req.AddCookie(csrfCookie) + recorder := httptest.NewRecorder() + server.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusSeeOther, recorder.Code) + assert.Equal(t, "/new", recorder.Header().Get("Location")) + + return browserCookies{ + access: findCookie(t, recorder, authflow.CookieName), + csrf: csrfCookie, + } +} + +func csrfFromLogin(t *testing.T, server *testServer) *http.Cookie { + t.Helper() + + recorder := httptest.NewRecorder() + server.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/login", nil)) + require.Equal(t, http.StatusOK, recorder.Code) + + return findCookie(t, recorder, csrfCookieName) +} + +func findCookie(t *testing.T, recorder *httptest.ResponseRecorder, name string) *http.Cookie { + t.Helper() + + for _, cookie := range recorder.Result().Cookies() { + if cookie.Name == name { + return cookie + } + } + require.Failf(t, "missing cookie", "cookie %q was not set", name) + + return nil +} + +func newPostFormRequest(t *testing.T, path string, values url.Values) *http.Request { + t.Helper() + + body := "" + if values != nil { + body = values.Encode() + } + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(body)) + req.Header.Set(contentTypeHeader, "application/x-www-form-urlencoded") + + return req +} + +func newAuthorizedRequest(req *http.Request, authorization string) *http.Request { + req.Header.Set("Authorization", authorization) + + return req +} + +func newCookieRequest(req *http.Request, cookie *http.Cookie) *http.Request { + req.AddCookie(cookie) + + return req +} + +func bearer(token string) string { + return "Bearer " + token +} + +type idSequence struct { + values []string + nextID int +} + +func sequentialIDs(ids ...string) *idSequence { + return &idSequence{values: ids} +} + +func (s *idSequence) next() (string, error) { + if s.nextID >= len(s.values) { + return "", errors.New("test: no more IDs") + } + + id := s.values[s.nextID] + s.nextID++ + + return id, nil +} + +func fixedTime() time.Time { + return time.Date(2026, time.May, 14, 10, 0, 0, 0, time.UTC) +} diff --git a/testkit/internal/httpui/static/app.css b/testkit/internal/httpui/static/app.css new file mode 100644 index 0000000..d019f09 --- /dev/null +++ b/testkit/internal/httpui/static/app.css @@ -0,0 +1,209 @@ +:root { + color-scheme: light; + --bg: #f7f7f4; + --panel: #ffffff; + --text: #1f2428; + --muted: #65717a; + --line: #d8ddd8; + --accent: #116149; + --accent-strong: #0b4534; + --error: #9d1c24; + --code-bg: #171b1f; + --code-text: #ecf2f0; +} + +* { + box-sizing: border-box; +} + +body { + margin: 0; + background: var(--bg); + color: var(--text); + font-family: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; + line-height: 1.5; +} + +a { + color: var(--accent); +} + +.site-header { + display: flex; + align-items: center; + justify-content: space-between; + gap: 1rem; + padding: 1rem clamp(1rem, 4vw, 3rem); + border-bottom: 1px solid var(--line); + background: var(--panel); +} + +.brand { + color: var(--text); + font-weight: 700; + text-decoration: none; +} + +nav { + display: flex; + align-items: center; + gap: 1rem; +} + +.button, +button { + display: inline-flex; + align-items: center; + justify-content: center; + min-height: 2.35rem; + padding: 0.45rem 0.8rem; + border: 1px solid var(--accent); + border-radius: 6px; + background: var(--accent); + color: #ffffff; + font: inherit; + font-weight: 650; + text-decoration: none; + cursor: pointer; +} + +.button:hover, +button:hover { + background: var(--accent-strong); +} + +.page { + width: min(100%, 980px); + margin: 0 auto; + padding: 2rem clamp(1rem, 4vw, 3rem) 4rem; +} + +.page-heading, +.paste-header { + display: flex; + align-items: flex-start; + justify-content: space-between; + gap: 1rem; + margin-bottom: 1.25rem; +} + +h1 { + margin: 0; + font-size: clamp(1.8rem, 3vw, 2.5rem); + line-height: 1.1; +} + +.paste-list { + display: grid; + gap: 0.75rem; + margin: 0; + padding: 0; + list-style: none; +} + +.paste-list li { + display: flex; + align-items: center; + justify-content: space-between; + gap: 1rem; + padding: 0.9rem 1rem; + border: 1px solid var(--line); + border-radius: 8px; + background: var(--panel); +} + +.paste-list a { + color: var(--text); + font-weight: 650; + overflow-wrap: anywhere; +} + +.paste-list span, +.paste-header p, +.empty { + color: var(--muted); +} + +.paste-form { + display: grid; + gap: 1rem; + padding: 1rem; + border: 1px solid var(--line); + border-radius: 8px; + background: var(--panel); +} + +.auth-form { + max-width: 34rem; +} + +label { + display: grid; + gap: 0.35rem; + color: var(--muted); + font-weight: 650; +} + +input, +textarea { + width: 100%; + border: 1px solid var(--line); + border-radius: 6px; + padding: 0.65rem 0.75rem; + color: var(--text); + font: inherit; +} + +textarea { + min-height: 18rem; + resize: vertical; + font-family: ui-monospace, "SFMono-Regular", Consolas, monospace; +} + +.error { + padding: 0.8rem 1rem; + border: 1px solid color-mix(in srgb, var(--error) 35%, transparent); + border-radius: 8px; + background: color-mix(in srgb, var(--error) 8%, white); + color: var(--error); +} + +.paste { + display: grid; + gap: 1rem; +} + +.paste-header p { + display: flex; + gap: 0.75rem; + flex-wrap: wrap; + margin: 0.5rem 0 0; +} + +pre { + margin: 0; + padding: 1rem; + overflow: auto; + border-radius: 8px; + background: var(--code-bg); + color: var(--code-text); +} + +code { + font-family: ui-monospace, "SFMono-Regular", Consolas, monospace; + font-size: 0.95rem; +} + +@media (max-width: 640px) { + .site-header, + .page-heading, + .paste-header, + .paste-list li { + align-items: stretch; + flex-direction: column; + } + + nav { + justify-content: space-between; + } +} diff --git a/testkit/internal/httpui/templates.go b/testkit/internal/httpui/templates.go new file mode 100644 index 0000000..8b81270 --- /dev/null +++ b/testkit/internal/httpui/templates.go @@ -0,0 +1,62 @@ +package httpui + +import ( + "embed" + "fmt" + "html/template" + "io" + "time" +) + +const ( + layoutTemplate = "templates/layout.html" + pageTemplate = "content" + timeFormat = "2006-01-02 15:04 UTC" +) + +//go:embed templates/*.html static/*.css +var content embed.FS + +type templateSet struct { + pages map[string]*template.Template +} + +func newTemplateSet() (*templateSet, error) { + pageFiles := map[string]string{ + pageError: "templates/error.html", + pageIndex: "templates/index.html", + pageLogin: "templates/login.html", + pageNew: "templates/new.html", + pagePaste: "templates/paste.html", + } + pages := make(map[string]*template.Template, len(pageFiles)) + funcs := template.FuncMap{ + "formatTime": formatTime, + } + + for name, file := range pageFiles { + parsed, err := template.New(name).Funcs(funcs).ParseFS(content, layoutTemplate, file) + if err != nil { + return nil, fmt.Errorf("httpui: parse template %s: %w", name, err) + } + pages[name] = parsed + } + + return &templateSet{pages: pages}, nil +} + +func (t *templateSet) execute(w io.Writer, page string, data pageData) error { + tmpl, exists := t.pages[page] + if !exists { + return fmt.Errorf("httpui: template %s not found", page) + } + if err := tmpl.ExecuteTemplate(w, pageTemplate, data); err != nil { + return fmt.Errorf("httpui: execute template %s: %w", page, err) + } + + return nil +} + +func formatTime(value time.Time) string { + return value.UTC().Format(timeFormat) +} diff --git a/testkit/internal/httpui/templates/error.html b/testkit/internal/httpui/templates/error.html new file mode 100644 index 0000000..5088509 --- /dev/null +++ b/testkit/internal/httpui/templates/error.html @@ -0,0 +1,6 @@ +{{define "page"}} +
+

{{.Title}}

+
+ +{{end}} diff --git a/testkit/internal/httpui/templates/index.html b/testkit/internal/httpui/templates/index.html new file mode 100644 index 0000000..caa34c1 --- /dev/null +++ b/testkit/internal/httpui/templates/index.html @@ -0,0 +1,19 @@ +{{define "page"}} +
+

Recent pastes

+ Create paste +
+ +{{if .Pastes}} +
    + {{range .Pastes}} +
  1. + {{if .Title}}{{.Title}}{{else}}Untitled paste{{end}} + {{formatTime .CreatedAt}} +
  2. + {{end}} +
+{{else}} +

No pastes yet.

+{{end}} +{{end}} diff --git a/testkit/internal/httpui/templates/layout.html b/testkit/internal/httpui/templates/layout.html new file mode 100644 index 0000000..4759a09 --- /dev/null +++ b/testkit/internal/httpui/templates/layout.html @@ -0,0 +1,24 @@ +{{define "content"}} + + + + + + {{if .Title}}{{.Title}} - {{end}}authkit testkit pastebin + + + + +
+ {{template "page" .}} +
+ + +{{end}} diff --git a/testkit/internal/httpui/templates/login.html b/testkit/internal/httpui/templates/login.html new file mode 100644 index 0000000..5859dae --- /dev/null +++ b/testkit/internal/httpui/templates/login.html @@ -0,0 +1,18 @@ +{{define "page"}} +
+

API token login

+
+ +{{if .Error}} + +{{end}} + +
+ + + +
+{{end}} diff --git a/testkit/internal/httpui/templates/new.html b/testkit/internal/httpui/templates/new.html new file mode 100644 index 0000000..7e6e838 --- /dev/null +++ b/testkit/internal/httpui/templates/new.html @@ -0,0 +1,26 @@ +{{define "page"}} +
+

Create paste

+
+ +{{if .Error}} + +{{end}} + +
+ + + + + +
+{{end}} diff --git a/testkit/internal/httpui/templates/paste.html b/testkit/internal/httpui/templates/paste.html new file mode 100644 index 0000000..8d1d1a2 --- /dev/null +++ b/testkit/internal/httpui/templates/paste.html @@ -0,0 +1,15 @@ +{{define "page"}} +
+
+
+

{{if .Paste.Title}}{{.Paste.Title}}{{else}}Untitled paste{{end}}

+

+ + {{if .Paste.Syntax}}{{.Paste.Syntax}}{{end}} +

+
+ Raw +
+
{{.Paste.Body}}
+
+{{end}} diff --git a/testkit/internal/paste/doc.go b/testkit/internal/paste/doc.go new file mode 100644 index 0000000..0794f44 --- /dev/null +++ b/testkit/internal/paste/doc.go @@ -0,0 +1,2 @@ +// Package paste contains the core pastebin behavior for authkit's testkit app. +package paste diff --git a/testkit/internal/paste/errors.go b/testkit/internal/paste/errors.go new file mode 100644 index 0000000..5662e54 --- /dev/null +++ b/testkit/internal/paste/errors.go @@ -0,0 +1,27 @@ +package paste + +import ( + "errors" + "fmt" +) + +var ( + // ErrEmptyBody indicates that a paste body is blank after trimming whitespace. + ErrEmptyBody = errors.New("paste: body is required") + + // ErrPasteNotFound indicates that a paste ID does not exist. + ErrPasteNotFound = errors.New("paste: paste not found") + + // ErrDuplicatePasteID indicates that a paste ID already exists in storage. + ErrDuplicatePasteID = errors.New("paste: duplicate paste ID") +) + +// BodyTooLargeError indicates that a paste body exceeds the configured byte limit. +type BodyTooLargeError struct { + // MaxBytes is the largest accepted paste body size. + MaxBytes int +} + +func (e BodyTooLargeError) Error() string { + return fmt.Sprintf("paste: body exceeds %d bytes", e.MaxBytes) +} diff --git a/testkit/internal/paste/service.go b/testkit/internal/paste/service.go new file mode 100644 index 0000000..81848bc --- /dev/null +++ b/testkit/internal/paste/service.go @@ -0,0 +1,153 @@ +package paste + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "strings" + "time" +) + +const idEntropyBytes = 12 + +// IDGenerator returns a new paste ID. +type IDGenerator func() (string, error) + +// Option configures a Service. +type Option func(*serviceOptions) + +type serviceOptions struct { + clock func() time.Time + idGenerator IDGenerator + maxBodyBytes int +} + +// Service coordinates paste creation and lookup behavior. +type Service struct { + repo Repository + clock func() time.Time + idGenerator IDGenerator + maxBodyBytes int +} + +// NewService constructs a paste service around repo. +func NewService(repo Repository, opts ...Option) (*Service, error) { + if repo == nil { + return nil, errors.New("paste: repository is required") + } + + cfg := serviceOptions{ + clock: time.Now, + idGenerator: generateID, + maxBodyBytes: DefaultMaxBodyBytes, + } + for _, opt := range opts { + if opt != nil { + opt(&cfg) + } + } + if cfg.clock == nil { + cfg.clock = time.Now + } + if cfg.idGenerator == nil { + cfg.idGenerator = generateID + } + if cfg.maxBodyBytes <= 0 { + cfg.maxBodyBytes = DefaultMaxBodyBytes + } + + return &Service{ + repo: repo, + clock: cfg.clock, + idGenerator: cfg.idGenerator, + maxBodyBytes: cfg.maxBodyBytes, + }, nil +} + +// WithClock sets the clock used for paste creation timestamps. +func WithClock(clock func() time.Time) Option { + return func(opts *serviceOptions) { + opts.clock = clock + } +} + +// WithIDGenerator sets the ID generator used for new pastes. +func WithIDGenerator(generator IDGenerator) Option { + return func(opts *serviceOptions) { + opts.idGenerator = generator + } +} + +// WithMaxBodyBytes sets the maximum accepted paste body size. +func WithMaxBodyBytes(maxBytes int) Option { + return func(opts *serviceOptions) { + opts.maxBodyBytes = maxBytes + } +} + +// Create validates and stores a new paste. +func (s *Service) Create(ctx context.Context, req CreatePasteRequest) (Paste, error) { + if strings.TrimSpace(req.Body) == "" { + return Paste{}, ErrEmptyBody + } + if len(req.Body) > s.maxBodyBytes { + return Paste{}, BodyTooLargeError{MaxBytes: s.maxBodyBytes} + } + + id, err := s.idGenerator() + if err != nil { + return Paste{}, fmt.Errorf("paste: generate ID: %w", err) + } + id = strings.TrimSpace(id) + if id == "" { + return Paste{}, errors.New("paste: generated ID is empty") + } + + created := Paste{ + ID: id, + Title: strings.TrimSpace(req.Title), + Body: req.Body, + Syntax: strings.TrimSpace(req.Syntax), + CreatedAt: s.clock().UTC(), + } + if err := s.repo.Create(ctx, created); err != nil { + return Paste{}, fmt.Errorf("paste: store paste: %w", err) + } + + return created, nil +} + +// Read returns a paste by ID. +func (s *Service) Read(ctx context.Context, id string) (Paste, error) { + paste, err := s.repo.Find(ctx, strings.TrimSpace(id)) + if err != nil { + return Paste{}, fmt.Errorf("paste: find paste: %w", err) + } + + return paste, nil +} + +// ListRecent returns recent pastes, newest first. +func (s *Service) ListRecent(ctx context.Context, limit int) ([]Paste, error) { + if limit <= 0 { + limit = DefaultRecentLimit + } + + pastes, err := s.repo.ListRecent(ctx, limit) + if err != nil { + return nil, fmt.Errorf("paste: list recent pastes: %w", err) + } + + return pastes, nil +} + +func generateID() (string, error) { + var entropy [idEntropyBytes]byte + if _, err := rand.Read(entropy[:]); err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(entropy[:]), nil +} diff --git a/testkit/internal/paste/service_test.go b/testkit/internal/paste/service_test.go new file mode 100644 index 0000000..ec09f00 --- /dev/null +++ b/testkit/internal/paste/service_test.go @@ -0,0 +1,158 @@ +package paste_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/meigma/authkit/testkit/internal/paste" + "github.com/meigma/authkit/testkit/internal/store/memory" +) + +const ( + firstPasteID = "paste-1" + secondPasteID = "paste-2" + maxTestBytes = 5 +) + +func TestServiceCreatesPaste(t *testing.T) { + service := newTestService(t, firstPasteID, fixedTime()) + + created, err := service.Create(context.Background(), paste.CreatePasteRequest{ + Title: " Example paste ", + Body: "hello, pastebin", + Syntax: " text ", + }) + + require.NoError(t, err) + assert.Equal(t, firstPasteID, created.ID) + assert.Equal(t, "Example paste", created.Title) + assert.Equal(t, "hello, pastebin", created.Body) + assert.Equal(t, "text", created.Syntax) + assert.Equal(t, fixedTime(), created.CreatedAt) + + found, err := service.Read(context.Background(), firstPasteID) + require.NoError(t, err) + assert.Equal(t, created, found) +} + +func TestServiceRejectsInvalidBodies(t *testing.T) { + tests := []struct { + name string + body string + assertErr func(*testing.T, error) + }{ + { + name: "empty body", + body: " \n\t ", + assertErr: func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, paste.ErrEmptyBody) + }, + }, + { + name: "body above limit", + body: "abcdef", + assertErr: func(t *testing.T, err error) { + t.Helper() + var bodyErr paste.BodyTooLargeError + require.ErrorAs(t, err, &bodyErr) + assert.Equal(t, maxTestBytes, bodyErr.MaxBytes) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + service := newTestService(t, firstPasteID, fixedTime(), paste.WithMaxBodyBytes(maxTestBytes)) + + _, err := service.Create(context.Background(), paste.CreatePasteRequest{ + Body: tt.body, + }) + + tt.assertErr(t, err) + }) + } +} + +func TestServiceReturnsMissingPasteError(t *testing.T) { + service := newTestService(t, firstPasteID, fixedTime()) + + _, err := service.Read(context.Background(), "missing") + + assert.ErrorIs(t, err, paste.ErrPasteNotFound) +} + +func TestServiceListsRecentPastesNewestFirst(t *testing.T) { + repo := memory.NewStore() + ids := sequentialIDs(firstPasteID, secondPasteID) + now := fixedTime() + service, err := paste.NewService( + repo, + paste.WithIDGenerator(ids.next), + paste.WithClock(func() time.Time { + return now + }), + ) + require.NoError(t, err) + + _, err = service.Create(context.Background(), paste.CreatePasteRequest{Body: "older"}) + require.NoError(t, err) + now = now.Add(time.Minute) + _, err = service.Create(context.Background(), paste.CreatePasteRequest{Body: "newer"}) + require.NoError(t, err) + + recent, err := service.ListRecent(context.Background(), paste.DefaultRecentLimit) + + require.NoError(t, err) + require.Len(t, recent, 2) + assert.Equal(t, secondPasteID, recent[0].ID) + assert.Equal(t, firstPasteID, recent[1].ID) +} + +func newTestService(t *testing.T, id string, now time.Time, opts ...paste.Option) *paste.Service { + t.Helper() + + allOpts := []paste.Option{ + paste.WithIDGenerator(func() (string, error) { + return id, nil + }), + paste.WithClock(func() time.Time { + return now + }), + } + allOpts = append(allOpts, opts...) + + service, err := paste.NewService(memory.NewStore(), allOpts...) + require.NoError(t, err) + + return service +} + +type idSequence struct { + values []string + nextID int +} + +func sequentialIDs(ids ...string) *idSequence { + return &idSequence{values: ids} +} + +func (s *idSequence) next() (string, error) { + if s.nextID >= len(s.values) { + return "", errors.New("test: no more IDs") + } + + id := s.values[s.nextID] + s.nextID++ + + return id, nil +} + +func fixedTime() time.Time { + return time.Date(2026, time.May, 14, 10, 0, 0, 0, time.UTC) +} diff --git a/testkit/internal/paste/store.go b/testkit/internal/paste/store.go new file mode 100644 index 0000000..53ab615 --- /dev/null +++ b/testkit/internal/paste/store.go @@ -0,0 +1,15 @@ +package paste + +import "context" + +// Repository stores and retrieves pastes. +type Repository interface { + // Create stores a new paste. + Create(ctx context.Context, paste Paste) error + + // Find returns a paste by ID. + Find(ctx context.Context, id string) (Paste, error) + + // ListRecent returns recent pastes, newest first, up to limit. + ListRecent(ctx context.Context, limit int) ([]Paste, error) +} diff --git a/testkit/internal/paste/types.go b/testkit/internal/paste/types.go new file mode 100644 index 0000000..37d21a7 --- /dev/null +++ b/testkit/internal/paste/types.go @@ -0,0 +1,41 @@ +package paste + +import "time" + +const ( + // DefaultMaxBodyBytes is the first-slice paste body limit. + DefaultMaxBodyBytes = 64 * 1024 + + // DefaultRecentLimit is the default number of pastes shown in recent lists. + DefaultRecentLimit = 50 +) + +// Paste is a stored pastebin entry. +type Paste struct { + // ID is the stable URL identifier for the paste. + ID string + + // Title is the optional human-readable paste title. + Title string + + // Body is the exact paste content. + Body string + + // Syntax is an optional display label for the paste content type. + Syntax string + + // CreatedAt is when the paste was created. + CreatedAt time.Time +} + +// CreatePasteRequest describes a paste creation request. +type CreatePasteRequest struct { + // Title is an optional human-readable paste title. + Title string + + // Body is the required paste content. + Body string + + // Syntax is an optional display label for the paste content type. + Syntax string +} diff --git a/testkit/internal/store/memory/doc.go b/testkit/internal/store/memory/doc.go new file mode 100644 index 0000000..286bdf9 --- /dev/null +++ b/testkit/internal/store/memory/doc.go @@ -0,0 +1,2 @@ +// Package memory contains the in-memory testkit paste repository. +package memory diff --git a/testkit/internal/store/memory/store.go b/testkit/internal/store/memory/store.go new file mode 100644 index 0000000..3f8d80a --- /dev/null +++ b/testkit/internal/store/memory/store.go @@ -0,0 +1,98 @@ +package memory + +import ( + "context" + "slices" + "strings" + "sync" + + "github.com/meigma/authkit/testkit/internal/paste" +) + +const ( + sortBefore = -1 + sortEqual = 0 + sortAfter = 1 +) + +// Store keeps pastes in process memory. +type Store struct { + mu sync.RWMutex + pastes map[string]paste.Paste +} + +// NewStore constructs an empty in-memory paste store. +func NewStore() *Store { + return &Store{ + pastes: make(map[string]paste.Paste), + } +} + +// Create stores a new paste. +func (s *Store) Create(ctx context.Context, created paste.Paste) error { + if err := ctx.Err(); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + if _, exists := s.pastes[created.ID]; exists { + return paste.ErrDuplicatePasteID + } + s.pastes[created.ID] = created + + return nil +} + +// Find returns a paste by ID. +func (s *Store) Find(ctx context.Context, id string) (paste.Paste, error) { + if err := ctx.Err(); err != nil { + return paste.Paste{}, err + } + + s.mu.RLock() + defer s.mu.RUnlock() + + found, exists := s.pastes[id] + if !exists { + return paste.Paste{}, paste.ErrPasteNotFound + } + + return found, nil +} + +// ListRecent returns recent pastes, newest first, up to limit. +func (s *Store) ListRecent(ctx context.Context, limit int) ([]paste.Paste, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if limit <= 0 { + return []paste.Paste{}, nil + } + + s.mu.RLock() + defer s.mu.RUnlock() + + recent := make([]paste.Paste, 0, len(s.pastes)) + for _, stored := range s.pastes { + recent = append(recent, stored) + } + slices.SortFunc(recent, comparePasteRecency) + if len(recent) > limit { + recent = recent[:limit] + } + + return recent, nil +} + +func comparePasteRecency(left paste.Paste, right paste.Paste) int { + switch { + case left.CreatedAt.After(right.CreatedAt): + return sortBefore + case right.CreatedAt.After(left.CreatedAt): + return sortAfter + } + + return strings.Compare(left.ID, right.ID) +} diff --git a/testkit/internal/store/memory/store_test.go b/testkit/internal/store/memory/store_test.go new file mode 100644 index 0000000..f87e26c --- /dev/null +++ b/testkit/internal/store/memory/store_test.go @@ -0,0 +1,17 @@ +package memory_test + +import ( + "testing" + + "github.com/meigma/authkit/testkit/internal/paste" + "github.com/meigma/authkit/testkit/internal/store/memory" + "github.com/meigma/authkit/testkit/internal/store/storetest" +) + +func TestSharedStoreBehavior(t *testing.T) { + storetest.Run(t, func(t *testing.T) paste.Repository { + t.Helper() + + return memory.NewStore() + }) +} diff --git a/testkit/internal/store/postgres/doc.go b/testkit/internal/store/postgres/doc.go new file mode 100644 index 0000000..b259070 --- /dev/null +++ b/testkit/internal/store/postgres/doc.go @@ -0,0 +1,2 @@ +// Package postgres contains the PostgreSQL testkit paste repository. +package postgres diff --git a/testkit/internal/store/postgres/migrate.go b/testkit/internal/store/postgres/migrate.go new file mode 100644 index 0000000..ba605a8 --- /dev/null +++ b/testkit/internal/store/postgres/migrate.go @@ -0,0 +1,131 @@ +package postgres + +import ( + "context" + "embed" + "errors" + "fmt" + "io/fs" + "strconv" + "strings" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" +) + +const ( + migrationLockID int64 = 0x746573746b6974 + versionBase int = 10 + versionBits int = 64 +) + +//go:embed migrations/*.sql +var migrationFiles embed.FS + +// Migrate applies testkit's PostgreSQL paste schema migrations. +func Migrate(ctx context.Context, pool *pgxpool.Pool) error { + if err := ctx.Err(); err != nil { + return err + } + if pool == nil { + return errors.New("testkit postgres: pool is required") + } + + tx, err := pool.Begin(ctx) + if err != nil { + return fmt.Errorf("testkit postgres: begin migration: %w", err) + } + defer func() { + _ = tx.Rollback(ctx) + }() + + if _, execErr := tx.Exec(ctx, `select pg_advisory_xact_lock($1)`, migrationLockID); execErr != nil { + return fmt.Errorf("testkit postgres: acquire migration lock: %w", execErr) + } + if _, execErr := tx.Exec(ctx, createMigrationTableSQL); execErr != nil { + return fmt.Errorf("testkit postgres: create migration table: %w", execErr) + } + + entries, err := fs.ReadDir(migrationFiles, "migrations") + if err != nil { + return fmt.Errorf("testkit postgres: read migrations: %w", err) + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + if err := applyMigration(ctx, tx, entry.Name()); err != nil { + return err + } + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("testkit postgres: commit migration: %w", err) + } + + return nil +} + +const createMigrationTableSQL = ` +create table if not exists testkit_schema_migrations ( + version bigint primary key, + name text not null, + applied_at timestamptz not null default now() +)` + +type migrationTx interface { + Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row +} + +func applyMigration(ctx context.Context, tx migrationTx, name string) error { + version, err := migrationVersion(name) + if err != nil { + return err + } + + var applied bool + if queryErr := tx.QueryRow( + ctx, + `select exists (select 1 from testkit_schema_migrations where version = $1)`, + version, + ).Scan(&applied); queryErr != nil { + return fmt.Errorf("testkit postgres: check migration %s: %w", name, queryErr) + } + if applied { + return nil + } + + sql, err := migrationFiles.ReadFile("migrations/" + name) + if err != nil { + return fmt.Errorf("testkit postgres: read migration %s: %w", name, err) + } + if _, err := tx.Exec(ctx, string(sql)); err != nil { + return fmt.Errorf("testkit postgres: apply migration %s: %w", name, err) + } + if _, err := tx.Exec( + ctx, + `insert into testkit_schema_migrations (version, name) values ($1, $2)`, + version, + name, + ); err != nil { + return fmt.Errorf("testkit postgres: record migration %s: %w", name, err) + } + + return nil +} + +func migrationVersion(name string) (int64, error) { + prefix, _, found := strings.Cut(name, "_") + if !found { + return 0, fmt.Errorf("testkit postgres: migration %q has no version prefix", name) + } + + version, err := strconv.ParseInt(prefix, versionBase, versionBits) + if err != nil { + return 0, fmt.Errorf("testkit postgres: parse migration version %q: %w", name, err) + } + + return version, nil +} diff --git a/testkit/internal/store/postgres/migrations/000001_pastes.sql b/testkit/internal/store/postgres/migrations/000001_pastes.sql new file mode 100644 index 0000000..1765088 --- /dev/null +++ b/testkit/internal/store/postgres/migrations/000001_pastes.sql @@ -0,0 +1,10 @@ +create table testkit_pastes ( + id text primary key, + title text not null, + body text not null, + syntax text not null, + created_at timestamptz not null +); + +create index testkit_pastes_created_at_id_idx + on testkit_pastes (created_at desc, id asc); diff --git a/testkit/internal/store/postgres/store.go b/testkit/internal/store/postgres/store.go new file mode 100644 index 0000000..316ae1c --- /dev/null +++ b/testkit/internal/store/postgres/store.go @@ -0,0 +1,142 @@ +package postgres + +import ( + "context" + "errors" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/meigma/authkit/testkit/internal/paste" +) + +const uniqueViolation = "23505" + +// Store persists testkit pastes in PostgreSQL. +type Store struct { + pool *pgxpool.Pool +} + +// NewStore constructs a PostgreSQL paste store around pool. +func NewStore(pool *pgxpool.Pool) (*Store, error) { + if pool == nil { + return nil, errors.New("testkit postgres: pool is required") + } + + return &Store{pool: pool}, nil +} + +// Create stores a new paste. +func (s *Store) Create(ctx context.Context, created paste.Paste) error { + if err := ctx.Err(); err != nil { + return err + } + + _, err := s.pool.Exec( + ctx, + `insert into testkit_pastes (id, title, body, syntax, created_at) + values ($1, $2, $3, $4, $5)`, + created.ID, + created.Title, + created.Body, + created.Syntax, + created.CreatedAt.UTC(), + ) + if err != nil { + if isPostgresCode(err, uniqueViolation) { + return paste.ErrDuplicatePasteID + } + + return fmt.Errorf("testkit postgres: create paste: %w", err) + } + + return nil +} + +// Find returns a paste by ID. +func (s *Store) Find(ctx context.Context, id string) (paste.Paste, error) { + if err := ctx.Err(); err != nil { + return paste.Paste{}, err + } + + found, err := scanPaste(s.pool.QueryRow( + ctx, + `select id, title, body, syntax, created_at + from testkit_pastes + where id = $1`, + id, + )) + if errors.Is(err, pgx.ErrNoRows) { + return paste.Paste{}, paste.ErrPasteNotFound + } + if err != nil { + return paste.Paste{}, err + } + + return found, nil +} + +// ListRecent returns recent pastes, newest first, up to limit. +func (s *Store) ListRecent(ctx context.Context, limit int) ([]paste.Paste, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if limit <= 0 { + return []paste.Paste{}, nil + } + + rows, err := s.pool.Query( + ctx, + `select id, title, body, syntax, created_at + from testkit_pastes + order by created_at desc, id asc + limit $1`, + limit, + ) + if err != nil { + return nil, fmt.Errorf("testkit postgres: list recent pastes: %w", err) + } + defer rows.Close() + + var pastes []paste.Paste + for rows.Next() { + found, err := scanPaste(rows) + if err != nil { + return nil, err + } + pastes = append(pastes, found) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("testkit postgres: read recent pastes: %w", err) + } + + return pastes, nil +} + +type scanner interface { + Scan(dest ...any) error +} + +func scanPaste(row scanner) (paste.Paste, error) { + var found paste.Paste + if err := row.Scan( + &found.ID, + &found.Title, + &found.Body, + &found.Syntax, + &found.CreatedAt, + ); err != nil { + return paste.Paste{}, err + } + found.CreatedAt = found.CreatedAt.UTC() + + return found, nil +} + +func isPostgresCode(err error, code string) bool { + var pgErr *pgconn.PgError + + return errors.As(err, &pgErr) && pgErr.Code == code +} diff --git a/testkit/internal/store/postgres/store_integration_test.go b/testkit/internal/store/postgres/store_integration_test.go new file mode 100644 index 0000000..0030744 --- /dev/null +++ b/testkit/internal/store/postgres/store_integration_test.go @@ -0,0 +1,143 @@ +//go:build integration + +package postgres + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" + + "github.com/meigma/authkit/testkit/internal/paste" + "github.com/meigma/authkit/testkit/internal/store/storetest" +) + +const ( + concurrentMigrationCalls = 5 + expectedMigrationRows = 1 + postgresReadyOccurrences = 2 +) + +func TestSharedStoreBehavior(t *testing.T) { + ctx := context.Background() + pool := newPostgresPool(t) + require.NoError(t, Migrate(ctx, pool)) + + storetest.Run(t, func(t *testing.T) paste.Repository { + t.Helper() + resetStore(t, pool) + + store, err := NewStore(pool) + require.NoError(t, err) + + return store + }) +} + +func TestMigrateCreatesSchema(t *testing.T) { + ctx := context.Background() + pool := newPostgresPool(t) + + require.NoError(t, Migrate(ctx, pool)) + require.NoError(t, Migrate(ctx, pool)) + + for _, table := range []string{ + "testkit_schema_migrations", + "testkit_pastes", + } { + t.Run(table, func(t *testing.T) { + var exists bool + err := pool.QueryRow( + ctx, + `select exists ( + select 1 from information_schema.tables + where table_schema = 'public' and table_name = $1 + )`, + table, + ).Scan(&exists) + + require.NoError(t, err) + assert.True(t, exists) + }) + } + + assertMigrationRows(t, pool) +} + +func TestMigrateConcurrentCalls(t *testing.T) { + ctx := context.Background() + pool := newPostgresPool(t) + errs := make(chan error, concurrentMigrationCalls) + var wg sync.WaitGroup + + for range cap(errs) { + wg.Add(1) + go func() { + defer wg.Done() + errs <- Migrate(ctx, pool) + }() + } + + wg.Wait() + close(errs) + + for err := range errs { + require.NoError(t, err) + } + assertMigrationRows(t, pool) +} + +func newPostgresPool(t *testing.T) *pgxpool.Pool { + t.Helper() + + ctx := context.Background() + container, err := tcpostgres.Run( + ctx, + "postgres:16-alpine", + tcpostgres.WithDatabase("testkit"), + tcpostgres.WithUsername("testkit"), + tcpostgres.WithPassword("testkit"), + testcontainers.WithAdditionalWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(postgresReadyOccurrences). + WithStartupTimeout(time.Minute), + ), + ) + require.NoError(t, err) + testcontainers.CleanupContainer(t, container) + + connectionString, err := container.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + pool, err := pgxpool.New(ctx, connectionString) + require.NoError(t, err) + t.Cleanup(pool.Close) + + return pool +} + +func resetStore(t *testing.T, pool *pgxpool.Pool) { + t.Helper() + + _, err := pool.Exec(context.Background(), `truncate table testkit_pastes`) + require.NoError(t, err) +} + +func assertMigrationRows(t *testing.T, pool *pgxpool.Pool) { + t.Helper() + + var migrationRows int + err := pool.QueryRow( + context.Background(), + `select count(*) from testkit_schema_migrations where version = 1`, + ).Scan(&migrationRows) + require.NoError(t, err) + assert.Equal(t, expectedMigrationRows, migrationRows) +} diff --git a/testkit/internal/store/storetest/storetest.go b/testkit/internal/store/storetest/storetest.go new file mode 100644 index 0000000..e62fd58 --- /dev/null +++ b/testkit/internal/store/storetest/storetest.go @@ -0,0 +1,95 @@ +// Package storetest contains shared behavior tests for testkit paste stores. +package storetest + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/meigma/authkit/testkit/internal/paste" +) + +const ( + firstPasteID = "paste-1" + recentLimit = 2 + secondPasteID = "paste-2" + thirdOffset = 2 + thirdPasteID = "paste-3" +) + +// Run runs the shared paste repository behavior suite. +func Run(t *testing.T, newRepo func(*testing.T) paste.Repository) { + t.Helper() + + t.Run("create and find paste", func(t *testing.T) { + repo := newRepo(t) + created := newPaste(firstPasteID, "Example", "hello", "text", firstTime()) + + require.NoError(t, repo.Create(context.Background(), created)) + found, err := repo.Find(context.Background(), firstPasteID) + + require.NoError(t, err) + assert.Equal(t, created, found) + }) + + t.Run("missing paste returns not found", func(t *testing.T) { + repo := newRepo(t) + + _, err := repo.Find(context.Background(), "missing") + + assert.ErrorIs(t, err, paste.ErrPasteNotFound) + }) + + t.Run("duplicate paste ID is rejected", func(t *testing.T) { + repo := newRepo(t) + created := newPaste(firstPasteID, "Example", "hello", "text", firstTime()) + + require.NoError(t, repo.Create(context.Background(), created)) + err := repo.Create(context.Background(), created) + + assert.ErrorIs(t, err, paste.ErrDuplicatePasteID) + }) + + t.Run("recent list is newest first and limited", func(t *testing.T) { + repo := newRepo(t) + + require.NoError(t, repo.Create(context.Background(), newPaste(firstPasteID, "Old", "one", "", firstTime()))) + require.NoError(t, repo.Create(context.Background(), newPaste(secondPasteID, "New", "two", "", secondTime()))) + require.NoError( + t, + repo.Create(context.Background(), newPaste(thirdPasteID, "Newest", "three", "", thirdTime())), + ) + + recent, err := repo.ListRecent(context.Background(), recentLimit) + + require.NoError(t, err) + require.Len(t, recent, recentLimit) + assert.Equal(t, thirdPasteID, recent[0].ID) + assert.Equal(t, secondPasteID, recent[1].ID) + }) +} + +func newPaste(id string, title string, body string, syntax string, createdAt time.Time) paste.Paste { + return paste.Paste{ + ID: id, + Title: title, + Body: body, + Syntax: syntax, + CreatedAt: createdAt, + } +} + +func firstTime() time.Time { + return time.Date(2026, time.May, 14, 10, 0, 0, 0, time.UTC) +} + +func secondTime() time.Time { + return firstTime().Add(time.Minute) +} + +func thirdTime() time.Time { + return firstTime().Add(thirdOffset * time.Minute) +}